The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .github
    └── workflows
    │   └── python-publish.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── gigagan-architecture.png
├── gigagan-sample.png
├── gigagan_pytorch
    ├── __init__.py
    ├── attend.py
    ├── data.py
    ├── distributed.py
    ├── gigagan_pytorch.py
    ├── open_clip.py
    ├── optimizer.py
    ├── unet_upsampler.py
    └── version.py
├── pyproject.toml
└── setup.py


/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 | # This workflow will upload a Python Package using Twine when a release is created
 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
 5 | 
 6 | # This workflow uses actions that are not certified by GitHub.
 7 | # They are provided by a third-party and are governed by
 8 | # separate terms of service, privacy policy, and support
 9 | # documentation.
10 | 
11 | name: Upload Python Package
12 | 
13 | on:
14 |   release:
15 |     types: [published]
16 | 
17 | jobs:
18 |   deploy:
19 | 
20 |     runs-on: ubuntu-latest
21 | 
22 |     steps:
23 |     - uses: actions/checkout@v2
24 |     - name: Set up Python
25 |       uses: actions/setup-python@v2
26 |       with:
27 |         python-version: '3.x'
28 |     - name: Install dependencies
29 |       run: |
30 |         python -m pip install --upgrade pip
31 |         pip install build
32 |     - name: Build package
33 |       run: python -m build
34 |     - name: Publish package
35 |       uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 |       with:
37 |         user: __token__
38 |         password: ${{ secrets.PYPI_API_TOKEN }}
39 | 


--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
  1 | # Byte-compiled / optimized / DLL files
  2 | __pycache__/
  3 | *.py[cod]
  4 | *$py.class
  5 | 
  6 | # C extensions
  7 | *.so
  8 | 
  9 | # Distribution / packaging
 10 | .Python
 11 | build/
 12 | develop-eggs/
 13 | dist/
 14 | downloads/
 15 | eggs/
 16 | .eggs/
 17 | lib/
 18 | lib64/
 19 | parts/
 20 | sdist/
 21 | var/
 22 | wheels/
 23 | pip-wheel-metadata/
 24 | share/python-wheels/
 25 | *.egg-info/
 26 | .installed.cfg
 27 | *.egg
 28 | MANIFEST
 29 | 
 30 | # PyInstaller
 31 | #  Usually these files are written by a python script from a template
 32 | #  before PyInstaller builds the exe, so as to inject date/other infos into it.
 33 | *.manifest
 34 | *.spec
 35 | 
 36 | # Installer logs
 37 | pip-log.txt
 38 | pip-delete-this-directory.txt
 39 | 
 40 | # Unit test / coverage reports
 41 | htmlcov/
 42 | .tox/
 43 | .nox/
 44 | .coverage
 45 | .coverage.*
 46 | .cache
 47 | nosetests.xml
 48 | coverage.xml
 49 | *.cover
 50 | *.py,cover
 51 | .hypothesis/
 52 | .pytest_cache/
 53 | 
 54 | # Translations
 55 | *.mo
 56 | *.pot
 57 | 
 58 | # Django stuff:
 59 | *.log
 60 | local_settings.py
 61 | db.sqlite3
 62 | db.sqlite3-journal
 63 | 
 64 | # Flask stuff:
 65 | instance/
 66 | .webassets-cache
 67 | 
 68 | # Scrapy stuff:
 69 | .scrapy
 70 | 
 71 | # Sphinx documentation
 72 | docs/_build/
 73 | 
 74 | # PyBuilder
 75 | target/
 76 | 
 77 | # Jupyter Notebook
 78 | .ipynb_checkpoints
 79 | 
 80 | # IPython
 81 | profile_default/
 82 | ipython_config.py
 83 | 
 84 | # pyenv
 85 | .python-version
 86 | 
 87 | # pipenv
 88 | #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
 89 | #   However, in case of collaboration, if having platform-specific dependencies or dependencies
 90 | #   having no cross-platform support, pipenv may install dependencies that don't work, or not
 91 | #   install all needed dependencies.
 92 | #Pipfile.lock
 93 | 
 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
 95 | __pypackages__/
 96 | 
 97 | # Celery stuff
 98 | celerybeat-schedule
 99 | celerybeat.pid
100 | 
101 | # SageMath parsed files
102 | *.sage.py
103 | 
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 | 
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 | 
117 | # Rope project settings
118 | .ropeproject
119 | 
120 | # mkdocs documentation
121 | /site
122 | 
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 | 
128 | # Pyre type checker
129 | .pyre/
130 | 


--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | repos:
3 |   - repo: https://github.com/astral-sh/ruff-pre-commit
4 |     rev: v0.0.278
5 |     hooks:
6 |       - id: ruff
7 |         args: [ --fix, --exit-non-zero-on-fix]
8 | 


--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
 1 | MIT License
 2 | 
 3 | Copyright (c) 2023 Phil Wang
 4 | 
 5 | Permission is hereby granted, free of charge, to any person obtaining a copy
 6 | of this software and associated documentation files (the "Software"), to deal
 7 | in the Software without restriction, including without limitation the rights
 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 | 
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 | 
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 | 


--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
  1 | <img src="./gigagan-sample.png" width="500px"></img>
  2 | 
  3 | <img src="./gigagan-architecture.png" width="500px"></img>
  4 | 
  5 | ## GigaGAN - Pytorch
  6 | 
  7 | Implementation of <a href="https://arxiv.org/abs/2303.05511v2">GigaGAN</a> <a href="https://mingukkang.github.io/GigaGAN/">(project page)</a>, new SOTA GAN out of Adobe.
  8 | 
  9 | I will also add a few findings from <a href="https://github.com/lucidrains/lightweight-gan">lightweight gan</a>, for faster convergence (skip layer excitation) and better stability (reconstruction auxiliary loss in discriminator)
 10 | 
 11 | It will also contain the code for the 1k - 4k upsamplers, which I find to be the highlight of this paper.
 12 | 
 13 | Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community
 14 | 
 15 | ## Appreciation
 16 | 
 17 | - <a href="https://stability.ai/">StabilityAI</a> and <a href="https://huggingface.co/">🤗 Huggingface</a> for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.
 18 | 
 19 | - <a href="https://huggingface.co/">🤗 Huggingface</a> for their accelerate library
 20 | 
 21 | - All the maintainers at <a href="https://github.com/mlfoundations/open_clip">OpenClip</a>, for their SOTA open sourced contrastive learning text-image models
 22 | 
 23 | - <a href="https://github.com/XavierXiao">Xavier</a> for the very helpful code review, and for discussions on how the scale invariance in the discriminator should be built!
 24 | 
 25 | - <a href="https://github.com/CerebralSeed">@CerebralSeed</a> for pull requesting the initial sampling code for both the generator and upsampler!
 26 | 
 27 | - <a href="https://github.com/randintgenr ">Keerth</a> for the code review and pointing out some discrepancies with the paper!
 28 | 
 29 | ## Install
 30 | 
 31 | ```bash
 32 | $ pip install gigagan-pytorch
 33 | ```
 34 | 
 35 | ## Usage
 36 | 
 37 | Simple unconditional GAN, for starters
 38 | 
 39 | ```python
 40 | import torch
 41 | 
 42 | from gigagan_pytorch import (
 43 |     GigaGAN,
 44 |     ImageDataset
 45 | )
 46 | 
 47 | gan = GigaGAN(
 48 |     generator = dict(
 49 |         dim_capacity = 8,
 50 |         style_network = dict(
 51 |             dim = 64,
 52 |             depth = 4
 53 |         ),
 54 |         image_size = 256,
 55 |         dim_max = 512,
 56 |         num_skip_layers_excite = 4,
 57 |         unconditional = True
 58 |     ),
 59 |     discriminator = dict(
 60 |         dim_capacity = 16,
 61 |         dim_max = 512,
 62 |         image_size = 256,
 63 |         num_skip_layers_excite = 4,
 64 |         unconditional = True
 65 |     ),
 66 |     amp = True
 67 | ).cuda()
 68 | 
 69 | # dataset
 70 | 
 71 | dataset = ImageDataset(
 72 |     folder = '/path/to/your/data',
 73 |     image_size = 256
 74 | )
 75 | 
 76 | dataloader = dataset.get_dataloader(batch_size = 1)
 77 | 
 78 | # you must then set the dataloader for the GAN before training
 79 | 
 80 | gan.set_dataloader(dataloader)
 81 | 
 82 | # training the discriminator and generator alternating
 83 | # for 100 steps in this example, batch size 1, gradient accumulated 8 times
 84 | 
 85 | gan(
 86 |     steps = 100,
 87 |     grad_accum_every = 8
 88 | )
 89 | 
 90 | # after much training
 91 | 
 92 | images = gan.generate(batch_size = 4) # (4, 3, 256, 256)
 93 | ```
 94 | 
 95 | For unconditional Unet Upsampler
 96 | 
 97 | ```python
 98 | import torch
 99 | from gigagan_pytorch import (
100 |     GigaGAN,
101 |     ImageDataset
102 | )
103 | 
104 | gan = GigaGAN(
105 |     train_upsampler = True,     # set this to True
106 |     generator = dict(
107 |         style_network = dict(
108 |             dim = 64,
109 |             depth = 4
110 |         ),
111 |         dim = 32,
112 |         image_size = 256,
113 |         input_image_size = 64,
114 |         unconditional = True
115 |     ),
116 |     discriminator = dict(
117 |         dim_capacity = 16,
118 |         dim_max = 512,
119 |         image_size = 256,
120 |         num_skip_layers_excite = 4,
121 |         multiscale_input_resolutions = (128,),
122 |         unconditional = True
123 |     ),
124 |     amp = True
125 | ).cuda()
126 | 
127 | dataset = ImageDataset(
128 |     folder = '/path/to/your/data',
129 |     image_size = 256
130 | )
131 | 
132 | dataloader = dataset.get_dataloader(batch_size = 1)
133 | 
134 | gan.set_dataloader(dataloader)
135 | 
136 | # training the discriminator and generator alternating
137 | # for 100 steps in this example, batch size 1, gradient accumulated 8 times
138 | 
139 | gan(
140 |     steps = 100,
141 |     grad_accum_every = 8
142 | )
143 | 
144 | # after much training
145 | 
146 | lowres = torch.randn(1, 3, 64, 64).cuda()
147 | 
148 | images = gan.generate(lowres) # (1, 3, 256, 256)
149 | ```
150 | 
151 | ## Losses
152 | 
153 | * `G` - Generator
154 | * `MSG` - Multiscale Generator
155 | * `D` - Discriminator
156 | * `MSD` - Multiscale Discriminator
157 | * `GP` - Gradient Penalty
158 | * `SSL` - Auxiliary Reconstruction in Discriminator (from Lightweight GAN)
159 | * `VD` - Vision-aided Discriminator
160 | * `VG` - Vision-aided Generator
161 | * `CL` - Generator Constrastive Loss
162 | * `MAL` - Matching Aware Loss
163 | 
164 | A healthy run would have `G`, `MSG`, `D`, `MSD` with values hovering between `0` to `10`, and usually staying pretty constant. If at any time after 1k training steps these values persist at triple digits, that would mean something is wrong. It is ok for generator and discriminator values to occasionally dip negative, but it should swing back up to the range above.
165 | 
166 | `GP` and `SSL` should be pushed towards `0`. `GP` can occasionally spike; I like to imagine it as the networks undergoing some epiphany
167 | 
168 | ## Multi-GPU Training
169 | 
170 | The `GigaGAN` class is now equipped with <a href="https://huggingface.co/docs/accelerate/en/package_reference/accelerator">🤗 Accelerator</a>. You can easily do multi-gpu training in two steps using their `accelerate` CLI
171 | 
172 | At the project root directory, where the training script is, run
173 | 
174 | ```python
175 | $ accelerate config
176 | ```
177 | 
178 | Then, in the same directory
179 | 
180 | ```python
181 | $ accelerate launch train.py
182 | ```
183 | 
184 | ## Todo
185 | 
186 | - [x] make sure it can be trained unconditionally
187 | - [x] read the relevant papers and knock out all 3 auxiliary losses
188 |     - [x] matching aware loss
189 |     - [x] clip loss
190 |     - [x] vision-aided discriminator loss
191 |     - [x] add reconstruction losses on arbitrary stages in the discriminator (lightweight gan)
192 |     - [x] figure out how the random projections are used from projected-gan
193 |     - [x] vision aided discriminator needs to extract N layers from the vision model in CLIP
194 |     - [x] figure out whether to discard CLS token and reshape into image dimensions for convolution, or stick with attention and condition with adaptive layernorm - also turn off vision aided gan in unconditional case
195 | - [x] unet upsampler
196 |     - [x] add adaptive conv
197 |     - [x] modify latter stage of unet to also output rgb residuals, and pass the rgb into discriminator. make discriminator agnostic to rgb being passed in
198 |     - [x] do pixel shuffle upsamples for unet
199 | - [x] get a code review for the multi-scale inputs and outputs, as the paper was a bit vague
200 | - [x] add upsampling network architecture
201 | - [x] make unconditional work for both base generator and upsampler
202 | - [x] make text conditioned training work for both base and upsampler
203 | - [x] make recon more efficient by random sampling patches
204 | - [x] make sure generator and discriminator can also accept pre-encoded CLIP text encodings
205 | - [x] do a review of the auxiliary losses
206 |     - [x] add contrastive loss for generator
207 |     - [x] add vision aided loss
208 |     - [x] add gradient penalty for vision aided discr - make optional
209 |     - [x] add matching awareness loss - figure out if rotating text conditions by one is good enough for mismatching (without drawing an additional batch from dataloader)
210 |     - [x] make sure gradient accumulation works with matching aware loss
211 |     - [x] matching awareness loss runs and is stable
212 |     - [x] vision aided trains
213 | - [x] add some differentiable augmentations, proven technique from the old GAN days
214 |     - [x] remove any magic being done with automatic rgbs processing, and have it explicitly passed in - offer functions on the discriminator that can process real images into the right multi-scales
215 |     - [x] add horizontal flip for starters
216 | 
217 | - [ ] move all modulation projections into the adaptive conv2d class
218 | - [ ] add accelerate
219 |     - [x] works single machine
220 |     - [x] works for mixed precision (make sure gradient penalty is scaled correctly), take care of manual scaler saving and reloading, borrow from imagen-pytorch
221 |     - [x] make sure it works multi-GPU for one machine
222 |     - [ ] have someone else try multiple machines
223 | 
224 | - [ ] clip should be optional for all modules, and managed by `GigaGAN`, with text -> text embeds processed once
225 | - [ ] add ability to select a random subset from multiscale dimension, for efficiency
226 | 
227 | - [ ] port over CLI from lightweight|stylegan2-pytorch
228 | - [ ] hook up laion dataset for text-image
229 | 
230 | ## Citations
231 | 
232 | ```bibtex
233 | @misc{https://doi.org/10.48550/arxiv.2303.05511,
234 |     url     = {https://arxiv.org/abs/2303.05511},
235 |     author  = {Kang, Minguk and Zhu, Jun-Yan and Zhang, Richard and Park, Jaesik and Shechtman, Eli and Paris, Sylvain and Park, Taesung},  
236 |     title   = {Scaling up GANs for Text-to-Image Synthesis},
237 |     publisher = {arXiv},
238 |     year    = {2023},
239 |     copyright = {arXiv.org perpetual, non-exclusive license}
240 | }
241 | ```
242 | 
243 | ```bibtex
244 | @article{Liu2021TowardsFA,
245 |     title   = {Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis},
246 |     author  = {Bingchen Liu and Yizhe Zhu and Kunpeng Song and A. Elgammal},
247 |     journal = {ArXiv},
248 |     year    = {2021},
249 |     volume  = {abs/2101.04775}
250 | }
251 | ```
252 | 
253 | ```bibtex
254 | @inproceedings{dao2022flashattention,
255 |     title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
256 |     author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
257 |     booktitle = {Advances in Neural Information Processing Systems},
258 |     year    = {2022}
259 | }
260 | ```
261 | 
262 | ```bibtex
263 | @inproceedings{Karras2020ada,
264 |     title     = {Training Generative Adversarial Networks with Limited Data},
265 |     author    = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila},
266 |     booktitle = {Proc. NeurIPS},
267 |     year      = {2020}
268 | }
269 | ```
270 | 
271 | ```bibtex
272 | @article{Xu2024VideoGigaGANTD,
273 |     title   = {VideoGigaGAN: Towards Detail-rich Video Super-Resolution},
274 |     author  = {Yiran Xu and Taesung Park and Richard Zhang and Yang Zhou and Eli Shechtman and Feng Liu and Jia-Bin Huang and Difan Liu},
275 |     journal = {ArXiv},
276 |     year    = {2024},
277 |     volume  = {abs/2404.12388},
278 |     url     ={https://api.semanticscholar.org/CorpusID:269214195}
279 | }
280 | ```
281 | 
282 | ```bibtex
283 | @inproceedings{Huang2025TheGI,
284 |     title   = {The GAN is dead; long live the GAN! A Modern GAN Baseline},
285 |     author  = {Yiwen Huang and Aaron Gokaslan and Volodymyr Kuleshov and James Tompkin},
286 |     year    = {2025},
287 |     url     = {https://api.semanticscholar.org/CorpusID:275405495}
288 | }
289 | ```
290 | 


--------------------------------------------------------------------------------
/gigagan-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/gigagan-pytorch/0806433f1e8eadbe888162b5c5a5ab625ce4e0a5/gigagan-architecture.png


--------------------------------------------------------------------------------
/gigagan-sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/gigagan-pytorch/0806433f1e8eadbe888162b5c5a5ab625ce4e0a5/gigagan-sample.png


--------------------------------------------------------------------------------
/gigagan_pytorch/__init__.py:
--------------------------------------------------------------------------------
 1 | from gigagan_pytorch.gigagan_pytorch import (
 2 |     GigaGAN,
 3 |     Generator,
 4 |     Discriminator,
 5 |     VisionAidedDiscriminator,
 6 |     AdaptiveConv2DMod,
 7 |     StyleNetwork,
 8 |     TextEncoder
 9 | )
10 | 
11 | from gigagan_pytorch.unet_upsampler import UnetUpsampler
12 | 
13 | from gigagan_pytorch.data import (
14 |     ImageDataset,
15 |     TextImageDataset,
16 |     MockTextImageDataset
17 | )
18 | 
19 | __all__ = [
20 |     GigaGAN,
21 |     Generator,
22 |     Discriminator,
23 |     VisionAidedDiscriminator,
24 |     AdaptiveConv2DMod,
25 |     StyleNetwork,
26 |     UnetUpsampler,
27 |     TextEncoder,
28 |     ImageDataset,
29 |     TextImageDataset,
30 |     MockTextImageDataset
31 | ]
32 | 


--------------------------------------------------------------------------------
/gigagan_pytorch/attend.py:
--------------------------------------------------------------------------------
  1 | from functools import wraps
  2 | from packaging import version
  3 | from collections import namedtuple
  4 | 
  5 | import torch
  6 | from torch import nn, einsum
  7 | import torch.nn.functional as F
  8 | 
  9 | 
 10 | # constants
 11 | 
 12 | AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
 13 | 
 14 | # helpers
 15 | 
 16 | def exists(val):
 17 |     return val is not None
 18 | 
 19 | def once(fn):
 20 |     called = False
 21 |     @wraps(fn)
 22 |     def inner(x):
 23 |         nonlocal called
 24 |         if called:
 25 |             return
 26 |         called = True
 27 |         return fn(x)
 28 |     return inner
 29 | 
 30 | print_once = once(print)
 31 | 
 32 | # main class
 33 | 
 34 | class Attend(nn.Module):
 35 |     def __init__(
 36 |         self,
 37 |         dropout = 0.,
 38 |         flash = False
 39 |     ):
 40 |         super().__init__()
 41 |         self.dropout = dropout
 42 |         self.attn_dropout = nn.Dropout(dropout)
 43 | 
 44 |         self.flash = flash
 45 |         assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
 46 | 
 47 |         # determine efficient attention configs for cuda and cpu
 48 | 
 49 |         self.cpu_config = AttentionConfig(True, True, True)
 50 |         self.cuda_config = None
 51 | 
 52 |         if not torch.cuda.is_available() or not flash:
 53 |             return
 54 | 
 55 |         device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
 56 | 
 57 |         if device_properties.major == 8 and device_properties.minor == 0:
 58 |             print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
 59 |             self.cuda_config = AttentionConfig(True, False, False)
 60 |         else:
 61 |             print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
 62 |             self.cuda_config = AttentionConfig(False, True, True)
 63 | 
 64 |     def flash_attn(self, q, k, v):
 65 |         is_cuda = q.is_cuda
 66 | 
 67 |         q, k, v = map(lambda t: t.contiguous(), (q, k, v))
 68 | 
 69 |         # Check if there is a compatible device for flash attention
 70 | 
 71 |         config = self.cuda_config if is_cuda else self.cpu_config
 72 | 
 73 |         # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
 74 | 
 75 |         with torch.backends.cuda.sdp_kernel(**config._asdict()):
 76 |             out = F.scaled_dot_product_attention(
 77 |                 q, k, v,
 78 |                 dropout_p = self.dropout if self.training else 0.
 79 |             )
 80 | 
 81 |         return out
 82 | 
 83 |     def forward(self, q, k, v):
 84 |         """
 85 |         einstein notation
 86 |         b - batch
 87 |         h - heads
 88 |         n, i, j - sequence length (base sequence length, source, target)
 89 |         d - feature dimension
 90 |         """
 91 | 
 92 |         if self.flash:
 93 |             return self.flash_attn(q, k, v)
 94 | 
 95 |         scale = q.shape[-1] ** -0.5
 96 | 
 97 |         # similarity
 98 | 
 99 |         sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
100 | 
101 |         # attention
102 | 
103 |         attn = sim.softmax(dim = -1)
104 |         attn = self.attn_dropout(attn)
105 | 
106 |         # aggregate values
107 | 
108 |         out = einsum("b h i j, b h j d -> b h i d", attn, v)
109 | 
110 |         return out
111 | 


--------------------------------------------------------------------------------
/gigagan_pytorch/data.py:
--------------------------------------------------------------------------------
  1 | from functools import partial
  2 | from pathlib import Path
  3 | 
  4 | import torch
  5 | from torch import nn
  6 | from torch.utils.data import Dataset, DataLoader
  7 | 
  8 | from PIL import Image
  9 | from torchvision import transforms as T
 10 | 
 11 | from beartype.door import is_bearable
 12 | from beartype.typing import Tuple
 13 | 
 14 | # helper functions
 15 | 
 16 | def exists(val):
 17 |     return val is not None
 18 | 
 19 | def convert_image_to_fn(img_type, image):
 20 |     if image.mode == img_type:
 21 |         return image
 22 | 
 23 |     return image.convert(img_type)
 24 | 
 25 | # custom collation function
 26 | # so dataset can return a str and it will collate into List[str]
 27 | 
 28 | def collate_tensors_or_str(data):
 29 |     is_one_data = not isinstance(data[0], tuple)
 30 | 
 31 |     if is_one_data:
 32 |         data = torch.stack(data)
 33 |         return (data,)
 34 | 
 35 |     outputs = []
 36 |     for datum in zip(*data):
 37 |         if is_bearable(datum, Tuple[str, ...]):
 38 |             output = list(datum)
 39 |         else:
 40 |             output = torch.stack(datum)
 41 | 
 42 |         outputs.append(output)
 43 | 
 44 |     return tuple(outputs)
 45 | 
 46 | # dataset classes
 47 | 
 48 | class ImageDataset(Dataset):
 49 |     def __init__(
 50 |         self,
 51 |         folder,
 52 |         image_size,
 53 |         exts = ['jpg', 'jpeg', 'png', 'tiff'],
 54 |         augment_horizontal_flip = False,
 55 |         convert_image_to = None
 56 |     ):
 57 |         super().__init__()
 58 |         self.folder = folder
 59 |         self.image_size = image_size
 60 | 
 61 |         self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
 62 | 
 63 |         assert len(self.paths) > 0, 'your folder contains no images'
 64 |         assert len(self.paths) > 100, 'you need at least 100 images, 10k for research paper, millions for miraculous results (try Laion-5B)'
 65 | 
 66 |         maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else nn.Identity()
 67 | 
 68 |         self.transform = T.Compose([
 69 |             T.Lambda(maybe_convert_fn),
 70 |             T.Resize(image_size),
 71 |             T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
 72 |             T.CenterCrop(image_size),
 73 |             T.ToTensor()
 74 |         ])
 75 | 
 76 |     def get_dataloader(self, *args, **kwargs):
 77 |         return DataLoader(self, *args, shuffle = True, drop_last = True, **kwargs)
 78 | 
 79 |     def __len__(self):
 80 |         return len(self.paths)
 81 | 
 82 |     def __getitem__(self, index):
 83 |         path = self.paths[index]
 84 |         img = Image.open(path)
 85 |         return self.transform(img)
 86 | 
 87 | class TextImageDataset(Dataset):
 88 |     def __init__(self):
 89 |         raise NotImplementedError
 90 | 
 91 |     def get_dataloader(self, *args, **kwargs):
 92 |         return DataLoader(self, *args, collate_fn = collate_tensors_or_str, **kwargs)
 93 | 
 94 | class MockTextImageDataset(TextImageDataset):
 95 |     def __init__(
 96 |         self,
 97 |         image_size,
 98 |         length = int(1e5),
 99 |         channels = 3
100 |     ):
101 |         self.image_size = image_size
102 |         self.channels = channels
103 |         self.length = length
104 | 
105 |     def get_dataloader(self, *args, **kwargs):
106 |         return DataLoader(self, *args, collate_fn = collate_tensors_or_str, **kwargs)
107 | 
108 |     def __len__(self):
109 |         return self.length
110 | 
111 |     def __getitem__(self, index):
112 |         mock_image = torch.randn(self.channels, self.image_size, self.image_size)
113 |         return mock_image, 'mock text'
114 | 


--------------------------------------------------------------------------------
/gigagan_pytorch/distributed.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import torch.nn.functional as F
 3 | from torch.autograd import Function
 4 | import torch.distributed as dist
 5 | 
 6 | from einops import rearrange
 7 | 
 8 | # helpers
 9 | 
10 | def exists(val):
11 |     return val is not None
12 | 
13 | def pad_dim_to(t, length, dim = 0):
14 |     pad_length = length - t.shape[dim]
15 |     zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
16 |     return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))
17 | 
18 | # distributed helpers
19 | 
20 | def all_gather_variable_dim(t, dim = 0, sizes = None):
21 |     device, world_size = t.device, dist.get_world_size()
22 | 
23 |     if not exists(sizes):
24 |         size = torch.tensor(t.shape[dim], device = device, dtype = torch.long)
25 |         sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
26 |         dist.all_gather(sizes, size)
27 |         sizes = torch.stack(sizes)
28 | 
29 |     max_size = sizes.amax().item()
30 |     padded_t = pad_dim_to(t, max_size, dim = dim)
31 | 
32 |     gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)]
33 |     dist.all_gather(gathered_tensors, padded_t)
34 | 
35 |     gathered_tensor = torch.cat(gathered_tensors, dim = dim)
36 |     seq = torch.arange(max_size, device = device)
37 | 
38 |     mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
39 |     mask = rearrange(mask, 'i j -> (i j)')
40 |     seq = torch.arange(mask.shape[-1], device = device)
41 |     indices = seq[mask]
42 | 
43 |     gathered_tensor = gathered_tensor.index_select(dim, indices)
44 | 
45 |     return gathered_tensor, sizes
46 | 
47 | class AllGather(Function):
48 |     @staticmethod
49 |     def forward(ctx, x, dim, sizes):
50 |         is_dist = dist.is_initialized() and dist.get_world_size() > 1
51 |         ctx.is_dist = is_dist
52 | 
53 |         if not is_dist:
54 |             return x, None
55 | 
56 |         x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
57 |         ctx.batch_sizes = batch_sizes.tolist()
58 |         ctx.dim = dim
59 |         return x, batch_sizes
60 | 
61 |     @staticmethod
62 |     def backward(ctx, grads, _):
63 |         if not ctx.is_dist:
64 |             return grads, None, None
65 | 
66 |         batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
67 |         grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
68 |         return grads_by_rank[rank], None, None
69 | 
70 | all_gather = AllGather.apply
71 | 


--------------------------------------------------------------------------------
/gigagan_pytorch/gigagan_pytorch.py:
--------------------------------------------------------------------------------
   1 | from __future__ import annotations
   2 | 
   3 | from collections import namedtuple
   4 | from pathlib import Path
   5 | from math import log2, sqrt
   6 | from random import random
   7 | from functools import partial
   8 | 
   9 | from torchvision import utils
  10 | 
  11 | import torch
  12 | import torch.nn.functional as F
  13 | from torch import nn, einsum, Tensor
  14 | from torch.autograd import grad as torch_grad
  15 | from torch.utils.data import DataLoader
  16 | from torch.cuda.amp import GradScaler
  17 | 
  18 | from beartype import beartype
  19 | from beartype.typing import List, Tuple, Dict, Iterable
  20 | 
  21 | from einops import rearrange, pack, unpack, repeat, reduce
  22 | from einops.layers.torch import Rearrange, Reduce
  23 | 
  24 | from kornia.filters import filter2d
  25 | 
  26 | from ema_pytorch import EMA
  27 | 
  28 | from gigagan_pytorch.version import __version__
  29 | from gigagan_pytorch.open_clip import OpenClipAdapter
  30 | from gigagan_pytorch.optimizer import get_optimizer
  31 | from gigagan_pytorch.distributed import all_gather
  32 | 
  33 | from tqdm import tqdm
  34 | 
  35 | from numerize import numerize
  36 | 
  37 | from accelerate import Accelerator, DistributedType
  38 | from accelerate.utils import DistributedDataParallelKwargs
  39 | 
  40 | # helpers
  41 | 
  42 | def exists(val):
  43 |     return val is not None
  44 | 
  45 | @beartype
  46 | def is_empty(arr: Iterable):
  47 |     return len(arr) == 0
  48 | 
  49 | def default(*vals):
  50 |     for val in vals:
  51 |         if exists(val):
  52 |             return val
  53 |     return None
  54 | 
  55 | def cast_tuple(t, length = 1):
  56 |     return t if isinstance(t, tuple) else ((t,) * length)
  57 | 
  58 | def is_power_of_two(n):
  59 |     return log2(n).is_integer()
  60 | 
  61 | def safe_unshift(arr):
  62 |     if len(arr) == 0:
  63 |         return None
  64 |     return arr.pop(0)
  65 | 
  66 | def divisible_by(numer, denom):
  67 |     return (numer % denom) == 0
  68 | 
  69 | def group_by_num_consecutive(arr, num):
  70 |     out = []
  71 |     for ind, el in enumerate(arr):
  72 |         if ind > 0 and divisible_by(ind, num):
  73 |             yield out
  74 |             out = []
  75 | 
  76 |         out.append(el)
  77 | 
  78 |     if len(out) > 0:
  79 |         yield out
  80 | 
  81 | def is_unique(arr):
  82 |     return len(set(arr)) == len(arr)
  83 | 
  84 | def cycle(dl):
  85 |     while True:
  86 |         for data in dl:
  87 |             yield data
  88 | 
  89 | def num_to_groups(num, divisor):
  90 |     groups, remainder = divmod(num, divisor)
  91 |     arr = [divisor] * groups
  92 |     if remainder > 0:
  93 |         arr.append(remainder)
  94 |     return arr
  95 | 
  96 | def mkdir_if_not_exists(path):
  97 |     path.mkdir(exist_ok = True, parents = True)
  98 | 
  99 | @beartype
 100 | def set_requires_grad_(
 101 |     m: nn.Module,
 102 |     requires_grad: bool
 103 | ):
 104 |     for p in m.parameters():
 105 |         p.requires_grad = requires_grad
 106 | 
 107 | # activation functions
 108 | 
 109 | def leaky_relu(neg_slope = 0.2):
 110 |     return nn.LeakyReLU(neg_slope)
 111 | 
 112 | def conv2d_3x3(dim_in, dim_out):
 113 |     return nn.Conv2d(dim_in, dim_out, 3, padding = 1)
 114 | 
 115 | # tensor helpers
 116 | 
 117 | def log(t, eps = 1e-20):
 118 |     return t.clamp(min = eps).log()
 119 | 
 120 | def gradient_penalty(
 121 |     images,
 122 |     outputs,
 123 |     grad_output_weights = None,
 124 |     weight = 10,
 125 |     center = 0.,
 126 |     scaler: GradScaler | None = None,
 127 |     eps = 1e-4
 128 | ):
 129 |     if not isinstance(outputs, (list, tuple)):
 130 |         outputs = [outputs]
 131 | 
 132 |     if exists(scaler):
 133 |         outputs = [*map(scaler.scale, outputs)]
 134 | 
 135 |     if not exists(grad_output_weights):
 136 |         grad_output_weights = (1,) * len(outputs)
 137 | 
 138 |     maybe_scaled_gradients, *_ = torch_grad(
 139 |         outputs = outputs,
 140 |         inputs = images,
 141 |         grad_outputs = [(torch.ones_like(output) * weight) for output, weight in zip(outputs, grad_output_weights)],
 142 |         create_graph = True,
 143 |         retain_graph = True,
 144 |         only_inputs = True
 145 |     )
 146 | 
 147 |     gradients = maybe_scaled_gradients
 148 | 
 149 |     if exists(scaler):
 150 |         scale = scaler.get_scale()
 151 |         inv_scale = 1. / max(scale, eps)
 152 |         gradients = maybe_scaled_gradients * inv_scale
 153 | 
 154 |     gradients = rearrange(gradients, 'b ... -> b (...)')
 155 |     return weight * ((gradients.norm(2, dim = 1) - center) ** 2).mean()
 156 | 
 157 | # hinge gan losses
 158 | 
 159 | def generator_hinge_loss(fake):
 160 |     return fake.mean()
 161 | 
 162 | def discriminator_hinge_loss(real, fake):
 163 |     return (F.relu(1 + real) + F.relu(1 - fake)).mean()
 164 | 
 165 | # auxiliary losses
 166 | 
 167 | def aux_matching_loss(real, fake):
 168 |     """
 169 |     making logits negative, as in this framework, discriminator is 0 for real, high value for fake. GANs can have this arbitrarily swapped, as it only matters if the generator and discriminator are opposites
 170 |     """
 171 |     return (log(1 + (-real).exp()) + log(1 + (-fake).exp())).mean()
 172 | 
 173 | @beartype
 174 | def aux_clip_loss(
 175 |     clip: OpenClipAdapter,
 176 |     images: Tensor,
 177 |     texts: List[str] | None = None,
 178 |     text_embeds: Tensor | None = None
 179 | ):
 180 |     assert exists(texts) ^ exists(text_embeds)
 181 | 
 182 |     images, batch_sizes = all_gather(images, 0, None)
 183 | 
 184 |     if exists(texts):
 185 |         text_embeds, _ = clip.embed_texts(texts)
 186 |         text_embeds, _ = all_gather(text_embeds, 0, batch_sizes)
 187 | 
 188 |     return clip.contrastive_loss(images = images, text_embeds = text_embeds)
 189 | 
 190 | # differentiable augmentation - Karras et al. stylegan-ada
 191 | # start with horizontal flip
 192 | 
 193 | class DiffAugment(nn.Module):
 194 |     def __init__(
 195 |         self,
 196 |         *,
 197 |         prob,
 198 |         horizontal_flip,
 199 |         horizontal_flip_prob = 0.5
 200 |     ):
 201 |         super().__init__()
 202 |         self.prob = prob
 203 |         assert 0 <= prob <= 1.
 204 | 
 205 |         self.horizontal_flip = horizontal_flip
 206 |         self.horizontal_flip_prob = horizontal_flip_prob
 207 | 
 208 |     def forward(
 209 |         self,
 210 |         images,
 211 |         rgbs: List[Tensor]
 212 |     ):
 213 |         if random() >= self.prob:
 214 |             return images, rgbs
 215 | 
 216 |         if random() < self.horizontal_flip_prob:
 217 |             images = torch.flip(images, (-1,))
 218 |             rgbs = [torch.flip(rgb, (-1,)) for rgb in rgbs]
 219 | 
 220 |         return images, rgbs
 221 | 
 222 | # rmsnorm (newer papers show mean-centering in layernorm not necessary)
 223 | 
 224 | class ChannelRMSNorm(nn.Module):
 225 |     def __init__(self, dim):
 226 |         super().__init__()
 227 |         self.scale = dim ** 0.5
 228 |         self.gamma = nn.Parameter(torch.ones(dim, 1, 1))
 229 | 
 230 |     def forward(self, x):
 231 |         normed = F.normalize(x, dim = 1)
 232 |         return normed * self.scale * self.gamma
 233 | 
 234 | class RMSNorm(nn.Module):
 235 |     def __init__(self, dim):
 236 |         super().__init__()
 237 |         self.scale = dim ** 0.5
 238 |         self.gamma = nn.Parameter(torch.ones(dim))
 239 | 
 240 |     def forward(self, x):
 241 |         normed = F.normalize(x, dim = -1)
 242 |         return normed * self.scale * self.gamma
 243 | 
 244 | # down and upsample
 245 | 
 246 | class Blur(nn.Module):
 247 |     def __init__(self):
 248 |         super().__init__()
 249 |         f = torch.Tensor([1, 2, 1])
 250 |         self.register_buffer('f', f)
 251 | 
 252 |     def forward(self, x):
 253 |         f = self.f
 254 |         f = f[None, None, :] * f[None, :, None]
 255 |         return filter2d(x, f, normalized = True)
 256 | 
 257 | def Upsample(*args):
 258 |     return nn.Sequential(
 259 |         nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False),
 260 |         Blur()
 261 |     )
 262 | 
 263 | class PixelShuffleUpsample(nn.Module):
 264 |     def __init__(self, dim, dim_out = None):
 265 |         super().__init__()
 266 |         dim_out = default(dim_out, dim)
 267 |         conv = nn.Conv2d(dim, dim_out * 4, 1)
 268 | 
 269 |         self.net = nn.Sequential(
 270 |             conv,
 271 |             nn.SiLU(),
 272 |             nn.PixelShuffle(2)
 273 |         )
 274 | 
 275 |         self.init_conv_(conv)
 276 | 
 277 |     def init_conv_(self, conv):
 278 |         o, i, h, w = conv.weight.shape
 279 |         conv_weight = torch.empty(o // 4, i, h, w)
 280 |         nn.init.kaiming_uniform_(conv_weight)
 281 |         conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
 282 | 
 283 |         conv.weight.data.copy_(conv_weight)
 284 |         nn.init.zeros_(conv.bias.data)
 285 | 
 286 |     def forward(self, x):
 287 |         return self.net(x)
 288 | 
 289 | def Downsample(dim):
 290 |     return nn.Sequential(
 291 |         Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
 292 |         nn.Conv2d(dim * 4, dim, 1)
 293 |     )
 294 | 
 295 | # skip layer excitation
 296 | 
 297 | def SqueezeExcite(dim, dim_out, reduction = 4, dim_min = 32):
 298 |     dim_hidden = max(dim_out // reduction, dim_min)
 299 | 
 300 |     return nn.Sequential(
 301 |         Reduce('b c h w -> b c', 'mean'),
 302 |         nn.Linear(dim, dim_hidden),
 303 |         nn.SiLU(),
 304 |         nn.Linear(dim_hidden, dim_out),
 305 |         nn.Sigmoid(),
 306 |         Rearrange('b c -> b c 1 1')
 307 |     )
 308 | 
 309 | # adaptive conv
 310 | # the main novelty of the paper - they propose to learn a softmax weighted sum of N convolutional kernels, depending on the text embedding
 311 | 
 312 | def get_same_padding(size, kernel, dilation, stride):
 313 |     return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
 314 | 
 315 | class AdaptiveConv2DMod(nn.Module):
 316 |     def __init__(
 317 |         self,
 318 |         dim,
 319 |         dim_out,
 320 |         kernel,
 321 |         *,
 322 |         demod = True,
 323 |         stride = 1,
 324 |         dilation = 1,
 325 |         eps = 1e-8,
 326 |         num_conv_kernels = 1 # set this to be greater than 1 for adaptive
 327 |     ):
 328 |         super().__init__()
 329 |         self.eps = eps
 330 | 
 331 |         self.dim_out = dim_out
 332 | 
 333 |         self.kernel = kernel
 334 |         self.stride = stride
 335 |         self.dilation = dilation
 336 |         self.adaptive = num_conv_kernels > 1
 337 | 
 338 |         self.weights = nn.Parameter(torch.randn((num_conv_kernels, dim_out, dim, kernel, kernel)))
 339 | 
 340 |         self.demod = demod
 341 | 
 342 |         nn.init.kaiming_normal_(self.weights, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu')
 343 | 
 344 |     def forward(
 345 |         self,
 346 |         fmap,
 347 |         mod: Tensor,
 348 |         kernel_mod: Tensor | None = None
 349 |     ):
 350 |         """
 351 |         notation
 352 | 
 353 |         b - batch
 354 |         n - convs
 355 |         o - output
 356 |         i - input
 357 |         k - kernel
 358 |         """
 359 | 
 360 |         b, h = fmap.shape[0], fmap.shape[-2]
 361 | 
 362 |         # account for feature map that has been expanded by the scale in the first dimension
 363 |         # due to multiscale inputs and outputs
 364 | 
 365 |         if mod.shape[0] != b:
 366 |             mod = repeat(mod, 'b ... -> (s b) ...', s = b // mod.shape[0])
 367 | 
 368 |         if exists(kernel_mod):
 369 |             kernel_mod_has_el = kernel_mod.numel() > 0
 370 | 
 371 |             assert self.adaptive or not kernel_mod_has_el
 372 | 
 373 |             if kernel_mod_has_el and kernel_mod.shape[0] != b:
 374 |                 kernel_mod = repeat(kernel_mod, 'b ... -> (s b) ...', s = b // kernel_mod.shape[0])
 375 | 
 376 |         # prepare weights for modulation
 377 | 
 378 |         weights = self.weights
 379 | 
 380 |         if self.adaptive:
 381 |             weights = repeat(weights, '... -> b ...', b = b)
 382 | 
 383 |             # determine an adaptive weight and 'select' the kernel to use with softmax
 384 | 
 385 |             assert exists(kernel_mod) and kernel_mod.numel() > 0
 386 | 
 387 |             kernel_attn = kernel_mod.softmax(dim = -1)
 388 |             kernel_attn = rearrange(kernel_attn, 'b n -> b n 1 1 1 1')
 389 | 
 390 |             weights = reduce(weights * kernel_attn, 'b n ... -> b ...', 'sum')
 391 | 
 392 |         # do the modulation, demodulation, as done in stylegan2
 393 | 
 394 |         mod = rearrange(mod, 'b i -> b 1 i 1 1')
 395 | 
 396 |         weights = weights * (mod + 1)
 397 | 
 398 |         if self.demod:
 399 |             inv_norm = reduce(weights ** 2, 'b o i k1 k2 -> b o 1 1 1', 'sum').clamp(min = self.eps).rsqrt()
 400 |             weights = weights * inv_norm
 401 | 
 402 |         fmap = rearrange(fmap, 'b c h w -> 1 (b c) h w')
 403 | 
 404 |         weights = rearrange(weights, 'b o ... -> (b o) ...')
 405 | 
 406 |         padding = get_same_padding(h, self.kernel, self.dilation, self.stride)
 407 |         fmap = F.conv2d(fmap, weights, padding = padding, groups = b)
 408 | 
 409 |         return rearrange(fmap, '1 (b o) ... -> b o ...', b = b)
 410 | 
 411 | class AdaptiveConv1DMod(nn.Module):
 412 |     """ 1d version of adaptive conv, for time dimension in videogigagan """
 413 | 
 414 |     def __init__(
 415 |         self,
 416 |         dim,
 417 |         dim_out,
 418 |         kernel,
 419 |         *,
 420 |         demod = True,
 421 |         stride = 1,
 422 |         dilation = 1,
 423 |         eps = 1e-8,
 424 |         num_conv_kernels = 1 # set this to be greater than 1 for adaptive
 425 |     ):
 426 |         super().__init__()
 427 |         self.eps = eps
 428 | 
 429 |         self.dim_out = dim_out
 430 | 
 431 |         self.kernel = kernel
 432 |         self.stride = stride
 433 |         self.dilation = dilation
 434 |         self.adaptive = num_conv_kernels > 1
 435 | 
 436 |         self.weights = nn.Parameter(torch.randn((num_conv_kernels, dim_out, dim, kernel)))
 437 | 
 438 |         self.demod = demod
 439 | 
 440 |         nn.init.kaiming_normal_(self.weights, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu')
 441 | 
 442 |     def forward(
 443 |         self,
 444 |         fmap,
 445 |         mod: Tensor,
 446 |         kernel_mod: Tensor | None = None
 447 |     ):
 448 |         """
 449 |         notation
 450 | 
 451 |         b - batch
 452 |         n - convs
 453 |         o - output
 454 |         i - input
 455 |         k - kernel
 456 |         """
 457 | 
 458 |         b, t = fmap.shape[0], fmap.shape[-1]
 459 | 
 460 |         # account for feature map that has been expanded by the scale in the first dimension
 461 |         # due to multiscale inputs and outputs
 462 | 
 463 |         if mod.shape[0] != b:
 464 |             mod = repeat(mod, 'b ... -> (s b) ...', s = b // mod.shape[0])
 465 | 
 466 |         if exists(kernel_mod):
 467 |             kernel_mod_has_el = kernel_mod.numel() > 0
 468 | 
 469 |             assert self.adaptive or not kernel_mod_has_el
 470 | 
 471 |             if kernel_mod_has_el and kernel_mod.shape[0] != b:
 472 |                 kernel_mod = repeat(kernel_mod, 'b ... -> (s b) ...', s = b // kernel_mod.shape[0])
 473 | 
 474 |         # prepare weights for modulation
 475 | 
 476 |         weights = self.weights
 477 | 
 478 |         if self.adaptive:
 479 |             weights = repeat(weights, '... -> b ...', b = b)
 480 | 
 481 |             # determine an adaptive weight and 'select' the kernel to use with softmax
 482 | 
 483 |             assert exists(kernel_mod) and kernel_mod.numel() > 0
 484 | 
 485 |             kernel_attn = kernel_mod.softmax(dim = -1)
 486 |             kernel_attn = rearrange(kernel_attn, 'b n -> b n 1 1 1')
 487 | 
 488 |             weights = reduce(weights * kernel_attn, 'b n ... -> b ...', 'sum')
 489 | 
 490 |         # do the modulation, demodulation, as done in stylegan2
 491 | 
 492 |         mod = rearrange(mod, 'b i -> b 1 i 1')
 493 | 
 494 |         weights = weights * (mod + 1)
 495 | 
 496 |         if self.demod:
 497 |             inv_norm = reduce(weights ** 2, 'b o i k -> b o 1 1', 'sum').clamp(min = self.eps).rsqrt()
 498 |             weights = weights * inv_norm
 499 | 
 500 |         fmap = rearrange(fmap, 'b c t -> 1 (b c) t')
 501 | 
 502 |         weights = rearrange(weights, 'b o ... -> (b o) ...')
 503 | 
 504 |         padding = get_same_padding(t, self.kernel, self.dilation, self.stride)
 505 |         fmap = F.conv1d(fmap, weights, padding = padding, groups = b)
 506 | 
 507 |         return rearrange(fmap, '1 (b o) ... -> b o ...', b = b)
 508 | 
 509 | # attention
 510 | # they use an attention with a better Lipchitz constant - l2 distance similarity instead of dot product - also shared query / key space - shown in vitgan to be more stable
 511 | # not sure what they did about token attention to self, so masking out, as done in some other papers using shared query / key space
 512 | 
 513 | class SelfAttention(nn.Module):
 514 |     def __init__(
 515 |         self,
 516 |         dim,
 517 |         dim_head = 64,
 518 |         heads = 8,
 519 |         dot_product = False
 520 |     ):
 521 |         super().__init__()
 522 |         self.heads = heads
 523 |         self.scale = dim_head ** -0.5
 524 |         dim_inner = dim_head * heads
 525 | 
 526 |         self.dot_product = dot_product
 527 | 
 528 |         self.norm = ChannelRMSNorm(dim)
 529 | 
 530 |         self.to_q = nn.Conv2d(dim, dim_inner, 1, bias = False)
 531 |         self.to_k = nn.Conv2d(dim, dim_inner, 1, bias = False) if dot_product else None
 532 |         self.to_v = nn.Conv2d(dim, dim_inner, 1, bias = False)
 533 | 
 534 |         self.null_kv = nn.Parameter(torch.randn(2, heads, dim_head))
 535 | 
 536 |         self.to_out = nn.Conv2d(dim_inner, dim, 1, bias = False)
 537 | 
 538 |     def forward(self, fmap):
 539 |         """
 540 |         einstein notation
 541 | 
 542 |         b - batch
 543 |         h - heads
 544 |         x - height
 545 |         y - width
 546 |         d - dimension
 547 |         i - source seq (attend from)
 548 |         j - target seq (attend to)
 549 |         """
 550 |         batch = fmap.shape[0]
 551 | 
 552 |         fmap = self.norm(fmap)
 553 | 
 554 |         x, y = fmap.shape[-2:]
 555 | 
 556 |         h = self.heads
 557 | 
 558 |         q, v = self.to_q(fmap), self.to_v(fmap)
 559 | 
 560 |         k = self.to_k(fmap) if exists(self.to_k) else q
 561 | 
 562 |         q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = self.heads), (q, k, v))
 563 | 
 564 |         # add a null key / value, so network can choose to pay attention to nothing
 565 | 
 566 |         nk, nv = map(lambda t: repeat(t, 'h d -> (b h) 1 d', b = batch), self.null_kv)
 567 | 
 568 |         k = torch.cat((nk, k), dim = -2)
 569 |         v = torch.cat((nv, v), dim = -2)
 570 | 
 571 |         # l2 distance or dot product
 572 | 
 573 |         if self.dot_product:
 574 |             sim = einsum('b i d, b j d -> b i j', q, k)
 575 |         else:
 576 |             # using pytorch cdist leads to nans in lightweight gan training framework, at least
 577 |             q_squared = (q * q).sum(dim = -1)
 578 |             k_squared = (k * k).sum(dim = -1)
 579 |             l2dist_squared = rearrange(q_squared, 'b i -> b i 1') + rearrange(k_squared, 'b j -> b 1 j') - 2 * einsum('b i d, b j d -> b i j', q, k) # hope i'm mathing right
 580 |             sim = -l2dist_squared
 581 | 
 582 |         # scale
 583 | 
 584 |         sim = sim * self.scale
 585 | 
 586 |         # attention
 587 | 
 588 |         attn = sim.softmax(dim = -1)
 589 | 
 590 |         out = einsum('b i j, b j d -> b i d', attn, v)
 591 | 
 592 |         out = rearrange(out, '(b h) (x y) d -> b (h d) x y', x = x, y = y, h = h)
 593 | 
 594 |         return self.to_out(out)
 595 | 
 596 | class CrossAttention(nn.Module):
 597 |     def __init__(
 598 |         self,
 599 |         dim,
 600 |         dim_context,
 601 |         dim_head = 64,
 602 |         heads = 8
 603 |     ):
 604 |         super().__init__()
 605 |         self.heads = heads
 606 |         self.scale = dim_head ** -0.5
 607 |         dim_inner = dim_head * heads
 608 |         kv_input_dim = default(dim_context, dim)
 609 | 
 610 |         self.norm = ChannelRMSNorm(dim)
 611 |         self.norm_context = RMSNorm(kv_input_dim)
 612 | 
 613 |         self.to_q = nn.Conv2d(dim, dim_inner, 1, bias = False)
 614 |         self.to_kv = nn.Linear(kv_input_dim, dim_inner * 2, bias = False)
 615 |         self.to_out = nn.Conv2d(dim_inner, dim, 1, bias = False)
 616 | 
 617 |     def forward(self, fmap, context, mask = None):
 618 |         """
 619 |         einstein notation
 620 | 
 621 |         b - batch
 622 |         h - heads
 623 |         x - height
 624 |         y - width
 625 |         d - dimension
 626 |         i - source seq (attend from)
 627 |         j - target seq (attend to)
 628 |         """
 629 | 
 630 |         fmap = self.norm(fmap)
 631 |         context = self.norm_context(context)
 632 | 
 633 |         x, y = fmap.shape[-2:]
 634 | 
 635 |         h = self.heads
 636 | 
 637 |         q, k, v = (self.to_q(fmap), *self.to_kv(context).chunk(2, dim = -1))
 638 | 
 639 |         k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (k, v))
 640 | 
 641 |         q = rearrange(q, 'b (h d) x y -> (b h) (x y) d', h = self.heads)
 642 | 
 643 |         sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
 644 | 
 645 |         if exists(mask):
 646 |             mask = repeat(mask, 'b j -> (b h) 1 j', h = self.heads)
 647 |             sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
 648 | 
 649 |         attn = sim.softmax(dim = -1)
 650 | 
 651 |         out = einsum('b i j, b j d -> b i d', attn, v)
 652 | 
 653 |         out = rearrange(out, '(b h) (x y) d -> b (h d) x y', x = x, y = y, h = h)
 654 | 
 655 |         return self.to_out(out)
 656 | 
 657 | # classic transformer attention, stick with l2 distance
 658 | 
 659 | class TextAttention(nn.Module):
 660 |     def __init__(
 661 |         self,
 662 |         dim,
 663 |         dim_head = 64,
 664 |         heads = 8
 665 |     ):
 666 |         super().__init__()
 667 |         self.heads = heads
 668 |         self.scale = dim_head ** -0.5
 669 |         dim_inner = dim_head * heads
 670 | 
 671 |         self.norm = RMSNorm(dim)
 672 |         self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
 673 | 
 674 |         self.null_kv = nn.Parameter(torch.randn(2, heads, dim_head))
 675 | 
 676 |         self.to_out = nn.Linear(dim_inner, dim, bias = False)
 677 | 
 678 |     def forward(self, encodings, mask = None):
 679 |         """
 680 |         einstein notation
 681 | 
 682 |         b - batch
 683 |         h - heads
 684 |         x - height
 685 |         y - width
 686 |         d - dimension
 687 |         i - source seq (attend from)
 688 |         j - target seq (attend to)
 689 |         """
 690 |         batch = encodings.shape[0]
 691 | 
 692 |         encodings = self.norm(encodings)
 693 | 
 694 |         h = self.heads
 695 | 
 696 |         q, k, v = self.to_qkv(encodings).chunk(3, dim = -1)
 697 |         q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v))
 698 | 
 699 |         # add a null key / value, so network can choose to pay attention to nothing
 700 | 
 701 |         nk, nv = map(lambda t: repeat(t, 'h d -> (b h) 1 d', b = batch), self.null_kv)
 702 | 
 703 |         k = torch.cat((nk, k), dim = -2)
 704 |         v = torch.cat((nv, v), dim = -2)
 705 | 
 706 |         sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
 707 | 
 708 |         # key padding mask
 709 | 
 710 |         if exists(mask):
 711 |             mask = F.pad(mask, (1, 0), value = True)
 712 |             mask = repeat(mask, 'b n -> (b h) 1 n', h = h)
 713 |             sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
 714 | 
 715 |         # attention
 716 | 
 717 |         attn = sim.softmax(dim = -1)
 718 |         out = einsum('b i j, b j d -> b i d', attn, v)
 719 | 
 720 |         out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
 721 | 
 722 |         return self.to_out(out)
 723 | 
 724 | # feedforward
 725 | 
 726 | def FeedForward(
 727 |     dim,
 728 |     mult = 4,
 729 |     channel_first = False
 730 | ):
 731 |     dim_hidden = int(dim * mult)
 732 |     norm_klass = ChannelRMSNorm if channel_first else RMSNorm
 733 |     proj = partial(nn.Conv2d, kernel_size = 1) if channel_first else nn.Linear
 734 | 
 735 |     return nn.Sequential(
 736 |         norm_klass(dim),
 737 |         proj(dim, dim_hidden),
 738 |         nn.GELU(),
 739 |         proj(dim_hidden, dim)
 740 |     )
 741 | 
 742 | # different types of transformer blocks or transformers (multiple blocks)
 743 | 
 744 | class SelfAttentionBlock(nn.Module):
 745 |     def __init__(
 746 |         self,
 747 |         dim,
 748 |         dim_head = 64,
 749 |         heads = 8,
 750 |         ff_mult = 4,
 751 |         dot_product = False
 752 |     ):
 753 |         super().__init__()
 754 |         self.attn = SelfAttention(dim = dim, dim_head = dim_head, heads = heads, dot_product = dot_product)
 755 |         self.ff = FeedForward(dim = dim, mult = ff_mult, channel_first = True)
 756 | 
 757 |     def forward(self, x):
 758 |         x = self.attn(x) + x
 759 |         x = self.ff(x) + x
 760 |         return x
 761 | 
 762 | class CrossAttentionBlock(nn.Module):
 763 |     def __init__(
 764 |         self,
 765 |         dim,
 766 |         dim_context,
 767 |         dim_head = 64,
 768 |         heads = 8,
 769 |         ff_mult = 4
 770 |     ):
 771 |         super().__init__()
 772 |         self.attn = CrossAttention(dim = dim, dim_context = dim_context, dim_head = dim_head, heads = heads)
 773 |         self.ff = FeedForward(dim = dim, mult = ff_mult, channel_first = True)
 774 | 
 775 |     def forward(self, x, context, mask = None):
 776 |         x = self.attn(x, context = context, mask = mask) + x
 777 |         x = self.ff(x) + x
 778 |         return x
 779 | 
 780 | class Transformer(nn.Module):
 781 |     def __init__(
 782 |         self,
 783 |         dim,
 784 |         depth,
 785 |         dim_head = 64,
 786 |         heads = 8,
 787 |         ff_mult = 4
 788 |     ):
 789 |         super().__init__()
 790 |         self.layers = nn.ModuleList([])
 791 |         for _ in range(depth):
 792 |             self.layers.append(nn.ModuleList([
 793 |                 TextAttention(dim = dim, dim_head = dim_head, heads = heads),
 794 |                 FeedForward(dim = dim, mult = ff_mult)
 795 |             ]))
 796 | 
 797 |         self.norm = RMSNorm(dim)
 798 | 
 799 |     def forward(self, x, mask = None):
 800 |         for attn, ff in self.layers:
 801 |             x = attn(x, mask = mask) + x
 802 |             x = ff(x) + x
 803 | 
 804 |         return self.norm(x)
 805 | 
 806 | # text encoder
 807 | 
 808 | class TextEncoder(nn.Module):
 809 |     @beartype
 810 |     def __init__(
 811 |         self,
 812 |         *,
 813 |         dim,
 814 |         depth,
 815 |         clip: OpenClipAdapter | None = None,
 816 |         dim_head = 64,
 817 |         heads = 8,
 818 |     ):
 819 |         super().__init__()
 820 |         self.dim = dim
 821 | 
 822 |         if not exists(clip):
 823 |             clip = OpenClipAdapter()
 824 | 
 825 |         self.clip = clip
 826 |         set_requires_grad_(clip, False)
 827 | 
 828 |         self.learned_global_token = nn.Parameter(torch.randn(dim))
 829 | 
 830 |         self.project_in = nn.Linear(clip.dim_latent, dim) if clip.dim_latent != dim else nn.Identity()
 831 | 
 832 |         self.transformer = Transformer(
 833 |             dim = dim,
 834 |             depth = depth,
 835 |             dim_head = dim_head,
 836 |             heads = heads
 837 |         )
 838 | 
 839 |     @beartype
 840 |     def forward(
 841 |         self,
 842 |         texts: List[str] | None = None,
 843 |         text_encodings: Tensor | None = None
 844 |     ):
 845 |         assert exists(texts) ^ exists(text_encodings)
 846 | 
 847 |         if not exists(text_encodings):
 848 |             with torch.no_grad():
 849 |                 self.clip.eval()
 850 |                 _, text_encodings = self.clip.embed_texts(texts)
 851 | 
 852 |         mask = (text_encodings != 0.).any(dim = -1)
 853 | 
 854 |         text_encodings = self.project_in(text_encodings)
 855 | 
 856 |         mask_with_global = F.pad(mask, (1, 0), value = True)
 857 | 
 858 |         batch = text_encodings.shape[0]
 859 |         global_tokens = repeat(self.learned_global_token, 'd -> b d', b = batch)
 860 | 
 861 |         text_encodings, ps = pack([global_tokens, text_encodings], 'b * d')
 862 | 
 863 |         text_encodings = self.transformer(text_encodings, mask = mask_with_global)
 864 | 
 865 |         global_tokens, text_encodings = unpack(text_encodings, ps, 'b * d')
 866 | 
 867 |         return global_tokens, text_encodings, mask
 868 | 
 869 | # style mapping network
 870 | 
 871 | class EqualLinear(nn.Module):
 872 |     def __init__(
 873 |         self,
 874 |         dim,
 875 |         dim_out,
 876 |         lr_mul = 1,
 877 |         bias = True
 878 |     ):
 879 |         super().__init__()
 880 |         self.weight = nn.Parameter(torch.randn(dim_out, dim))
 881 |         if bias:
 882 |             self.bias = nn.Parameter(torch.zeros(dim_out))
 883 | 
 884 |         self.lr_mul = lr_mul
 885 | 
 886 |     def forward(self, input):
 887 |         return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)
 888 | 
 889 | class StyleNetwork(nn.Module):
 890 |     def __init__(
 891 |         self,
 892 |         dim,
 893 |         depth,
 894 |         lr_mul = 0.1,
 895 |         dim_text_latent = 0
 896 |     ):
 897 |         super().__init__()
 898 |         self.dim = dim
 899 |         self.dim_text_latent = dim_text_latent
 900 | 
 901 |         layers = []
 902 |         for i in range(depth):
 903 |             is_first = i == 0
 904 |             dim_in = (dim + dim_text_latent) if is_first else dim
 905 | 
 906 |             layers.extend([EqualLinear(dim_in, dim, lr_mul), leaky_relu()])
 907 | 
 908 |         self.net = nn.Sequential(*layers)
 909 | 
 910 |     def forward(
 911 |         self,
 912 |         x,
 913 |         text_latent = None
 914 |     ):
 915 |         x = F.normalize(x, dim = 1)
 916 | 
 917 |         if self.dim_text_latent > 0:
 918 |             assert exists(text_latent)
 919 |             x = torch.cat((x, text_latent), dim = -1)
 920 | 
 921 |         return self.net(x)
 922 | 
 923 | # noise
 924 | 
 925 | class Noise(nn.Module):
 926 |     def __init__(self, dim):
 927 |         super().__init__()
 928 |         self.weight = nn.Parameter(torch.zeros(dim, 1, 1))
 929 | 
 930 |     def forward(
 931 |         self,
 932 |         x,
 933 |         noise = None
 934 |     ):
 935 |         b, _, h, w, device = *x.shape, x.device
 936 | 
 937 |         if not exists(noise):
 938 |             noise = torch.randn(b, 1, h, w, device = device)
 939 | 
 940 |         return x + self.weight * noise
 941 | 
 942 | # generator
 943 | 
 944 | class BaseGenerator(nn.Module):
 945 |     pass
 946 | 
 947 | class Generator(BaseGenerator):
 948 |     @beartype
 949 |     def __init__(
 950 |         self,
 951 |         *,
 952 |         image_size,
 953 |         dim_capacity = 16,
 954 |         dim_max = 2048,
 955 |         channels = 3,
 956 |         style_network: StyleNetwork | Dict | None = None,
 957 |         style_network_dim = None,
 958 |         text_encoder: TextEncoder | Dict | None = None,
 959 |         dim_latent = 512,
 960 |         self_attn_resolutions: Tuple[int, ...] = (32, 16),
 961 |         self_attn_dim_head = 64,
 962 |         self_attn_heads = 8,
 963 |         self_attn_dot_product = True,
 964 |         self_attn_ff_mult = 4,
 965 |         cross_attn_resolutions: Tuple[int, ...] = (32, 16),
 966 |         cross_attn_dim_head = 64,
 967 |         cross_attn_heads = 8,
 968 |         cross_attn_ff_mult = 4,
 969 |         num_conv_kernels = 2,  # the number of adaptive conv kernels
 970 |         num_skip_layers_excite = 0,
 971 |         unconditional = False,
 972 |         pixel_shuffle_upsample = False
 973 |     ):
 974 |         super().__init__()
 975 |         self.channels = channels
 976 | 
 977 |         if isinstance(style_network, dict):
 978 |             style_network = StyleNetwork(**style_network)
 979 | 
 980 |         self.style_network = style_network
 981 | 
 982 |         assert exists(style_network) ^ exists(style_network_dim), 'style_network_dim must be given to the generator if StyleNetwork not passed in as style_network'
 983 | 
 984 |         if not exists(style_network_dim):
 985 |             style_network_dim = style_network.dim
 986 | 
 987 |         self.style_network_dim = style_network_dim
 988 | 
 989 |         if isinstance(text_encoder, dict):
 990 |             text_encoder = TextEncoder(**text_encoder)
 991 | 
 992 |         self.text_encoder = text_encoder
 993 | 
 994 |         self.unconditional = unconditional
 995 | 
 996 |         assert not (unconditional and exists(text_encoder))
 997 |         assert not (unconditional and exists(style_network) and style_network.dim_text_latent > 0)
 998 |         assert unconditional or (exists(text_encoder) and text_encoder.dim == style_network.dim_text_latent), 'the `dim_text_latent` on your StyleNetwork must be equal to the `dim` set for the TextEncoder'
 999 | 
1000 |         assert is_power_of_two(image_size)
1001 |         num_layers = int(log2(image_size) - 1)
1002 |         self.num_layers = num_layers
1003 | 
1004 |         # generator requires convolutions conditioned by the style vector
1005 |         # and also has N convolutional kernels adaptively selected (one of the only novelties of the paper)
1006 | 
1007 |         is_adaptive = num_conv_kernels > 1
1008 |         dim_kernel_mod = num_conv_kernels if is_adaptive else 0
1009 | 
1010 |         style_embed_split_dims = []
1011 | 
1012 |         adaptive_conv = partial(AdaptiveConv2DMod, kernel = 3, num_conv_kernels = num_conv_kernels)
1013 | 
1014 |         # initial 4x4 block and conv
1015 | 
1016 |         self.init_block = nn.Parameter(torch.randn(dim_latent, 4, 4))
1017 |         self.init_conv = adaptive_conv(dim_latent, dim_latent)
1018 | 
1019 |         style_embed_split_dims.extend([
1020 |             dim_latent,
1021 |             dim_kernel_mod
1022 |         ])
1023 | 
1024 |         # main network
1025 | 
1026 |         num_layers = int(log2(image_size) - 1)
1027 |         self.num_layers = num_layers
1028 | 
1029 |         resolutions = image_size / ((2 ** torch.arange(num_layers).flip(0)))
1030 |         resolutions = resolutions.long().tolist()
1031 | 
1032 |         dim_layers = (2 ** (torch.arange(num_layers) + 1)) * dim_capacity
1033 |         dim_layers.clamp_(max = dim_max)
1034 | 
1035 |         dim_layers = torch.flip(dim_layers, (0,))
1036 |         dim_layers = F.pad(dim_layers, (1, 0), value = dim_latent)
1037 | 
1038 |         dim_layers = dim_layers.tolist()
1039 | 
1040 |         dim_pairs = list(zip(dim_layers[:-1], dim_layers[1:]))
1041 | 
1042 |         self.num_skip_layers_excite = num_skip_layers_excite
1043 | 
1044 |         self.layers = nn.ModuleList([])
1045 | 
1046 |         # go through layers and construct all parameters
1047 | 
1048 |         for ind, ((dim_in, dim_out), resolution) in enumerate(zip(dim_pairs, resolutions)):
1049 |             is_last = (ind + 1) == len(dim_pairs)
1050 |             is_first = ind == 0
1051 | 
1052 |             should_upsample = not is_first
1053 |             should_upsample_rgb = not is_last
1054 |             should_skip_layer_excite = num_skip_layers_excite > 0 and (ind + num_skip_layers_excite) < len(dim_pairs)
1055 | 
1056 |             has_self_attn = resolution in self_attn_resolutions
1057 |             has_cross_attn = resolution in cross_attn_resolutions and not unconditional
1058 | 
1059 |             skip_squeeze_excite = None
1060 |             if should_skip_layer_excite:
1061 |                 dim_skip_in, _ = dim_pairs[ind + num_skip_layers_excite]
1062 |                 skip_squeeze_excite = SqueezeExcite(dim_in, dim_skip_in)
1063 | 
1064 |             resnet_block = nn.ModuleList([
1065 |                 adaptive_conv(dim_in, dim_out),
1066 |                 Noise(dim_out),
1067 |                 leaky_relu(),
1068 |                 adaptive_conv(dim_out, dim_out),
1069 |                 Noise(dim_out),
1070 |                 leaky_relu()
1071 |             ])
1072 | 
1073 |             to_rgb = AdaptiveConv2DMod(dim_out, channels, 1, num_conv_kernels = 1, demod = False)
1074 | 
1075 |             self_attn = cross_attn = rgb_upsample = upsample = None
1076 | 
1077 |             upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample
1078 | 
1079 |             upsample = upsample_klass(dim_in) if should_upsample else None
1080 |             rgb_upsample = upsample_klass(channels) if should_upsample_rgb else None
1081 | 
1082 |             if has_self_attn:
1083 |                 self_attn = SelfAttentionBlock(
1084 |                     dim_out,
1085 |                     dim_head = self_attn_dim_head,
1086 |                     heads = self_attn_heads,
1087 |                     ff_mult = self_attn_ff_mult,
1088 |                     dot_product = self_attn_dot_product
1089 |             )
1090 | 
1091 |             if has_cross_attn:
1092 |                 cross_attn = CrossAttentionBlock(
1093 |                     dim_out,
1094 |                     dim_context = text_encoder.dim,
1095 |                     dim_head = cross_attn_dim_head,
1096 |                     heads = cross_attn_heads,
1097 |                     ff_mult = cross_attn_ff_mult,
1098 |                 )
1099 | 
1100 |             style_embed_split_dims.extend([
1101 |                 dim_in,             # for first conv in resnet block
1102 |                 dim_kernel_mod,     # first conv kernel selection
1103 |                 dim_out,            # second conv in resnet block
1104 |                 dim_kernel_mod,     # second conv kernel selection
1105 |                 dim_out,            # to RGB conv
1106 |                 0,                  # RGB conv kernel selection
1107 |             ])
1108 | 
1109 |             self.layers.append(nn.ModuleList([
1110 |                 skip_squeeze_excite,
1111 |                 resnet_block,
1112 |                 to_rgb,
1113 |                 self_attn,
1114 |                 cross_attn,
1115 |                 upsample,
1116 |                 rgb_upsample
1117 |             ]))
1118 | 
1119 |         # determine the projection of the style embedding to convolutional modulation weights (+ adaptive kernel selection weights) for all layers
1120 | 
1121 |         self.style_to_conv_modulations = nn.Linear(style_network_dim, sum(style_embed_split_dims))
1122 |         self.style_embed_split_dims = style_embed_split_dims
1123 | 
1124 |         self.apply(self.init_)
1125 |         nn.init.normal_(self.init_block, std = 0.02)
1126 | 
1127 |     def init_(self, m):
1128 |         if type(m) in {nn.Conv2d, nn.Linear}:
1129 |             nn.init.kaiming_normal_(m.weight, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu')
1130 | 
1131 |     @property
1132 |     def total_params(self):
1133 |         return sum([p.numel() for p in self.parameters() if p.requires_grad])
1134 | 
1135 |     @property
1136 |     def device(self):
1137 |         return next(self.parameters()).device
1138 | 
1139 |     @beartype
1140 |     def forward(
1141 |         self,
1142 |         styles = None,
1143 |         noise = None,
1144 |         texts: List[str] | None = None,
1145 |         text_encodings: Tensor | None = None,
1146 |         global_text_tokens = None,
1147 |         fine_text_tokens = None,
1148 |         text_mask = None,
1149 |         batch_size = 1,
1150 |         return_all_rgbs = False
1151 |     ):
1152 |         # take care of text encodings
1153 |         # which requires global text tokens to adaptively select the kernels from the main contribution in the paper
1154 |         # and fine text tokens to attend to using cross attention
1155 | 
1156 |         if not self.unconditional:
1157 |             if exists(texts) or exists(text_encodings):
1158 |                 assert exists(texts) ^ exists(text_encodings), 'either raw texts as List[str] or text_encodings (from clip) as Tensor is passed in, but not both'
1159 |                 assert exists(self.text_encoder)
1160 | 
1161 |                 if exists(texts):
1162 |                     text_encoder_kwargs = dict(texts = texts)
1163 |                 elif exists(text_encodings):
1164 |                     text_encoder_kwargs = dict(text_encodings = text_encodings)
1165 | 
1166 |                 global_text_tokens, fine_text_tokens, text_mask = self.text_encoder(**text_encoder_kwargs)
1167 |             else:
1168 |                 assert all([*map(exists, (global_text_tokens, fine_text_tokens, text_mask))]), 'raw text or text embeddings were not passed in for conditional training'
1169 |         else:
1170 |             assert not any([*map(exists, (texts, global_text_tokens, fine_text_tokens))])
1171 | 
1172 |         # determine styles
1173 | 
1174 |         if not exists(styles):
1175 |             assert exists(self.style_network)
1176 | 
1177 |             if not exists(noise):
1178 |                 noise = torch.randn((batch_size, self.style_network_dim), device = self.device)
1179 | 
1180 |             styles = self.style_network(noise, global_text_tokens)
1181 | 
1182 |         # project styles to conv modulations
1183 | 
1184 |         conv_mods = self.style_to_conv_modulations(styles)
1185 |         conv_mods = conv_mods.split(self.style_embed_split_dims, dim = -1)
1186 |         conv_mods = iter(conv_mods)
1187 | 
1188 |         # prepare initial block
1189 | 
1190 |         batch_size = styles.shape[0]
1191 | 
1192 |         x = repeat(self.init_block, 'c h w -> b c h w', b = batch_size)
1193 |         x = self.init_conv(x, mod = next(conv_mods), kernel_mod = next(conv_mods))
1194 | 
1195 |         rgb = torch.zeros((batch_size, self.channels, 4, 4), device = self.device, dtype = x.dtype)
1196 | 
1197 |         # skip layer squeeze excitations
1198 | 
1199 |         excitations = [None] * self.num_skip_layers_excite
1200 | 
1201 |         # all the rgb's of each layer of the generator is to be saved for multi-resolution input discrimination
1202 | 
1203 |         rgbs = []
1204 | 
1205 |         # main network
1206 | 
1207 |         for squeeze_excite, (resnet_conv1, noise1, act1, resnet_conv2, noise2, act2), to_rgb_conv, self_attn, cross_attn, upsample, upsample_rgb in self.layers:
1208 | 
1209 |             if exists(upsample):
1210 |                 x = upsample(x)
1211 | 
1212 |             if exists(squeeze_excite):
1213 |                 skip_excite = squeeze_excite(x)
1214 |                 excitations.append(skip_excite)
1215 | 
1216 |             excite = safe_unshift(excitations)
1217 |             if exists(excite):
1218 |                 x = x * excite
1219 | 
1220 |             x = resnet_conv1(x, mod = next(conv_mods), kernel_mod = next(conv_mods))
1221 |             x = noise1(x)
1222 |             x = act1(x)
1223 | 
1224 |             x = resnet_conv2(x, mod = next(conv_mods), kernel_mod = next(conv_mods))
1225 |             x = noise2(x)
1226 |             x = act2(x)
1227 | 
1228 |             if exists(self_attn):
1229 |                 x = self_attn(x)
1230 | 
1231 |             if exists(cross_attn):
1232 |                 x = cross_attn(x, context = fine_text_tokens, mask = text_mask)
1233 | 
1234 |             layer_rgb = to_rgb_conv(x, mod = next(conv_mods), kernel_mod = next(conv_mods))
1235 | 
1236 |             rgb = rgb + layer_rgb
1237 | 
1238 |             rgbs.append(rgb)
1239 | 
1240 |             if exists(upsample_rgb):
1241 |                 rgb = upsample_rgb(rgb)
1242 | 
1243 |         # sanity check
1244 | 
1245 |         assert is_empty([*conv_mods]), 'convolutions were incorrectly modulated'
1246 | 
1247 |         if return_all_rgbs:
1248 |             return rgb, rgbs
1249 | 
1250 |         return rgb
1251 | 
1252 | # discriminator
1253 | 
1254 | @beartype
1255 | class SimpleDecoder(nn.Module):
1256 |     def __init__(
1257 |         self,
1258 |         dim,
1259 |         *,
1260 |         dims: Tuple[int, ...],
1261 |         patch_dim: int = 1,
1262 |         frac_patches: float = 1.,
1263 |         dropout: float = 0.5
1264 |     ):
1265 |         super().__init__()
1266 |         assert 0 < frac_patches <= 1.
1267 | 
1268 |         self.patch_dim = patch_dim
1269 |         self.frac_patches = frac_patches
1270 | 
1271 |         self.dropout = nn.Dropout(dropout)
1272 | 
1273 |         dims = [dim, *dims]
1274 | 
1275 |         layers = [conv2d_3x3(dim, dim)]
1276 | 
1277 |         for dim_in, dim_out in zip(dims[:-1], dims[1:]):
1278 |             layers.append(nn.Sequential(
1279 |                 Upsample(dim_in),
1280 |                 conv2d_3x3(dim_in, dim_out),
1281 |                 leaky_relu()
1282 |             ))
1283 | 
1284 |         self.net = nn.Sequential(*layers)
1285 | 
1286 |     @property
1287 |     def device(self):
1288 |         return next(self.parameters()).device
1289 | 
1290 |     def forward(
1291 |         self,
1292 |         fmap,
1293 |         orig_image
1294 |     ):
1295 |         fmap = self.dropout(fmap)
1296 | 
1297 |         if self.frac_patches < 1.:
1298 |             batch, patch_dim = fmap.shape[0], self.patch_dim
1299 |             fmap_size, img_size = fmap.shape[-1], orig_image.shape[-1]
1300 | 
1301 |             assert divisible_by(fmap_size, patch_dim), f'feature map dimensions are {fmap_size}, but the patch dim was designated to be {patch_dim}'
1302 |             assert divisible_by(img_size, patch_dim), f'image size is {img_size} but the patch dim was specified to be {patch_dim}'
1303 | 
1304 |             fmap, orig_image = map(lambda t: rearrange(t, 'b c (p1 h) (p2 w) -> b (p1 p2) c h w', p1 = patch_dim, p2 = patch_dim), (fmap, orig_image))
1305 | 
1306 |             total_patches = patch_dim ** 2
1307 |             num_patches_recon = max(int(self.frac_patches * total_patches), 1)
1308 | 
1309 |             batch_arange = torch.arange(batch, device = self.device)[..., None]
1310 |             batch_randperm = torch.randn((batch, total_patches)).sort(dim = -1).indices
1311 |             patch_indices = batch_randperm[..., :num_patches_recon]
1312 | 
1313 |             fmap, orig_image = map(lambda t: t[batch_arange, patch_indices], (fmap, orig_image))
1314 |             fmap, orig_image = map(lambda t: rearrange(t, 'b p ... -> (b p) ...'), (fmap, orig_image))
1315 | 
1316 |         recon = self.net(fmap)
1317 |         return F.mse_loss(recon, orig_image)
1318 | 
1319 | class RandomFixedProjection(nn.Module):
1320 |     def __init__(
1321 |         self,
1322 |         dim,
1323 |         dim_out,
1324 |         channel_first = True
1325 |     ):
1326 |         super().__init__()
1327 |         weights = torch.randn(dim, dim_out)
1328 |         nn.init.kaiming_normal_(weights, mode = 'fan_out', nonlinearity = 'linear')
1329 | 
1330 |         self.channel_first = channel_first
1331 |         self.register_buffer('fixed_weights', weights)
1332 | 
1333 |     def forward(self, x):
1334 |         if not self.channel_first:
1335 |             return x @ self.fixed_weights
1336 | 
1337 |         return einsum('b c ..., c d -> b d ...', x, self.fixed_weights)
1338 | 
1339 | class VisionAidedDiscriminator(nn.Module):
1340 |     """ the vision-aided gan loss """
1341 | 
1342 |     @beartype
1343 |     def __init__(
1344 |         self,
1345 |         *,
1346 |         depth = 2,
1347 |         dim_head = 64,
1348 |         heads = 8,
1349 |         clip: OpenClipAdapter | None = None,
1350 |         layer_indices = (-1, -2, -3),
1351 |         conv_dim = None,
1352 |         text_dim = None,
1353 |         unconditional = False,
1354 |         num_conv_kernels = 2
1355 |     ):
1356 |         super().__init__()
1357 | 
1358 |         if not exists(clip):
1359 |             clip = OpenClipAdapter()
1360 | 
1361 |         self.clip = clip
1362 |         dim = clip._dim_image_latent
1363 | 
1364 |         self.unconditional = unconditional
1365 |         text_dim = default(text_dim, dim)
1366 |         conv_dim = default(conv_dim, dim)
1367 | 
1368 |         self.layer_discriminators = nn.ModuleList([])
1369 |         self.layer_indices = layer_indices
1370 | 
1371 |         conv_klass = partial(AdaptiveConv2DMod, kernel = 3, num_conv_kernels = num_conv_kernels) if not unconditional else conv2d_3x3
1372 | 
1373 |         for _ in layer_indices:
1374 |             self.layer_discriminators.append(nn.ModuleList([
1375 |                 RandomFixedProjection(dim, conv_dim),
1376 |                 conv_klass(conv_dim, conv_dim),
1377 |                 nn.Linear(text_dim, conv_dim) if not unconditional else None,
1378 |                 nn.Linear(text_dim, num_conv_kernels) if not unconditional else None,
1379 |                 nn.Sequential(
1380 |                     conv2d_3x3(conv_dim, 1),
1381 |                     Rearrange('b 1 ... -> b ...')
1382 |                 )
1383 |             ]))
1384 | 
1385 |     def parameters(self):
1386 |         return self.layer_discriminators.parameters()
1387 | 
1388 |     @property
1389 |     def total_params(self):
1390 |         return sum([p.numel() for p in self.parameters()])
1391 | 
1392 |     @beartype
1393 |     def forward(
1394 |         self,
1395 |         images,
1396 |         texts: List[str] | None = None,
1397 |         text_embeds: Tensor | None = None,
1398 |         return_clip_encodings = False
1399 |     ):
1400 | 
1401 |         assert self.unconditional or (exists(text_embeds) ^ exists(texts))
1402 | 
1403 |         with torch.no_grad():
1404 |             if not self.unconditional and exists(texts):
1405 |                 self.clip.eval()
1406 |                 text_embeds = self.clip.embed_texts
1407 | 
1408 |         _, image_encodings = self.clip.embed_images(images)
1409 | 
1410 |         logits = []
1411 | 
1412 |         for layer_index, (rand_proj, conv, to_conv_mod, to_conv_kernel_mod, to_logits) in zip(self.layer_indices, self.layer_discriminators):
1413 |             image_encoding = image_encodings[layer_index]
1414 | 
1415 |             cls_token, rest_tokens = image_encoding[:, :1], image_encoding[:, 1:]
1416 |             height_width = int(sqrt(rest_tokens.shape[-2])) # assume square
1417 | 
1418 |             img_fmap = rearrange(rest_tokens, 'b (h w) d -> b d h w', h = height_width)
1419 | 
1420 |             img_fmap = img_fmap + rearrange(cls_token, 'b 1 d -> b d 1 1 ') # pool the cls token into the rest of the tokens
1421 | 
1422 |             img_fmap = rand_proj(img_fmap)
1423 | 
1424 |             if self.unconditional:
1425 |                 img_fmap = conv(img_fmap)
1426 |             else:
1427 |                 assert exists(text_embeds)
1428 | 
1429 |                 img_fmap = conv(
1430 |                     img_fmap,
1431 |                     mod = to_conv_mod(text_embeds),
1432 |                     kernel_mod = to_conv_kernel_mod(text_embeds)
1433 |                 )
1434 | 
1435 |             layer_logits = to_logits(img_fmap)
1436 | 
1437 |             logits.append(layer_logits)
1438 | 
1439 |         if not return_clip_encodings:
1440 |             return logits
1441 | 
1442 |         return logits, image_encodings
1443 | 
1444 | class Predictor(nn.Module):
1445 |     def __init__(
1446 |         self,
1447 |         dim,
1448 |         depth = 4,
1449 |         num_conv_kernels = 2,
1450 |         unconditional = False
1451 |     ):
1452 |         super().__init__()
1453 |         self.unconditional = unconditional
1454 |         self.residual_fn = nn.Conv2d(dim, dim, 1)
1455 |         self.residual_scale = 2 ** -0.5
1456 | 
1457 |         self.layers = nn.ModuleList([])
1458 | 
1459 |         klass = nn.Conv2d if unconditional else partial(AdaptiveConv2DMod, num_conv_kernels = num_conv_kernels)
1460 |         klass_kwargs = dict(padding = 1) if unconditional else dict()
1461 | 
1462 |         for ind in range(depth):
1463 |             self.layers.append(nn.ModuleList([
1464 |                 klass(dim, dim, 3, **klass_kwargs),
1465 |                 leaky_relu(),
1466 |                 klass(dim, dim, 3, **klass_kwargs),
1467 |                 leaky_relu()
1468 |             ]))
1469 | 
1470 |         self.to_logits = nn.Conv2d(dim, 1, 1)
1471 | 
1472 |     def forward(
1473 |         self,
1474 |         x,
1475 |         mod = None,
1476 |         kernel_mod = None
1477 |     ):
1478 |         residual = self.residual_fn(x)
1479 | 
1480 |         kwargs = dict()
1481 | 
1482 |         if not self.unconditional:
1483 |             kwargs = dict(mod = mod, kernel_mod = kernel_mod)
1484 | 
1485 |         for conv1, activation, conv2, activation in self.layers:
1486 | 
1487 |             inner_residual = x
1488 | 
1489 |             x = conv1(x, **kwargs)
1490 |             x = activation(x)
1491 |             x = conv2(x, **kwargs)
1492 |             x = activation(x)
1493 | 
1494 |             x = x + inner_residual
1495 |             x = x * self.residual_scale
1496 | 
1497 |         x = x + residual
1498 |         return self.to_logits(x)
1499 | 
1500 | class Discriminator(nn.Module):
1501 |     @beartype
1502 |     def __init__(
1503 |         self,
1504 |         *,
1505 |         dim_capacity = 16,
1506 |         image_size,
1507 |         dim_max = 2048,
1508 |         channels = 3,
1509 |         attn_resolutions: Tuple[int, ...] = (32, 16),
1510 |         attn_dim_head = 64,
1511 |         attn_heads = 8,
1512 |         self_attn_dot_product = False,
1513 |         ff_mult = 4,
1514 |         text_encoder: TextEncoder | Dict | None = None,
1515 |         text_dim = None,
1516 |         filter_input_resolutions: bool = True,
1517 |         multiscale_input_resolutions: Tuple[int, ...] = (64, 32, 16, 8),
1518 |         multiscale_output_skip_stages: int = 1,
1519 |         aux_recon_resolutions: Tuple[int, ...] = (8,),
1520 |         aux_recon_patch_dims: Tuple[int, ...] = (2,),
1521 |         aux_recon_frac_patches: Tuple[float, ...] = (0.25,),
1522 |         aux_recon_fmap_dropout: float = 0.5,
1523 |         resize_mode = 'bilinear',
1524 |         num_conv_kernels = 2,
1525 |         num_skip_layers_excite = 0,
1526 |         unconditional = False,
1527 |         predictor_depth = 2
1528 |     ):
1529 |         super().__init__()
1530 |         self.unconditional = unconditional
1531 |         assert not (unconditional and exists(text_encoder))
1532 | 
1533 |         assert is_power_of_two(image_size)
1534 |         assert all([*map(is_power_of_two, attn_resolutions)])
1535 | 
1536 |         if filter_input_resolutions:
1537 |             multiscale_input_resolutions = [*filter(lambda t: t < image_size, multiscale_input_resolutions)]
1538 | 
1539 |         assert is_unique(multiscale_input_resolutions)
1540 |         assert all([*map(is_power_of_two, multiscale_input_resolutions)])
1541 |         assert all([*map(lambda t: t < image_size, multiscale_input_resolutions)])
1542 | 
1543 |         self.multiscale_input_resolutions = multiscale_input_resolutions
1544 | 
1545 |         assert multiscale_output_skip_stages > 0
1546 |         multiscale_output_resolutions = [resolution // (2 ** multiscale_output_skip_stages) for resolution in multiscale_input_resolutions]
1547 | 
1548 |         assert all([*map(lambda t: t >= 4, multiscale_output_resolutions)])
1549 | 
1550 |         assert all([*map(lambda t: t < image_size, multiscale_output_resolutions)])
1551 | 
1552 |         if len(multiscale_input_resolutions) > 0 and len(multiscale_output_resolutions) > 0:
1553 |             assert max(multiscale_input_resolutions) > max(multiscale_output_resolutions)
1554 |             assert min(multiscale_input_resolutions) > min(multiscale_output_resolutions)
1555 | 
1556 |         self.multiscale_output_resolutions = multiscale_output_resolutions
1557 | 
1558 |         assert all([*map(is_power_of_two, aux_recon_resolutions)])
1559 |         assert len(aux_recon_resolutions) == len(aux_recon_patch_dims) == len(aux_recon_frac_patches)
1560 | 
1561 |         self.aux_recon_resolutions_to_patches = {resolution: (patch_dim, frac_patches) for resolution, patch_dim, frac_patches in zip(aux_recon_resolutions, aux_recon_patch_dims, aux_recon_frac_patches)}
1562 | 
1563 |         self.resize_mode = resize_mode
1564 | 
1565 |         num_layers = int(log2(image_size) - 1)
1566 |         self.num_layers = num_layers
1567 |         self.image_size = image_size
1568 | 
1569 |         resolutions = image_size / ((2 ** torch.arange(num_layers)))
1570 |         resolutions = resolutions.long().tolist()
1571 | 
1572 |         dim_layers = (2 ** (torch.arange(num_layers) + 1)) * dim_capacity
1573 |         dim_layers = F.pad(dim_layers, (1, 0), value = channels)
1574 |         dim_layers.clamp_(max = dim_max)
1575 | 
1576 |         dim_layers = dim_layers.tolist()
1577 |         dim_last = dim_layers[-1]
1578 |         dim_pairs = list(zip(dim_layers[:-1], dim_layers[1:]))
1579 | 
1580 |         self.num_skip_layers_excite = num_skip_layers_excite
1581 | 
1582 |         self.residual_scale = 2 ** -0.5
1583 |         self.layers = nn.ModuleList([])
1584 | 
1585 |         upsample_dims = []
1586 |         predictor_dims = []
1587 |         dim_kernel_attn = (num_conv_kernels if num_conv_kernels > 1 else 0)
1588 | 
1589 |         for ind, ((dim_in, dim_out), resolution) in enumerate(zip(dim_pairs, resolutions)):
1590 |             is_first = ind == 0
1591 |             is_last = (ind + 1) == len(dim_pairs)
1592 |             should_downsample = not is_last
1593 |             should_skip_layer_excite = not is_first and num_skip_layers_excite > 0 and (ind + num_skip_layers_excite) < len(dim_pairs)
1594 | 
1595 |             has_attn = resolution in attn_resolutions
1596 |             has_multiscale_output = resolution in multiscale_output_resolutions
1597 | 
1598 |             has_aux_recon_decoder = resolution in aux_recon_resolutions
1599 |             upsample_dims.insert(0, dim_in)
1600 | 
1601 |             skip_squeeze_excite = None
1602 |             if should_skip_layer_excite:
1603 |                 dim_skip_in, _ = dim_pairs[ind + num_skip_layers_excite]
1604 |                 skip_squeeze_excite = SqueezeExcite(dim_in, dim_skip_in)
1605 | 
1606 |             # multi-scale rgb input to feature dimension
1607 | 
1608 |             from_rgb = nn.Conv2d(channels, dim_in, 7, padding = 3)
1609 | 
1610 |             # residual convolution
1611 | 
1612 |             residual_conv = nn.Conv2d(dim_in, dim_out, 1, stride = (2 if should_downsample else 1))
1613 | 
1614 |             # main resnet block
1615 | 
1616 |             resnet_block = nn.Sequential(
1617 |                 conv2d_3x3(dim_in, dim_out),
1618 |                 leaky_relu(),
1619 |                 conv2d_3x3(dim_out, dim_out),
1620 |                 leaky_relu()
1621 |             )
1622 | 
1623 |             # multi-scale output
1624 | 
1625 |             multiscale_output_predictor = None
1626 | 
1627 |             if has_multiscale_output:
1628 |                 multiscale_output_predictor = Predictor(dim_out, num_conv_kernels = num_conv_kernels, depth = 2, unconditional = unconditional)
1629 |                 predictor_dims.extend([dim_out, dim_kernel_attn])
1630 | 
1631 |             aux_recon_decoder = None
1632 | 
1633 |             if has_aux_recon_decoder:
1634 |                 patch_dim, frac_patches = self.aux_recon_resolutions_to_patches[resolution]
1635 | 
1636 |                 aux_recon_decoder = SimpleDecoder(
1637 |                     dim_out,
1638 |                     dims = tuple(upsample_dims),
1639 |                     patch_dim = patch_dim,
1640 |                     frac_patches = frac_patches,
1641 |                     dropout = aux_recon_fmap_dropout
1642 |                 )
1643 | 
1644 |             self.layers.append(nn.ModuleList([
1645 |                 skip_squeeze_excite,
1646 |                 from_rgb,
1647 |                 resnet_block,
1648 |                 residual_conv,
1649 |                 SelfAttentionBlock(dim_out, heads = attn_heads, dim_head = attn_dim_head, ff_mult = ff_mult, dot_product = self_attn_dot_product) if has_attn else None,
1650 |                 multiscale_output_predictor,
1651 |                 aux_recon_decoder,
1652 |                 Downsample(dim_out) if should_downsample else None,
1653 |             ]))
1654 | 
1655 |         self.to_logits = nn.Sequential(
1656 |             conv2d_3x3(dim_last, dim_last),
1657 |             Rearrange('b c h w -> b (c h w)'),
1658 |             nn.Linear(dim_last * (4 ** 2), 1),
1659 |             Rearrange('b 1 -> b')
1660 |         )
1661 | 
1662 |         # take care of text conditioning in the multiscale predictor branches
1663 | 
1664 |         assert unconditional or (exists(text_dim) ^ exists(text_encoder))
1665 | 
1666 |         if not unconditional:
1667 |             if isinstance(text_encoder, dict):
1668 |                 text_encoder = TextEncoder(**text_encoder)
1669 | 
1670 |             self.text_dim = default(text_dim, text_encoder.dim)
1671 | 
1672 |             self.predictor_dims = predictor_dims
1673 |             self.text_to_conv_conditioning = nn.Linear(self.text_dim, sum(predictor_dims)) if exists(self.text_dim) else None
1674 | 
1675 |         self.text_encoder = text_encoder
1676 | 
1677 |         self.apply(self.init_)
1678 | 
1679 |     def init_(self, m):
1680 |         if type(m) in {nn.Conv2d, nn.Linear}:
1681 |             nn.init.kaiming_normal_(m.weight, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu')
1682 | 
1683 |     def resize_image_to(self, images, resolution):
1684 |         return F.interpolate(images, resolution, mode = self.resize_mode)
1685 | 
1686 |     def real_images_to_rgbs(self, images):
1687 |         return [self.resize_image_to(images, resolution) for resolution in self.multiscale_input_resolutions]
1688 | 
1689 |     @property
1690 |     def total_params(self):
1691 |         return sum([p.numel() for p in self.parameters()])
1692 | 
1693 |     @property
1694 |     def device(self):
1695 |         return next(self.parameters()).device
1696 | 
1697 |     @beartype
1698 |     def forward(
1699 |         self,
1700 |         images,
1701 |         rgbs: List[Tensor],                   # multi-resolution inputs (rgbs) from the generator
1702 |         texts: List[str] | None = None,
1703 |         text_encodings: Tensor | None = None,
1704 |         text_embeds = None,
1705 |         real_images = None,                   # if this were passed in, the network will automatically append the real to the presumably generated images passed in as the first argument, and generate all intermediate resolutions through resizing and concat appropriately
1706 |         return_multiscale_outputs = True,     # can force it not to return multi-scale logits
1707 |         calc_aux_loss = True
1708 |     ):
1709 |         if not self.unconditional:
1710 |             assert (exists(texts) ^ exists(text_encodings)) ^ exists(text_embeds), 'either texts as List[str] is passed in, or clip text_encodings as Tensor'
1711 | 
1712 |             if exists(texts):
1713 |                 assert exists(self.text_encoder)
1714 |                 text_embeds, *_ = self.text_encoder(texts = texts)
1715 | 
1716 |             elif exists(text_encodings):
1717 |                 assert exists(self.text_encoder)
1718 |                 text_embeds, *_ = self.text_encoder(text_encodings = text_encodings)
1719 | 
1720 |             assert exists(text_embeds), 'raw text or text embeddings were not passed into discriminator for conditional training'
1721 | 
1722 |             conv_mods = self.text_to_conv_conditioning(text_embeds).split(self.predictor_dims, dim = -1)
1723 |             conv_mods = iter(conv_mods)
1724 | 
1725 |         else:
1726 |             assert not any([*map(exists, (texts, text_embeds))])
1727 | 
1728 |         x = images
1729 | 
1730 |         image_size = (self.image_size, self.image_size)
1731 | 
1732 |         assert x.shape[-2:] == image_size
1733 | 
1734 |         batch = x.shape[0]
1735 | 
1736 |         # index the rgbs by resolution
1737 | 
1738 |         rgbs_index = {t.shape[-1]: t for t in rgbs} if exists(rgbs) else {}
1739 | 
1740 |         # assert that the necessary resolutions are there
1741 | 
1742 |         assert is_empty(set(self.multiscale_input_resolutions) - set(rgbs_index.keys())), f'rgbs of necessary resolution {self.multiscale_input_resolutions} were not passed in'
1743 | 
1744 |         # hold multiscale outputs
1745 | 
1746 |         multiscale_outputs = []
1747 | 
1748 |         # hold auxiliary recon losses
1749 | 
1750 |         aux_recon_losses = []
1751 | 
1752 |         # excitations
1753 | 
1754 |         excitations = [None] * (self.num_skip_layers_excite + 1) # +1 since first image in pixel space is not excited
1755 | 
1756 |         for squeeze_excite, from_rgb, block, residual_fn, attn, predictor, recon_decoder, downsample in self.layers:
1757 |             resolution = x.shape[-1]
1758 | 
1759 |             if exists(squeeze_excite):
1760 |                 skip_excite = squeeze_excite(x)
1761 |                 excitations.append(skip_excite)
1762 | 
1763 |             excite = safe_unshift(excitations)
1764 | 
1765 |             if exists(excite):
1766 |                 excite = repeat(excite, 'b ... -> (s b) ...', s = x.shape[0] // excite.shape[0])
1767 |                 x = x * excite
1768 | 
1769 |             batch_prev_stage = x.shape[0]
1770 |             has_multiscale_input = resolution in self.multiscale_input_resolutions
1771 | 
1772 |             if has_multiscale_input:
1773 |                 rgb = rgbs_index.get(resolution, None)
1774 | 
1775 |                 # multi-scale input features
1776 | 
1777 |                 multi_scale_input_feats = from_rgb(rgb)
1778 | 
1779 |                 # expand multi-scale input features, as could include extra scales from previous stage
1780 | 
1781 |                 multi_scale_input_feats = repeat(multi_scale_input_feats, 'b ... -> (s b) ...', s = x.shape[0] // rgb.shape[0])
1782 | 
1783 |                 # add the multi-scale input features to the current hidden state from main stem
1784 | 
1785 |                 x = x + multi_scale_input_feats
1786 | 
1787 |                 # and also concat for scale invariance
1788 | 
1789 |                 x = torch.cat((x, multi_scale_input_feats), dim = 0)
1790 | 
1791 |             residual = residual_fn(x)
1792 |             x = block(x)
1793 | 
1794 |             if exists(attn):
1795 |                 x = attn(x)
1796 | 
1797 |             if exists(predictor):
1798 |                 pred_kwargs = dict()
1799 |                 if not self.unconditional:
1800 |                     pred_kwargs = dict(mod = next(conv_mods), kernel_mod = next(conv_mods))
1801 | 
1802 |                 if return_multiscale_outputs:
1803 |                     predictor_input = x[:batch_prev_stage]
1804 |                     multiscale_outputs.append(predictor(predictor_input, **pred_kwargs))
1805 | 
1806 |             if exists(downsample):
1807 |                 x = downsample(x)
1808 | 
1809 |             x = x + residual
1810 |             x = x * self.residual_scale
1811 | 
1812 |             if exists(recon_decoder) and calc_aux_loss:
1813 | 
1814 |                 recon_output = x[:batch_prev_stage]
1815 |                 recon_output = rearrange(x, '(s b) ... -> s b ...', b = batch)
1816 | 
1817 |                 aux_recon_target = images
1818 | 
1819 |                 # only use the input real images for aux recon
1820 | 
1821 |                 recon_output = recon_output[0]
1822 | 
1823 |                 # only reconstruct a fraction of images across batch and scale
1824 |                 # for efficiency
1825 | 
1826 |                 aux_recon_loss = recon_decoder(recon_output, aux_recon_target)
1827 |                 aux_recon_losses.append(aux_recon_loss)
1828 | 
1829 |         # sanity check
1830 | 
1831 |         assert self.unconditional or is_empty([*conv_mods]), 'convolutions were incorrectly modulated'
1832 | 
1833 |         # to logits
1834 | 
1835 |         logits = self.to_logits(x)   
1836 |         logits = rearrange(logits, '(s b) ... -> s b ...', b = batch)
1837 | 
1838 |         return logits, multiscale_outputs, aux_recon_losses
1839 | 
1840 | # gan
1841 | 
1842 | TrainDiscrLosses = namedtuple('TrainDiscrLosses', [
1843 |     'divergence',
1844 |     'multiscale_divergence',
1845 |     'vision_aided_divergence',
1846 |     'total_matching_aware_loss',
1847 |     'gradient_penalty',
1848 |     'aux_reconstruction'
1849 | ])
1850 | 
1851 | TrainGenLosses = namedtuple('TrainGenLosses', [
1852 |     'divergence',
1853 |     'multiscale_divergence',
1854 |     'total_vd_divergence',
1855 |     'contrastive_loss'
1856 | ])
1857 | 
1858 | class GigaGAN(nn.Module):
1859 |     @beartype
1860 |     def __init__(
1861 |         self,
1862 |         *,
1863 |         generator: BaseGenerator | Dict,
1864 |         discriminator: Discriminator | Dict,
1865 |         vision_aided_discriminator: VisionAidedDiscriminator | Dict | None = None,
1866 |         diff_augment: DiffAugment | Dict | None = None,
1867 |         learning_rate = 2e-4,
1868 |         betas = (0.5, 0.9),
1869 |         weight_decay = 0.,
1870 |         discr_aux_recon_loss_weight = 1.,
1871 |         multiscale_divergence_loss_weight = 0.1,
1872 |         vision_aided_divergence_loss_weight = 0.5,
1873 |         generator_contrastive_loss_weight = 0.1,
1874 |         matching_awareness_loss_weight = 0.1,
1875 |         calc_multiscale_loss_every = 1,
1876 |         apply_gradient_penalty_every = 4,
1877 |         resize_image_mode = 'bilinear',
1878 |         train_upsampler = False,
1879 |         log_steps_every = 20,
1880 |         create_ema_generator_at_init = True,
1881 |         save_and_sample_every = 1000,
1882 |         early_save_thres_steps = 2500,
1883 |         early_save_and_sample_every = 100,
1884 |         num_samples = 25,
1885 |         model_folder = './gigagan-models',
1886 |         results_folder = './gigagan-results',
1887 |         sample_upsampler_dl: DataLoader | None = None,
1888 |         accelerator: Accelerator | None = None,
1889 |         accelerate_kwargs: dict = {},
1890 |         find_unused_parameters = True,
1891 |         amp = False,
1892 |         mixed_precision_type = 'fp16'
1893 |     ):
1894 |         super().__init__()
1895 | 
1896 |         # create accelerator
1897 | 
1898 |         if accelerator:
1899 |             self.accelerator = accelerator
1900 |             assert is_empty(accelerate_kwargs)
1901 |         else:
1902 |             kwargs = DistributedDataParallelKwargs(find_unused_parameters = find_unused_parameters)
1903 | 
1904 |             self.accelerator = Accelerator(
1905 |                 kwargs_handlers = [kwargs],
1906 |                 mixed_precision = mixed_precision_type if amp else 'no',
1907 |                 **accelerate_kwargs
1908 |             )
1909 | 
1910 |         # whether to train upsampler or not
1911 | 
1912 |         self.train_upsampler = train_upsampler
1913 | 
1914 |         if train_upsampler:
1915 |             from gigagan_pytorch.unet_upsampler import UnetUpsampler
1916 |             generator_klass = UnetUpsampler
1917 |         else:
1918 |             generator_klass = Generator
1919 | 
1920 |         # gradient penalty and auxiliary recon loss
1921 | 
1922 |         self.apply_gradient_penalty_every = apply_gradient_penalty_every
1923 |         self.calc_multiscale_loss_every = calc_multiscale_loss_every
1924 | 
1925 |         if isinstance(generator, dict):
1926 |             generator = generator_klass(**generator)
1927 | 
1928 |         if isinstance(discriminator, dict):
1929 |             discriminator = Discriminator(**discriminator)
1930 | 
1931 |         if exists(vision_aided_discriminator) and isinstance(vision_aided_discriminator, dict):
1932 |             vision_aided_discriminator = VisionAidedDiscriminator(**vision_aided_discriminator)
1933 | 
1934 |         assert isinstance(generator, generator_klass)
1935 | 
1936 |         # diff augment
1937 | 
1938 |         if isinstance(diff_augment, dict):
1939 |             diff_augment = DiffAugment(**diff_augment)
1940 | 
1941 |         self.diff_augment = diff_augment
1942 | 
1943 |         # use _base to designate unwrapped models
1944 | 
1945 |         self.G = generator
1946 |         self.D = discriminator
1947 |         self.VD = vision_aided_discriminator
1948 | 
1949 |         # validate multiscale input resolutions
1950 | 
1951 |         if train_upsampler:
1952 |             assert is_empty(set(discriminator.multiscale_input_resolutions) - set(generator.allowable_rgb_resolutions)), f'only multiscale input resolutions of {generator.allowable_rgb_resolutions} is allowed based on the unet input and output image size. simply do Discriminator(multiscale_input_resolutions = unet.allowable_rgb_resolutions) to resolve this error'
1953 | 
1954 |         # ema
1955 | 
1956 |         self.has_ema_generator = False
1957 | 
1958 |         if self.is_main and create_ema_generator_at_init:
1959 |             self.create_ema_generator()
1960 | 
1961 |         # print number of parameters
1962 | 
1963 |         self.print('\n')
1964 | 
1965 |         self.print(f'Generator: {numerize.numerize(generator.total_params)}')
1966 |         self.print(f'Discriminator: {numerize.numerize(discriminator.total_params)}')
1967 | 
1968 |         if exists(self.VD):
1969 |             self.print(f'Vision Discriminator: {numerize.numerize(vision_aided_discriminator.total_params)}')
1970 | 
1971 |         self.print('\n')
1972 | 
1973 |         # text encoder
1974 | 
1975 |         assert generator.unconditional == discriminator.unconditional
1976 |         assert not exists(vision_aided_discriminator) or vision_aided_discriminator.unconditional == generator.unconditional
1977 | 
1978 |         self.unconditional = generator.unconditional
1979 | 
1980 |         # optimizers
1981 | 
1982 |         self.G_opt = get_optimizer(self.G.parameters(), lr = learning_rate, betas = betas, weight_decay = weight_decay)
1983 |         self.D_opt = get_optimizer(self.D.parameters(), lr = learning_rate, betas = betas, weight_decay = weight_decay)
1984 | 
1985 |         # prepare for distributed
1986 | 
1987 |         self.G, self.D, self.G_opt, self.D_opt = self.accelerator.prepare(self.G, self.D, self.G_opt, self.D_opt)
1988 | 
1989 |         # vision aided discriminator optimizer
1990 | 
1991 |         if exists(self.VD):
1992 |             self.VD_opt = get_optimizer(self.VD.parameters(), lr = learning_rate, betas = betas, weight_decay = weight_decay)
1993 |             self.VD_opt = self.accelerator.prepare(self.VD_opt)
1994 | 
1995 |         # loss related
1996 | 
1997 |         self.discr_aux_recon_loss_weight = discr_aux_recon_loss_weight
1998 |         self.multiscale_divergence_loss_weight = multiscale_divergence_loss_weight
1999 |         self.vision_aided_divergence_loss_weight = vision_aided_divergence_loss_weight
2000 |         self.generator_contrastive_loss_weight = generator_contrastive_loss_weight
2001 |         self.matching_awareness_loss_weight = matching_awareness_loss_weight
2002 | 
2003 |         # resize image mode
2004 | 
2005 |         self.resize_image_mode = resize_image_mode
2006 | 
2007 |         # steps
2008 | 
2009 |         self.log_steps_every = log_steps_every
2010 | 
2011 |         self.register_buffer('steps', torch.ones(1, dtype = torch.long))
2012 | 
2013 |         # save and sample
2014 | 
2015 |         self.save_and_sample_every = save_and_sample_every
2016 |         self.early_save_thres_steps = early_save_thres_steps
2017 |         self.early_save_and_sample_every = early_save_and_sample_every
2018 | 
2019 |         self.num_samples = num_samples
2020 | 
2021 |         self.train_dl = None
2022 | 
2023 |         self.sample_upsampler_dl_iter = None
2024 |         if exists(sample_upsampler_dl):
2025 |             self.sample_upsampler_dl_iter = cycle(self.sample_upsampler_dl)
2026 | 
2027 |         self.results_folder = Path(results_folder)
2028 |         self.model_folder = Path(model_folder)
2029 | 
2030 |         mkdir_if_not_exists(self.results_folder)
2031 |         mkdir_if_not_exists(self.model_folder)
2032 | 
2033 |     def save(self, path, overwrite = True):
2034 |         path = Path(path)
2035 |         mkdir_if_not_exists(path.parents[0])
2036 | 
2037 |         assert overwrite or not path.exists()
2038 | 
2039 |         pkg = dict(
2040 |             G = self.unwrapped_G.state_dict(),
2041 |             D = self.unwrapped_D.state_dict(),
2042 |             G_opt = self.G_opt.state_dict(),
2043 |             D_opt = self.D_opt.state_dict(),
2044 |             steps = self.steps.item(),
2045 |             version = __version__
2046 |         )
2047 | 
2048 |         if exists(self.G_opt.scaler):
2049 |             pkg['G_scaler'] = self.G_opt.scaler.state_dict()
2050 | 
2051 |         if exists(self.D_opt.scaler):
2052 |             pkg['D_scaler'] = self.D_opt.scaler.state_dict()
2053 | 
2054 |         if exists(self.VD):
2055 |             pkg['VD'] = self.unwrapped_VD.state_dict()
2056 |             pkg['VD_opt'] = self.VD_opt.state_dict()
2057 | 
2058 |             if exists(self.VD_opt.scaler):
2059 |                 pkg['VD_scaler'] = self.VD_opt.scaler.state_dict()
2060 | 
2061 |         if self.has_ema_generator:
2062 |             pkg['G_ema'] = self.G_ema.state_dict()
2063 | 
2064 |         torch.save(pkg, str(path))
2065 | 
2066 |     def load(self, path, strict = False):
2067 |         path = Path(path)
2068 |         assert path.exists()
2069 | 
2070 |         pkg = torch.load(str(path))
2071 | 
2072 |         if 'version' in pkg and pkg['version'] != __version__:
2073 |             print(f"trying to load from version {pkg['version']}")
2074 | 
2075 |         self.unwrapped_G.load_state_dict(pkg['G'], strict = strict)
2076 |         self.unwrapped_D.load_state_dict(pkg['D'], strict = strict)
2077 | 
2078 |         if exists(self.VD):
2079 |             self.unwrapped_VD.load_state_dict(pkg['VD'], strict = strict)
2080 | 
2081 |         if self.has_ema_generator:
2082 |             self.G_ema.load_state_dict(pkg['G_ema'])
2083 | 
2084 |         if 'steps' in pkg:
2085 |             self.steps.copy_(torch.tensor([pkg['steps']]))
2086 | 
2087 |         if 'G_opt'not in pkg or 'D_opt' not in pkg:
2088 |             return
2089 | 
2090 |         try:
2091 |             self.G_opt.load_state_dict(pkg['G_opt'])
2092 |             self.D_opt.load_state_dict(pkg['D_opt'])
2093 | 
2094 |             if exists(self.VD):
2095 |                 self.VD_opt.load_state_dict(pkg['VD_opt'])
2096 | 
2097 |             if 'G_scaler' in pkg and exists(self.G_opt.scaler):
2098 |                 self.G_opt.scaler.load_state_dict(pkg['G_scaler'])
2099 | 
2100 |             if 'D_scaler' in pkg and exists(self.D_opt.scaler):
2101 |                 self.D_opt.scaler.load_state_dict(pkg['D_scaler'])
2102 | 
2103 |             if 'VD_scaler' in pkg and exists(self.VD_opt.scaler):
2104 |                 self.VD_opt.scaler.load_state_dict(pkg['VD_scaler'])
2105 | 
2106 |         except Exception as e:
2107 |             self.print(f'unable to load optimizers {e.msg}- optimizer states will be reset')
2108 |             pass
2109 | 
2110 |     # accelerate related
2111 | 
2112 |     @property
2113 |     def device(self):
2114 |         return self.accelerator.device
2115 | 
2116 |     @property
2117 |     def unwrapped_G(self):
2118 |         return self.accelerator.unwrap_model(self.G)
2119 | 
2120 |     @property
2121 |     def unwrapped_D(self):
2122 |         return self.accelerator.unwrap_model(self.D)
2123 | 
2124 |     @property
2125 |     def unwrapped_VD(self):
2126 |         return self.accelerator.unwrap_model(self.VD)
2127 | 
2128 |     @property
2129 |     def need_vision_aided_discriminator(self):
2130 |         return exists(self.VD) and self.vision_aided_divergence_loss_weight > 0.
2131 | 
2132 |     @property
2133 |     def need_contrastive_loss(self):
2134 |         return self.generator_contrastive_loss_weight > 0. and not self.unconditional
2135 | 
2136 |     def print(self, msg):
2137 |         self.accelerator.print(msg)
2138 | 
2139 |     @property
2140 |     def is_distributed(self):
2141 |         return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
2142 | 
2143 |     @property
2144 |     def is_main(self):
2145 |         return self.accelerator.is_main_process
2146 | 
2147 |     @property
2148 |     def is_local_main(self):
2149 |         return self.accelerator.is_local_main_process
2150 | 
2151 |     def resize_image_to(self, images, resolution):
2152 |         return F.interpolate(images, resolution, mode = self.resize_image_mode)
2153 | 
2154 |     @beartype
2155 |     def set_dataloader(self, dl: DataLoader):
2156 |         assert not exists(self.train_dl), 'training dataloader has already been set'
2157 | 
2158 |         self.train_dl = dl
2159 |         self.train_dl_batch_size = dl.batch_size
2160 | 
2161 |         self.train_dl = self.accelerator.prepare(self.train_dl)
2162 | 
2163 |     # generate function
2164 | 
2165 |     @torch.inference_mode()
2166 |     def generate(self, *args, **kwargs):
2167 |         model = self.G_ema if self.has_ema_generator else self.G
2168 |         model.eval()
2169 |         return model(*args, **kwargs)
2170 | 
2171 |     # create EMA generator
2172 | 
2173 |     def create_ema_generator(
2174 |         self,
2175 |         update_every = 10,
2176 |         update_after_step = 100,
2177 |         decay = 0.995
2178 |     ):
2179 |         if not self.is_main:
2180 |             return
2181 | 
2182 |         assert not self.has_ema_generator, 'EMA generator has already been created'
2183 | 
2184 |         self.G_ema = EMA(self.unwrapped_G, update_every = update_every, update_after_step = update_after_step, beta = decay)
2185 |         self.has_ema_generator = True
2186 | 
2187 |     def generate_kwargs(self, dl_iter, batch_size):
2188 |         # what to pass into the generator
2189 |         # depends on whether training upsampler or not
2190 | 
2191 |         maybe_text_kwargs = dict()
2192 |         if self.train_upsampler or not self.unconditional:
2193 |             assert exists(dl_iter)
2194 | 
2195 |             if self.unconditional:
2196 |                 real_images = next(dl_iter)
2197 |             else:
2198 |                 result = next(dl_iter)
2199 |                 assert isinstance(result, tuple), 'dataset should return a tuple of two items for text conditioned training, (images: Tensor, texts: List[str])'
2200 |                 real_images, texts = result
2201 | 
2202 |                 maybe_text_kwargs['texts'] = texts[:batch_size]
2203 | 
2204 |             real_images = real_images.to(self.device)
2205 | 
2206 |         # if training upsample generator, need to downsample real images
2207 | 
2208 |         if self.train_upsampler:
2209 |             size = self.unwrapped_G.input_image_size
2210 |             lowres_real_images = F.interpolate(real_images, (size, size))
2211 | 
2212 |             G_kwargs = dict(lowres_image = lowres_real_images)
2213 |         else:
2214 |             assert exists(batch_size)
2215 | 
2216 |             G_kwargs = dict(batch_size = batch_size)
2217 | 
2218 |         # create noise
2219 | 
2220 |         noise = torch.randn(batch_size, self.unwrapped_G.style_network.dim, device = self.device)
2221 | 
2222 |         G_kwargs.update(noise = noise)
2223 | 
2224 |         return G_kwargs, maybe_text_kwargs
2225 |     
2226 |     @beartype
2227 |     def train_discriminator_step(
2228 |         self,
2229 |         dl_iter: Iterable,
2230 |         grad_accum_every = 1,
2231 |         apply_gradient_penalty = False,
2232 |         calc_multiscale_loss = True
2233 |     ):
2234 |         total_divergence = 0.
2235 |         total_vision_aided_divergence = 0.
2236 | 
2237 |         total_gp_loss = 0.
2238 |         total_aux_loss = 0.
2239 | 
2240 |         total_multiscale_divergence = 0. if calc_multiscale_loss else None
2241 | 
2242 |         has_matching_awareness = not self.unconditional and self.matching_awareness_loss_weight > 0.
2243 | 
2244 |         total_matching_aware_loss = 0.
2245 | 
2246 |         all_texts = []
2247 |         all_fake_images = []
2248 |         all_fake_rgbs = []
2249 |         all_real_images = []
2250 | 
2251 |         self.G.train()
2252 | 
2253 |         self.D.train()
2254 |         self.D_opt.zero_grad()
2255 | 
2256 |         if self.need_vision_aided_discriminator:
2257 |             self.VD.train()
2258 |             self.VD_opt.zero_grad()
2259 | 
2260 |         for _ in range(grad_accum_every):
2261 | 
2262 |             if self.unconditional:
2263 |                 real_images = next(dl_iter)
2264 |             else:
2265 |                 result = next(dl_iter)
2266 |                 assert isinstance(result, tuple), 'dataset should return a tuple of two items for text conditioned training, (images: Tensor, texts: List[str])'
2267 |                 real_images, texts = result
2268 | 
2269 |                 all_real_images.append(real_images)
2270 |                 all_texts.extend(texts)
2271 | 
2272 |             # requires grad for real images, for gradient penalty
2273 | 
2274 |             real_images = real_images.to(self.device)
2275 |             real_images.requires_grad_()
2276 | 
2277 |             real_images_rgbs = self.unwrapped_D.real_images_to_rgbs(real_images)
2278 | 
2279 |             # diff augment real images
2280 | 
2281 |             if exists(self.diff_augment):
2282 |                 real_images, real_images_rgbs = self.diff_augment(real_images, real_images_rgbs)
2283 | 
2284 |             # batch size
2285 | 
2286 |             batch_size = real_images.shape[0]
2287 | 
2288 |             # for discriminator training, fit upsampler and image synthesis logic under same function
2289 | 
2290 |             G_kwargs, maybe_text_kwargs = self.generate_kwargs(dl_iter, batch_size)
2291 | 
2292 |             # generator
2293 | 
2294 |             with torch.no_grad(), self.accelerator.autocast():
2295 |                 images, rgbs = self.G(
2296 |                     **G_kwargs,
2297 |                     **maybe_text_kwargs,
2298 |                     return_all_rgbs = True
2299 |                 )
2300 | 
2301 |                 all_fake_images.append(images)
2302 |                 all_fake_rgbs.append(rgbs)
2303 | 
2304 |                 # diff augment
2305 | 
2306 |                 if exists(self.diff_augment):
2307 |                     images, rgbs = self.diff_augment(images, rgbs)
2308 | 
2309 |                 # detach output of generator, as training discriminator only
2310 | 
2311 |                 images.detach_()
2312 |                 images.requires_grad_()
2313 | 
2314 |                 for rgb in rgbs:
2315 |                     rgb.detach_()
2316 |                     rgb.requires_grad_()
2317 | 
2318 |             # main divergence loss
2319 | 
2320 |             with self.accelerator.autocast():
2321 | 
2322 |                 fake_logits, fake_multiscale_logits, _ = self.D(
2323 |                     images,
2324 |                     rgbs,
2325 |                     **maybe_text_kwargs,
2326 |                     return_multiscale_outputs = calc_multiscale_loss,
2327 |                     calc_aux_loss = False
2328 |                 )
2329 | 
2330 |                 real_logits, real_multiscale_logits, aux_recon_losses = self.D(
2331 |                     real_images,
2332 |                     real_images_rgbs,
2333 |                     **maybe_text_kwargs,
2334 |                     return_multiscale_outputs = calc_multiscale_loss,
2335 |                     calc_aux_loss = True
2336 |                 )
2337 | 
2338 |                 divergence = discriminator_hinge_loss(real_logits, fake_logits)
2339 |                 total_divergence += (divergence.item() / grad_accum_every)
2340 | 
2341 |                 # handle multi-scale divergence
2342 | 
2343 |                 multiscale_divergence = 0.
2344 | 
2345 |                 if self.multiscale_divergence_loss_weight > 0. and len(fake_multiscale_logits) > 0:
2346 | 
2347 |                     for multiscale_fake, multiscale_real in zip(fake_multiscale_logits, real_multiscale_logits):
2348 |                         multiscale_loss = discriminator_hinge_loss(multiscale_real, multiscale_fake)
2349 |                         multiscale_divergence = multiscale_divergence + multiscale_loss
2350 | 
2351 |                     total_multiscale_divergence += (multiscale_divergence.item() / grad_accum_every)
2352 | 
2353 |                 # figure out gradient penalty if needed
2354 | 
2355 |                 gp_loss = 0.
2356 | 
2357 |                 if apply_gradient_penalty:
2358 |                     real_gp_loss = gradient_penalty(
2359 |                         real_images,
2360 |                         outputs = [real_logits, *real_multiscale_logits],
2361 |                         grad_output_weights = [1., *(self.multiscale_divergence_loss_weight,) * len(real_multiscale_logits)],
2362 |                         scaler = self.D_opt.scaler
2363 |                     )
2364 | 
2365 |                     fake_gp_loss = gradient_penalty(
2366 |                         images,
2367 |                         outputs = [fake_logits, *fake_multiscale_logits],
2368 |                         grad_output_weights = [1., *(self.multiscale_divergence_loss_weight,) * len(fake_multiscale_logits)],
2369 |                         scaler = self.D_opt.scaler
2370 |                     )
2371 | 
2372 |                     gp_loss = real_gp_loss + fake_gp_loss
2373 | 
2374 |                     if not torch.isnan(gp_loss):
2375 |                         total_gp_loss += (gp_loss.item() / grad_accum_every)
2376 | 
2377 |                 # handle vision aided discriminator, if needed
2378 | 
2379 |                 vd_loss = 0.
2380 | 
2381 |                 if self.need_vision_aided_discriminator:
2382 | 
2383 |                     fake_vision_aided_logits = self.VD(images, **maybe_text_kwargs)
2384 |                     real_vision_aided_logits, clip_encodings = self.VD(real_images, return_clip_encodings = True, **maybe_text_kwargs)
2385 | 
2386 |                     for fake_logits, real_logits in zip(fake_vision_aided_logits, real_vision_aided_logits):
2387 |                         vd_loss = vd_loss + discriminator_hinge_loss(real_logits, fake_logits)
2388 | 
2389 |                     total_vision_aided_divergence += (vd_loss.item() / grad_accum_every)
2390 | 
2391 |                     # handle gradient penalty for vision aided discriminator
2392 | 
2393 |                     if apply_gradient_penalty:
2394 | 
2395 |                         vd_gp_loss = gradient_penalty(
2396 |                             clip_encodings,
2397 |                             outputs = real_vision_aided_logits,
2398 |                             grad_output_weights = [self.vision_aided_divergence_loss_weight] * len(real_vision_aided_logits),
2399 |                             scaler = self.VD_opt.scaler
2400 |                         )
2401 | 
2402 |                         if not torch.isnan(vd_gp_loss):
2403 |                             gp_loss = gp_loss + vd_gp_loss
2404 | 
2405 |                             total_gp_loss += (vd_gp_loss.item() / grad_accum_every)
2406 | 
2407 |                 # sum up losses
2408 | 
2409 |                 total_loss = divergence + gp_loss
2410 | 
2411 |                 if self.multiscale_divergence_loss_weight > 0.:
2412 |                     total_loss = total_loss + multiscale_divergence * self.multiscale_divergence_loss_weight
2413 | 
2414 |                 if self.vision_aided_divergence_loss_weight > 0.:
2415 |                     total_loss = total_loss + vd_loss * self.vision_aided_divergence_loss_weight
2416 | 
2417 |                 if self.discr_aux_recon_loss_weight > 0.:
2418 |                     aux_loss = sum(aux_recon_losses)
2419 | 
2420 |                     total_aux_loss += (aux_loss.item() / grad_accum_every)
2421 | 
2422 |                     total_loss = total_loss + aux_loss * self.discr_aux_recon_loss_weight
2423 | 
2424 |             # backwards
2425 | 
2426 |             self.accelerator.backward(total_loss / grad_accum_every)
2427 | 
2428 | 
2429 |         # matching awareness loss
2430 |         # strategy would be to rotate the texts by one and assume batch is shuffled enough for mismatched conditions
2431 | 
2432 |         if has_matching_awareness:
2433 | 
2434 |             # rotate texts
2435 | 
2436 |             all_texts = [*all_texts[1:], all_texts[0]]
2437 |             all_texts = group_by_num_consecutive(texts, batch_size)
2438 | 
2439 |             zipped_data = zip(
2440 |                 all_fake_images,
2441 |                 all_fake_rgbs,
2442 |                 all_real_images,
2443 |                 all_texts
2444 |             )
2445 | 
2446 |             total_loss = 0.
2447 | 
2448 |             for fake_images, fake_rgbs, real_images, texts in zipped_data:
2449 | 
2450 |                 with self.accelerator.autocast():
2451 |                     fake_logits, *_ = self.D(
2452 |                         fake_images,
2453 |                         fake_rgbs,
2454 |                         texts = texts,
2455 |                         return_multiscale_outputs = False,
2456 |                         calc_aux_loss = False
2457 |                     )
2458 | 
2459 |                     real_images_rgbs = self.D.real_images_to_rgbs(real_images)
2460 | 
2461 |                     real_logits, *_ = self.D(
2462 |                         real_images,
2463 |                         real_images_rgbs,
2464 |                         texts = texts,
2465 |                         return_multiscale_outputs = False,
2466 |                         calc_aux_loss = False
2467 |                     )
2468 | 
2469 |                     matching_loss = aux_matching_loss(real_logits, fake_logits)
2470 | 
2471 |                     total_matching_aware_loss = (matching_loss.item() / grad_accum_every)
2472 | 
2473 |                     loss = matching_loss * self.matching_awareness_loss_weight
2474 | 
2475 |                 self.accelerator.backward(loss / grad_accum_every)
2476 | 
2477 |         self.D_opt.step()
2478 | 
2479 |         if self.need_vision_aided_discriminator:
2480 |             self.VD_opt.step()
2481 | 
2482 |         return TrainDiscrLosses(
2483 |             total_divergence,
2484 |             total_multiscale_divergence,
2485 |             total_vision_aided_divergence,
2486 |             total_matching_aware_loss,
2487 |             total_gp_loss,
2488 |             total_aux_loss
2489 |         )
2490 | 
2491 |     def train_generator_step(
2492 |         self,
2493 |         batch_size = None,
2494 |         dl_iter: Iterable | None = None,
2495 |         grad_accum_every = 1,
2496 |         calc_multiscale_loss = True
2497 |     ):
2498 |         total_divergence = 0.
2499 |         total_multiscale_divergence = 0. if calc_multiscale_loss else None
2500 |         total_vd_divergence = 0.
2501 |         contrastive_loss = 0.
2502 | 
2503 |         self.G.train()
2504 |         self.D.train()
2505 | 
2506 |         self.D_opt.zero_grad()
2507 |         self.G_opt.zero_grad()
2508 | 
2509 |         all_images = []
2510 |         all_texts = []
2511 | 
2512 |         for _ in range(grad_accum_every):
2513 | 
2514 |             # generator
2515 |             
2516 |             G_kwargs, maybe_text_kwargs = self.generate_kwargs(dl_iter, batch_size)
2517 | 
2518 |             with self.accelerator.autocast():
2519 |                 images, rgbs = self.G(
2520 |                     **G_kwargs,
2521 |                     **maybe_text_kwargs,
2522 |                     return_all_rgbs = True
2523 |                 )
2524 | 
2525 |                 # diff augment
2526 | 
2527 |                 if exists(self.diff_augment):
2528 |                     images, rgbs = self.diff_augment(images, rgbs)
2529 | 
2530 |                 # accumulate all images and texts for maybe contrastive loss
2531 | 
2532 |                 if self.need_contrastive_loss:
2533 |                     all_images.append(images)
2534 |                     all_texts.extend(maybe_text_kwargs['texts'])
2535 | 
2536 |                 # discriminator
2537 | 
2538 |                 logits, multiscale_logits, _ = self.D(
2539 |                     images,
2540 |                     rgbs,
2541 |                     **maybe_text_kwargs,
2542 |                     return_multiscale_outputs = calc_multiscale_loss,
2543 |                     calc_aux_loss = False
2544 |                 )
2545 | 
2546 |                 # generator hinge loss discriminator and multiscale
2547 | 
2548 |                 divergence = generator_hinge_loss(logits)
2549 | 
2550 |                 total_divergence += (divergence.item() / grad_accum_every)
2551 | 
2552 |                 total_loss = divergence
2553 | 
2554 |                 if self.multiscale_divergence_loss_weight > 0. and len(multiscale_logits) > 0:
2555 |                     multiscale_divergence = 0.
2556 | 
2557 |                     for multiscale_logit in multiscale_logits:
2558 |                         multiscale_divergence = multiscale_divergence + generator_hinge_loss(multiscale_logit)
2559 | 
2560 |                     total_multiscale_divergence += (multiscale_divergence.item() / grad_accum_every)
2561 | 
2562 |                     total_loss = total_loss + multiscale_divergence * self.multiscale_divergence_loss_weight
2563 | 
2564 |                 # vision aided generator hinge loss
2565 | 
2566 |                 if self.need_vision_aided_discriminator:
2567 |                     vd_loss = 0.
2568 | 
2569 |                     logits = self.VD(images, **maybe_text_kwargs)
2570 | 
2571 |                     for logit in logits:
2572 |                         vd_loss = vd_loss + generator_hinge_loss(logit)
2573 | 
2574 |                     total_vd_divergence += (vd_loss.item() / grad_accum_every)
2575 | 
2576 |                     total_loss = total_loss + vd_loss * self.vision_aided_divergence_loss_weight
2577 | 
2578 |             self.accelerator.backward(total_loss / grad_accum_every, retain_graph = self.need_contrastive_loss)
2579 | 
2580 |         # if needs the generator contrastive loss
2581 |         # gather up all images and texts and calculate it
2582 | 
2583 |         if self.need_contrastive_loss:
2584 |             all_images = torch.cat(all_images, dim = 0)
2585 | 
2586 |             contrastive_loss = aux_clip_loss(
2587 |                 clip = self.G.text_encoder.clip,
2588 |                 texts = all_texts,
2589 |                 images = all_images
2590 |             )
2591 | 
2592 |             self.accelerator.backward(contrastive_loss * self.generator_contrastive_loss_weight)
2593 | 
2594 |         # generator optimizer step
2595 | 
2596 |         self.G_opt.step()
2597 | 
2598 |         # update exponentially moving averaged generator
2599 | 
2600 |         self.accelerator.wait_for_everyone()
2601 | 
2602 |         if self.is_main and self.has_ema_generator:
2603 |             self.G_ema.update()
2604 | 
2605 |         return TrainGenLosses(
2606 |             total_divergence,
2607 |             total_multiscale_divergence,
2608 |             total_vd_divergence,
2609 |             contrastive_loss
2610 |         )
2611 | 
2612 |     def sample(self, model, dl_iter, batch_size):
2613 |         G_kwargs, maybe_text_kwargs = self.generate_kwargs(dl_iter, batch_size)
2614 | 
2615 |         with self.accelerator.autocast():
2616 |             generator_output = model(**G_kwargs, **maybe_text_kwargs)
2617 | 
2618 |         if not self.train_upsampler:
2619 |             return generator_output
2620 | 
2621 |         output_size = generator_output.shape[-1]
2622 |         lowres_image = G_kwargs['lowres_image']
2623 |         lowres_image = F.interpolate(lowres_image, (output_size, output_size))
2624 | 
2625 |         return torch.cat([lowres_image, generator_output])
2626 | 
2627 |     @torch.inference_mode()
2628 |     def save_sample(
2629 |         self,
2630 |         batch_size,
2631 |         dl_iter = None
2632 |     ):
2633 |         milestone = self.steps.item() // self.save_and_sample_every
2634 |         nrow_mult = 2 if self.train_upsampler else 1
2635 |         batches = num_to_groups(self.num_samples, batch_size)
2636 | 
2637 |         if self.train_upsampler:
2638 |             dl_iter = default(self.sample_upsampler_dl_iter, dl_iter)
2639 | 
2640 |         assert exists(dl_iter)
2641 | 
2642 |         sample_models_and_output_file_name = [(self.unwrapped_G, f'sample-{milestone}.png')]
2643 | 
2644 |         if self.has_ema_generator:
2645 |             sample_models_and_output_file_name.append((self.G_ema, f'ema-sample-{milestone}.png'))
2646 | 
2647 |         for model, filename in sample_models_and_output_file_name:
2648 |             model.eval()
2649 | 
2650 |             all_images_list = list(map(lambda n: self.sample(model, dl_iter, n), batches))
2651 |             all_images = torch.cat(all_images_list, dim = 0)
2652 | 
2653 |             all_images.clamp_(0., 1.)
2654 | 
2655 |             utils.save_image(
2656 |                 all_images,
2657 |                 str(self.results_folder / filename),
2658 |                 nrow = int(sqrt(self.num_samples)) * nrow_mult
2659 |             )
2660 | 
2661 |         # Possible to do: Include some metric to save if improved, include some sampler dict text entries
2662 |         self.save(str(self.model_folder / f'model-{milestone}.ckpt'))
2663 | 
2664 |     @beartype
2665 |     def forward(
2666 |         self,
2667 |         *,
2668 |         steps,
2669 |         grad_accum_every = 1
2670 |     ):
2671 |         assert exists(self.train_dl), 'you need to set the dataloader by running .set_dataloader(dl: Dataloader)'
2672 | 
2673 |         batch_size = self.train_dl_batch_size
2674 |         dl_iter = cycle(self.train_dl)
2675 | 
2676 |         last_gp_loss = 0.
2677 |         last_multiscale_d_loss = 0.
2678 |         last_multiscale_g_loss = 0.
2679 | 
2680 |         for _ in tqdm(range(steps), initial = self.steps.item()):
2681 |             steps = self.steps.item()
2682 |             is_first_step = steps == 1
2683 | 
2684 |             apply_gradient_penalty = self.apply_gradient_penalty_every > 0 and divisible_by(steps, self.apply_gradient_penalty_every)
2685 |             calc_multiscale_loss =  self.calc_multiscale_loss_every > 0 and divisible_by(steps, self.calc_multiscale_loss_every)
2686 | 
2687 |             (
2688 |                 d_loss,
2689 |                 multiscale_d_loss,
2690 |                 vision_aided_d_loss,
2691 |                 matching_aware_loss,
2692 |                 gp_loss,
2693 |                 recon_loss
2694 |             ) = self.train_discriminator_step(
2695 |                 dl_iter = dl_iter,
2696 |                 grad_accum_every = grad_accum_every,
2697 |                 apply_gradient_penalty = apply_gradient_penalty,
2698 |                 calc_multiscale_loss = calc_multiscale_loss
2699 |             )
2700 | 
2701 |             self.accelerator.wait_for_everyone()
2702 | 
2703 |             (
2704 |                 g_loss,
2705 |                 multiscale_g_loss,
2706 |                 vision_aided_g_loss,
2707 |                 contrastive_loss
2708 |             ) = self.train_generator_step(
2709 |                 dl_iter = dl_iter,
2710 |                 batch_size = batch_size,
2711 |                 grad_accum_every = grad_accum_every,
2712 |                 calc_multiscale_loss = calc_multiscale_loss
2713 |             )
2714 | 
2715 |             if exists(gp_loss):
2716 |                 last_gp_loss = gp_loss
2717 | 
2718 |             if exists(multiscale_d_loss):
2719 |                 last_multiscale_d_loss = multiscale_d_loss
2720 | 
2721 |             if exists(multiscale_g_loss):
2722 |                 last_multiscale_g_loss = multiscale_g_loss
2723 | 
2724 |             if is_first_step or divisible_by(steps, self.log_steps_every):
2725 | 
2726 |                 losses = (
2727 |                     ('G', g_loss),
2728 |                     ('MSG', last_multiscale_g_loss),
2729 |                     ('VG', vision_aided_g_loss),
2730 |                     ('D', d_loss),
2731 |                     ('MSD', last_multiscale_d_loss),
2732 |                     ('VD', vision_aided_d_loss),
2733 |                     ('GP', last_gp_loss),
2734 |                     ('SSL', recon_loss),
2735 |                     ('CL', contrastive_loss),
2736 |                     ('MAL', matching_aware_loss)
2737 |                 )
2738 | 
2739 |                 losses_str = ' | '.join([f'{loss_name}: {loss:.2f}' for loss_name, loss in losses])
2740 | 
2741 |                 self.print(losses_str)
2742 | 
2743 |             self.accelerator.wait_for_everyone()
2744 | 
2745 |             if self.is_main and (is_first_step or divisible_by(steps, self.save_and_sample_every) or (steps <= self.early_save_thres_steps and divisible_by(steps, self.early_save_and_sample_every))):
2746 |                 self.save_sample(batch_size, dl_iter)
2747 |             
2748 |             self.steps += 1
2749 | 
2750 |         self.print(f'complete {steps} training steps')
2751 | 


--------------------------------------------------------------------------------
/gigagan_pytorch/open_clip.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | from torch import nn, einsum
  3 | import torch.nn.functional as F
  4 | import open_clip
  5 | 
  6 | from einops import rearrange
  7 | 
  8 | from beartype import beartype
  9 | from beartype.typing import List, Optional
 10 | 
 11 | def exists(val):
 12 |     return val is not None
 13 | 
 14 | def l2norm(t):
 15 |     return F.normalize(t, dim = -1)
 16 | 
 17 | class OpenClipAdapter(nn.Module):
 18 |     @beartype
 19 |     def __init__(
 20 |         self,
 21 |         name = 'ViT-B/32',
 22 |         pretrained = 'laion400m_e32',
 23 |         tokenizer_name = 'ViT-B-32-quickgelu',
 24 |         eos_id = 49407
 25 |     ):
 26 |         super().__init__()
 27 | 
 28 |         clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained)
 29 |         tokenizer = open_clip.get_tokenizer(tokenizer_name)
 30 | 
 31 |         self.clip = clip
 32 |         self.tokenizer = tokenizer
 33 |         self.eos_id = eos_id
 34 | 
 35 |         # hook for getting final text representation
 36 | 
 37 |         text_attention_final = self.find_layer('ln_final')
 38 |         self._dim_latent = text_attention_final.weight.shape[0]
 39 |         self.text_handle = text_attention_final.register_forward_hook(self._text_hook)
 40 | 
 41 |         # hook for getting final image representation
 42 |         # this is for vision-aided gan loss
 43 | 
 44 |         self._dim_image_latent = self.find_layer('visual.ln_post').weight.shape[0]
 45 | 
 46 |         num_visual_layers = len(clip.visual.transformer.resblocks)
 47 |         self.image_handles = []
 48 | 
 49 |         for visual_layer in range(num_visual_layers):
 50 |             image_attention_final = self.find_layer(f'visual.transformer.resblocks.{visual_layer}')
 51 | 
 52 |             handle = image_attention_final.register_forward_hook(self._image_hook)
 53 |             self.image_handles.append(handle)
 54 | 
 55 |         # normalize fn
 56 | 
 57 |         self.clip_normalize = preprocess.transforms[-1]
 58 |         self.cleared = False
 59 | 
 60 |     @property
 61 |     def device(self):
 62 |         return next(self.parameters()).device
 63 | 
 64 |     def find_layer(self,  layer):
 65 |         modules = dict([*self.clip.named_modules()])
 66 |         return modules.get(layer, None)
 67 | 
 68 |     def clear(self):
 69 |         if self.cleared:
 70 |             return
 71 | 
 72 |         self.text_handle()
 73 |         self.image_handle()
 74 | 
 75 |     def _text_hook(self, _, inputs, outputs):
 76 |         self.text_encodings = outputs
 77 | 
 78 |     def _image_hook(self, _, inputs, outputs):
 79 |         if not hasattr(self, 'image_encodings'):
 80 |             self.image_encodings = []
 81 | 
 82 |         self.image_encodings.append(outputs)
 83 | 
 84 |     @property
 85 |     def dim_latent(self):
 86 |         return self._dim_latent
 87 | 
 88 |     @property
 89 |     def image_size(self):
 90 |         image_size = self.clip.visual.image_size
 91 |         if isinstance(image_size, tuple):
 92 |             return max(image_size)
 93 |         return image_size
 94 | 
 95 |     @property
 96 |     def image_channels(self):
 97 |         return 3
 98 | 
 99 |     @property
100 |     def max_text_len(self):
101 |         return self.clip.positional_embedding.shape[0]
102 | 
103 |     @beartype
104 |     def embed_texts(
105 |         self,
106 |         texts: List[str]
107 |     ):
108 |         ids = self.tokenizer(texts)
109 |         ids = ids.to(self.device)
110 |         ids = ids[..., :self.max_text_len]
111 | 
112 |         is_eos_id = (ids == self.eos_id)
113 |         text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
114 |         text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
115 |         text_mask = text_mask & (ids != 0)
116 |         assert not self.cleared
117 | 
118 |         text_embed = self.clip.encode_text(ids)
119 |         text_encodings = self.text_encodings
120 |         text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
121 |         del self.text_encodings
122 |         return l2norm(text_embed.float()), text_encodings.float()
123 | 
124 |     def embed_images(self, images):
125 |         if images.shape[-1] != self.image_size:
126 |             images = F.interpolate(images, self.image_size)
127 | 
128 |         assert not self.cleared
129 |         images = self.clip_normalize(images)
130 |         image_embeds = self.clip.encode_image(images)
131 | 
132 |         image_encodings = rearrange(self.image_encodings, 'l n b d -> l b n d')
133 |         del self.image_encodings
134 | 
135 |         return l2norm(image_embeds.float()), image_encodings.float()
136 | 
137 |     @beartype
138 |     def contrastive_loss(
139 |         self,
140 |         images,
141 |         texts: Optional[List[str]] = None,
142 |         text_embeds: Optional[torch.Tensor] = None
143 |     ):
144 |         assert exists(texts) ^ exists(text_embeds)
145 | 
146 |         if not exists(text_embeds):
147 |             text_embeds, _ = self.embed_texts(texts)
148 | 
149 |         image_embeds, _ = self.embed_images(images)
150 | 
151 |         n = text_embeds.shape[0]
152 | 
153 |         temperature = self.clip.logit_scale.exp()
154 |         sim = einsum('i d, j d -> i j', text_embeds, image_embeds) * temperature
155 | 
156 |         labels = torch.arange(n, device = sim.device)
157 | 
158 |         return (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
159 | 


--------------------------------------------------------------------------------
/gigagan_pytorch/optimizer.py:
--------------------------------------------------------------------------------
 1 | from torch.optim import AdamW, Adam
 2 | 
 3 | def separate_weight_decayable_params(params):
 4 |     wd_params, no_wd_params = [], []
 5 |     for param in params:
 6 |         param_list = no_wd_params if param.ndim < 2 else wd_params
 7 |         param_list.append(param)
 8 |     return wd_params, no_wd_params
 9 | 
10 | def get_optimizer(
11 |     params,
12 |     lr = 1e-4,
13 |     wd = 1e-2,
14 |     betas = (0.9, 0.99),
15 |     eps = 1e-8,
16 |     filter_by_requires_grad = True,
17 |     group_wd_params = True,
18 |     **kwargs
19 | ):
20 |     if filter_by_requires_grad:
21 |         params = list(filter(lambda t: t.requires_grad, params))
22 | 
23 |     if group_wd_params and wd > 0:
24 |         wd_params, no_wd_params = separate_weight_decayable_params(params)
25 | 
26 |         params = [
27 |             {'params': wd_params},
28 |             {'params': no_wd_params, 'weight_decay': 0},
29 |         ]
30 | 
31 |     if wd == 0:
32 |         return Adam(params, lr = lr, betas = betas, eps = eps)
33 | 
34 |     return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
35 | 


--------------------------------------------------------------------------------
/gigagan_pytorch/unet_upsampler.py:
--------------------------------------------------------------------------------
  1 | from __future__ import annotations
  2 | 
  3 | from math import log2
  4 | from functools import partial
  5 | from itertools import islice
  6 | 
  7 | import torch
  8 | from torch import nn
  9 | import torch.nn.functional as F
 10 | from torch.nn import Module, ModuleList
 11 | 
 12 | from einops import rearrange, repeat, pack, unpack
 13 | from einops.layers.torch import Rearrange
 14 | 
 15 | from gigagan_pytorch.attend import Attend
 16 | from gigagan_pytorch.gigagan_pytorch import (
 17 |     BaseGenerator,
 18 |     StyleNetwork,
 19 |     AdaptiveConv2DMod,
 20 |     AdaptiveConv1DMod,
 21 |     TextEncoder,
 22 |     CrossAttentionBlock,
 23 |     Upsample,
 24 |     PixelShuffleUpsample,
 25 |     Blur
 26 | )
 27 | 
 28 | from kornia.filters import filter3d, filter2d
 29 | 
 30 | from beartype import beartype
 31 | from beartype.typing import List, Dict, Iterable, Literal
 32 | 
 33 | # helpers functions
 34 | 
 35 | def exists(x):
 36 |     return x is not None
 37 | 
 38 | def default(val, d):
 39 |     if exists(val):
 40 |         return val
 41 |     return d() if callable(d) else d
 42 | 
 43 | def pack_one(t, pattern):
 44 |     return pack([t], pattern)
 45 | 
 46 | def unpack_one(t, ps, pattern):
 47 |     return unpack(t, ps, pattern)[0]
 48 | 
 49 | def cast_tuple(t, length = 1):
 50 |     if isinstance(t, tuple):
 51 |         return t
 52 |     return ((t,) * length)
 53 | 
 54 | def identity(t, *args, **kwargs):
 55 |     return t
 56 | 
 57 | def is_power_of_two(n):
 58 |     return log2(n).is_integer()
 59 | 
 60 | def null_iterator():
 61 |     while True:
 62 |         yield None
 63 | 
 64 | def fold_space_into_batch(x):
 65 |     x = rearrange(x, 'b c t h w -> b h w c t')
 66 |     x, ps = pack_one(x, '* c t')
 67 | 
 68 |     def split_space_from_batch(out):
 69 |         out = unpack_one(x, ps, '* c t')
 70 |         out = rearrange(out, 'b h w c t -> b c t h w')
 71 |         return out
 72 | 
 73 |     return x, split_space_from_batch
 74 | 
 75 | # small helper modules
 76 | 
 77 | def interpolate_1d(x, length, mode = 'bilinear'):
 78 |     x = rearrange(x, 'b c t -> b c t 1')
 79 |     x = F.interpolate(x, (length, 1), mode = mode)
 80 |     return rearrange(x, 'b c t 1 -> b c t')
 81 | 
 82 | class Downsample(Module):
 83 |     def __init__(
 84 |         self,
 85 |         dim,
 86 |         dim_out = None,
 87 |         skip_downsample = False,
 88 |         has_temporal_layers = False
 89 |     ):
 90 |         super().__init__()
 91 |         dim_out = default(dim_out, dim)
 92 | 
 93 |         self.skip_downsample = skip_downsample
 94 | 
 95 |         self.conv2d = nn.Conv2d(dim, dim_out, 3, padding = 1)
 96 | 
 97 |         self.has_temporal_layers = has_temporal_layers
 98 | 
 99 |         if has_temporal_layers:
100 |             self.conv1d = nn.Conv1d(dim_out, dim_out, 3, padding = 1)
101 | 
102 |             nn.init.dirac_(self.conv1d.weight)
103 |             nn.init.zeros_(self.conv1d.bias)
104 | 
105 |         self.register_buffer('filter', torch.Tensor([1., 2., 1.]))
106 | 
107 |     def forward(self, x):
108 |         batch = x.shape[0]
109 |         is_input_video = x.ndim == 5
110 | 
111 |         assert not (is_input_video and not self.has_temporal_layers)
112 | 
113 |         if is_input_video:
114 |             x = rearrange(x, 'b c t h w -> (b t) c h w')
115 | 
116 |         x = self.conv2d(x)
117 | 
118 |         if is_input_video:
119 |             x = rearrange(x, '(b t) c h w -> b h w c t', b = batch)
120 |             x, ps = pack_one(x, '* c t')
121 | 
122 |             x = self.conv1d(x)
123 | 
124 |             x = unpack_one(x, ps, '* c t')
125 |             x = rearrange(x, 'b h w c t -> b c t h w')
126 | 
127 |         # if not downsampling, early return
128 | 
129 |         if self.skip_downsample:
130 |             return x, x[:, 0:0]
131 | 
132 |         # save before blur to subtract out for high frequency fmap skip connection
133 | 
134 |         before_blur_input = x
135 | 
136 |         # blur 2d or 3d, depending
137 | 
138 |         f = self.filter
139 |         N = None
140 | 
141 |         if is_input_video:
142 |             f = f[N, N, :] * f[N, :, N] * f[:, N, N]
143 |             filter_fn = filter3d
144 |             maxpool_fn = F.max_pool3d
145 |         else:
146 |             f = f[N, :] * f[:, N]
147 |             filter_fn = filter2d
148 |             maxpool_fn = F.max_pool2d
149 | 
150 |         blurred = filter_fn(x, f[N, ...], normalized = True)
151 | 
152 |         # get high frequency fmap
153 | 
154 |         high_freq_fmap = before_blur_input - blurred
155 | 
156 |         # max pool 2d or 3d, depending
157 | 
158 |         x = maxpool_fn(x, kernel_size = 2)
159 | 
160 |         return x, high_freq_fmap
161 | 
162 | class TemporalBlur(Module):
163 |     def __init__(self):
164 |         super().__init__()
165 |         f = torch.Tensor([1, 2, 1])
166 |         self.register_buffer('f', f)
167 | 
168 |     def forward(self, x):
169 |         f = repeat(self.f, 't -> 1 t h w', h = 3, w = 3)
170 |         return filter3d(x, f, normalized = True)
171 | 
172 | class TemporalUpsample(Module):
173 |     def __init__(
174 |         self,
175 |         dim,
176 |         dim_out = None
177 |     ):
178 |         super().__init__()
179 |         self.blur = TemporalBlur()
180 | 
181 |     def forward(self, x):
182 |         assert x.ndim == 5
183 |         time = x.shape[2]
184 | 
185 |         x = rearrange(x, 'b c t h w -> b h w c t')
186 |         x, ps = pack_one(x, '* c t')
187 | 
188 |         x = interpolate_1d(x, time * 2, mode = 'bilinear')
189 | 
190 |         x = unpack_one(x, ps, '* c t')
191 |         x = rearrange(x, 'b h w c t -> b c t h w')
192 |         x = self.blur(x)
193 |         return x
194 | 
195 | class PixelShuffleTemporalUpsample(Module):
196 |     def __init__(self, dim, dim_out = None):
197 |         super().__init__()
198 |         dim_out = default(dim_out, dim)
199 | 
200 |         conv = nn.Conv3d(dim, dim_out * 2, 1)
201 | 
202 |         self.net = nn.Sequential(
203 |             conv,
204 |             nn.SiLU(),
205 |             Rearrange('b (c p) t h w -> b c (t p) h w', p = 2)
206 |         )
207 | 
208 |         self.init_conv_(conv)
209 | 
210 |     def init_conv_(self, conv):
211 |         o, i, t, h, w = conv.weight.shape
212 |         conv_weight = torch.empty(o // 2, i, t, h, w)
213 |         nn.init.kaiming_uniform_(conv_weight)
214 |         conv_weight = repeat(conv_weight, 'o ... -> (o 2) ...')
215 | 
216 |         conv.weight.data.copy_(conv_weight)
217 |         nn.init.zeros_(conv.bias.data)
218 | 
219 |     def forward(self, x):
220 |         return self.net(x)
221 | 
222 | # norm
223 | 
224 | class RMSNorm(Module):
225 |     def __init__(self, dim):
226 |         super().__init__()
227 |         self.scale = dim ** 0.5
228 |         self.gamma = nn.Parameter(torch.ones(dim))
229 | 
230 |     def forward(self, x):
231 |         spatial_dims = ((1,) * (x.ndim - 2))
232 |         gamma = self.gamma.reshape(-1, *spatial_dims)
233 | 
234 |         return F.normalize(x, dim = 1) * gamma * self.scale
235 | 
236 | # building block modules
237 | 
238 | class Block(Module):
239 |     @beartype
240 |     def __init__(
241 |         self,
242 |         dim,
243 |         dim_out,
244 |         num_conv_kernels = 0,
245 |         conv_type: Literal['1d', '2d'] = '2d',
246 |     ):
247 |         super().__init__()
248 | 
249 |         adaptive_conv_klass = AdaptiveConv2DMod if conv_type == '2d' else AdaptiveConv1DMod
250 | 
251 |         self.proj = adaptive_conv_klass(dim, dim_out, kernel = 3, num_conv_kernels = num_conv_kernels)
252 |         self.norm = RMSNorm(dim_out)
253 |         self.act = nn.SiLU()
254 | 
255 |     def forward(
256 |         self,
257 |         x,
258 |         conv_mods_iter: Iterable | None = None
259 |     ):
260 |         conv_mods_iter = default(conv_mods_iter, null_iterator())
261 | 
262 |         x = self.proj(
263 |             x,
264 |             mod = next(conv_mods_iter),
265 |             kernel_mod = next(conv_mods_iter)
266 |         )
267 | 
268 |         x = self.norm(x)
269 |         x = self.act(x)
270 |         return x
271 | 
272 | class ResnetBlock(Module):
273 |     @beartype
274 |     def __init__(
275 |         self,
276 |         dim,
277 |         dim_out,
278 |         *,
279 |         num_conv_kernels = 0,
280 |         conv_type: Literal['1d', '2d'] = '2d',
281 |         style_dims: List[int] = []
282 |     ):
283 |         super().__init__()
284 | 
285 |         mod_dims = [
286 |             dim,
287 |             num_conv_kernels,
288 |             dim_out,
289 |             num_conv_kernels
290 |         ]
291 | 
292 |         style_dims.extend(mod_dims)
293 | 
294 |         self.num_mods = len(mod_dims)
295 | 
296 |         self.block1 = Block(dim, dim_out, num_conv_kernels = num_conv_kernels, conv_type = conv_type)
297 |         self.block2 = Block(dim_out, dim_out, num_conv_kernels = num_conv_kernels, conv_type = conv_type)
298 | 
299 |         conv_klass = nn.Conv2d if conv_type == '2d' else nn.Conv1d
300 |         self.res_conv = conv_klass(dim, dim_out, 1) if dim != dim_out else nn.Identity()
301 | 
302 |     def forward(
303 |         self,
304 |         x,
305 |         conv_mods_iter: Iterable | None = None
306 |     ):
307 |         h = self.block1(x, conv_mods_iter = conv_mods_iter)
308 |         h = self.block2(h, conv_mods_iter = conv_mods_iter)
309 | 
310 |         return h + self.res_conv(x)
311 | 
312 | class LinearAttention(Module):
313 |     def __init__(
314 |         self,
315 |         dim,
316 |         heads = 4,
317 |         dim_head = 32
318 |     ):
319 |         super().__init__()
320 |         self.scale = dim_head ** -0.5
321 |         self.heads = heads
322 |         hidden_dim = dim_head * heads
323 | 
324 |         self.norm = RMSNorm(dim)
325 |         self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
326 | 
327 |         self.to_out = nn.Sequential(
328 |             nn.Conv2d(hidden_dim, dim, 1),
329 |             RMSNorm(dim)
330 |         )
331 | 
332 |     def forward(self, x):
333 |         b, c, h, w = x.shape
334 | 
335 |         x = self.norm(x)
336 | 
337 |         qkv = self.to_qkv(x).chunk(3, dim = 1)
338 |         q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
339 | 
340 |         q = q.softmax(dim = -2)
341 |         k = k.softmax(dim = -1)
342 | 
343 |         q = q * self.scale
344 | 
345 |         context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
346 | 
347 |         out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
348 |         out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
349 |         return self.to_out(out)
350 | 
351 | class Attention(Module):
352 |     def __init__(
353 |         self,
354 |         dim,
355 |         heads = 4,
356 |         dim_head = 32,
357 |         flash = False
358 |     ):
359 |         super().__init__()
360 |         self.heads = heads
361 |         hidden_dim = dim_head * heads
362 | 
363 |         self.norm = RMSNorm(dim)
364 |         self.attend = Attend(flash = flash)
365 | 
366 |         self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
367 |         self.to_out = nn.Conv2d(hidden_dim, dim, 1)
368 | 
369 |     def forward(self, x):
370 |         b, c, h, w = x.shape
371 | 
372 |         x = self.norm(x)
373 | 
374 |         qkv = self.to_qkv(x).chunk(3, dim = 1)
375 |         q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h (x y) c', h = self.heads), qkv)
376 | 
377 |         out = self.attend(q, k, v)
378 | 
379 |         out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
380 |         return self.to_out(out)
381 | 
382 | # feedforward
383 | 
384 | def FeedForward(dim, mult = 4):
385 |     return nn.Sequential(
386 |         RMSNorm(dim),
387 |         nn.Conv2d(dim, dim * mult, 1),
388 |         nn.GELU(),
389 |         nn.Conv2d(dim * mult, dim, 1)
390 |     )
391 | 
392 | # transformers
393 | 
394 | class Transformer(Module):
395 |     def __init__(
396 |         self,
397 |         dim,
398 |         dim_head = 64,
399 |         heads = 8,
400 |         depth = 1,
401 |         flash_attn = True,
402 |         ff_mult = 4
403 |     ):
404 |         super().__init__()
405 |         self.layers = ModuleList([])
406 | 
407 |         for _ in range(depth):
408 |             self.layers.append(ModuleList([
409 |                 Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash_attn),
410 |                 FeedForward(dim = dim, mult = ff_mult)
411 |             ]))
412 | 
413 |     def forward(self, x):
414 |         for attn, ff in self.layers:
415 |             x = attn(x) + x
416 |             x = ff(x) + x
417 | 
418 |         return x
419 | 
420 | class LinearTransformer(Module):
421 |     def __init__(
422 |         self,
423 |         dim,
424 |         dim_head = 64,
425 |         heads = 8,
426 |         depth = 1,
427 |         ff_mult = 4
428 |     ):
429 |         super().__init__()
430 |         self.layers = ModuleList([])
431 | 
432 |         for _ in range(depth):
433 |             self.layers.append(ModuleList([
434 |                 LinearAttention(dim = dim, dim_head = dim_head, heads = heads),
435 |                 FeedForward(dim = dim, mult = ff_mult)
436 |             ]))
437 | 
438 |     def forward(self, x):
439 |         for attn, ff in self.layers:
440 |             x = attn(x) + x
441 |             x = ff(x) + x
442 | 
443 |         return x
444 | 
445 | # model
446 | 
447 | class UnetUpsampler(BaseGenerator):
448 | 
449 |     @beartype
450 |     def __init__(
451 |         self,
452 |         dim,
453 |         *,
454 |         image_size,
455 |         input_image_size,
456 |         init_dim = None,
457 |         out_dim = None,
458 |         text_encoder: TextEncoder | Dict | None = None,
459 |         style_network: StyleNetwork | Dict | None = None,
460 |         style_network_dim = None,
461 |         dim_mults = (1, 2, 4, 8, 16),
462 |         channels = 3,
463 |         full_attn = (False, False, False, True, True),
464 |         cross_attn = (False, False, False, True, True),
465 |         flash_attn = True,
466 |         self_attn_dim_head = 64,
467 |         self_attn_heads = 8,
468 |         self_attn_dot_product = True,
469 |         self_attn_ff_mult = 4,
470 |         attn_depths = (1, 1, 1, 1, 1),
471 |         temporal_attn_depths = (1, 1, 1, 1, 1),
472 |         cross_attn_dim_head = 64,
473 |         cross_attn_heads = 8,
474 |         cross_ff_mult = 4,
475 |         has_temporal_layers = False,
476 |         mid_attn_depth = 1,
477 |         num_conv_kernels = 2,
478 |         unconditional = True,
479 |         skip_connect_scale = None
480 |     ):
481 |         super().__init__()
482 | 
483 |         # able to upsample video
484 | 
485 |         self.can_upsample_video = has_temporal_layers
486 | 
487 |         # style network
488 | 
489 |         if isinstance(text_encoder, dict):
490 |             text_encoder = TextEncoder(**text_encoder)
491 | 
492 |         self.text_encoder = text_encoder
493 | 
494 |         if isinstance(style_network, dict):
495 |             style_network = StyleNetwork(**style_network)
496 | 
497 |         self.style_network = style_network
498 | 
499 |         assert exists(style_network) ^ exists(style_network_dim), 'either style_network or style_network_dim must be passed in'
500 | 
501 |         # validate text conditioning and style network hparams
502 | 
503 |         self.unconditional = unconditional
504 |         assert unconditional ^ exists(text_encoder), 'if unconditional, text encoder should not be given, and vice versa'
505 |         assert not (unconditional and exists(style_network) and style_network.dim_text_latent > 0)
506 |         assert unconditional or text_encoder.dim == style_network.dim_text_latent, 'the `dim_text_latent` on your StyleNetwork must be equal to the `dim` set for the TextEncoder'
507 | 
508 |         assert is_power_of_two(image_size) and is_power_of_two(input_image_size), 'both output image size and input image size must be power of 2'
509 |         assert input_image_size < image_size, 'input image size must be smaller than the output image size, thus upsampling'
510 | 
511 |         num_layer_no_downsample = int(log2(image_size) - log2(input_image_size))
512 |         assert num_layer_no_downsample <= len(dim_mults), 'you need more stages in this unet for the level of upsampling'
513 | 
514 |         self.image_size = image_size
515 |         self.input_image_size = input_image_size
516 | 
517 |         # setup adaptive conv
518 | 
519 |         style_embed_split_dims = []
520 | 
521 |         # determine dimensions
522 | 
523 |         self.channels = channels
524 |         input_channels = channels
525 | 
526 |         init_dim = default(init_dim, dim)
527 |         self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
528 | 
529 |         dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
530 | 
531 |         *_, mid_dim = dims
532 | 
533 |         in_out = list(zip(dims[:-1], dims[1:]))
534 | 
535 |         block_klass = partial(
536 |             ResnetBlock,
537 |             num_conv_kernels = num_conv_kernels,
538 |             style_dims = style_embed_split_dims
539 |         )
540 | 
541 |         # attention
542 | 
543 |         full_attn = cast_tuple(full_attn, length = len(dim_mults))
544 |         assert len(full_attn) == len(dim_mults)
545 | 
546 |         FullAttention = partial(Transformer, flash_attn = flash_attn)
547 | 
548 |         cross_attn = cast_tuple(cross_attn, length = len(dim_mults))
549 |         assert unconditional or len(full_attn) == len(dim_mults)
550 | 
551 |         # skip connection scale
552 | 
553 |         self.skip_connect_scale = default(skip_connect_scale, 2 ** -0.5)
554 | 
555 |         # layers
556 | 
557 |         self.downs = ModuleList([])
558 |         self.ups = ModuleList([])
559 |         num_resolutions = len(in_out)
560 |         skip_connect_dims = []
561 | 
562 |         for ind, ((dim_in, dim_out), layer_full_attn, layer_cross_attn, layer_attn_depth, layer_temporal_attn_depth) in enumerate(zip(in_out, full_attn, cross_attn, attn_depths, temporal_attn_depths)):
563 | 
564 |             should_not_downsample = ind < num_layer_no_downsample
565 |             has_cross_attn = not self.unconditional and layer_cross_attn
566 | 
567 |             attn_klass = FullAttention if layer_full_attn else LinearTransformer
568 | 
569 |             skip_connect_dims.append(dim_in)
570 |             skip_connect_dims.append(dim_in + (dim_out if not should_not_downsample else 0))
571 | 
572 |             temporal_resnet_block = None
573 |             temporal_attn = None
574 | 
575 |             if has_temporal_layers:
576 |                 temporal_resnet_block = block_klass(dim_in, dim_in, conv_type = '1d')
577 |                 temporal_attn = FullAttention(dim_in, dim_head = self_attn_dim_head, heads = self_attn_heads, depth = layer_temporal_attn_depth)
578 | 
579 |             # all unet downsample stages
580 | 
581 |             self.downs.append(ModuleList([
582 |                 block_klass(dim_in, dim_in),
583 |                 block_klass(dim_in, dim_in),
584 |                 CrossAttentionBlock(dim_in, dim_context = text_encoder.dim, dim_head = self_attn_dim_head, heads = self_attn_heads, ff_mult = self_attn_ff_mult) if has_cross_attn else None,
585 |                 attn_klass(dim_in, dim_head = self_attn_dim_head, heads = self_attn_heads, depth = layer_attn_depth),
586 |                 temporal_resnet_block,
587 |                 temporal_attn,
588 |                 Downsample(dim_in, dim_out, skip_downsample = should_not_downsample, has_temporal_layers = has_temporal_layers)
589 |             ]))
590 | 
591 |         self.mid_block1 = block_klass(mid_dim, mid_dim)
592 |         self.mid_attn = FullAttention(mid_dim, dim_head = self_attn_dim_head, heads = self_attn_heads, depth = mid_attn_depth)
593 |         self.mid_block2 = block_klass(mid_dim, mid_dim)
594 |         self.mid_to_rgb = nn.Conv2d(mid_dim, channels, 1)
595 | 
596 |         for ind, ((dim_in, dim_out), layer_cross_attn, layer_full_attn, layer_attn_depth, layer_temporal_attn_depth) in enumerate(zip(reversed(in_out), reversed(full_attn), reversed(cross_attn), reversed(attn_depths), reversed(temporal_attn_depths))):
597 | 
598 |             attn_klass = FullAttention if layer_full_attn else LinearTransformer
599 |             has_cross_attn = not self.unconditional and layer_cross_attn
600 | 
601 |             temporal_upsample = None
602 |             temporal_upsample_rgb = None
603 |             temporal_resnet_block = None
604 |             temporal_attn = None
605 | 
606 |             if has_temporal_layers:
607 |                 temporal_upsample = PixelShuffleTemporalUpsample(dim_in, dim_in)
608 |                 temporal_upsample_rgb = TemporalUpsample(dim_in, dim_in)
609 | 
610 |                 temporal_resnet_block = block_klass(dim_in, dim_in, conv_type = '1d')
611 |                 temporal_attn = FullAttention(dim_in, dim_head = self_attn_dim_head, heads = self_attn_heads, depth = layer_temporal_attn_depth)
612 | 
613 |             self.ups.append(ModuleList([
614 |                 PixelShuffleUpsample(dim_out, dim_in),
615 |                 Upsample(),
616 |                 temporal_upsample,
617 |                 temporal_upsample_rgb,
618 |                 nn.Conv2d(dim_in, channels, 1),
619 |                 block_klass(dim_in + skip_connect_dims.pop(), dim_in),
620 |                 block_klass(dim_in + skip_connect_dims.pop(), dim_in),
621 |                 CrossAttentionBlock(dim_in, dim_context = text_encoder.dim, dim_head = self_attn_dim_head, heads = self_attn_heads, ff_mult = cross_ff_mult) if has_cross_attn else None,
622 |                 attn_klass(dim_in, dim_head = cross_attn_dim_head, heads = self_attn_heads, depth = layer_attn_depth),
623 |                 temporal_resnet_block,
624 |                 temporal_attn
625 |             ]))
626 | 
627 |         self.out_dim = default(out_dim, channels)
628 | 
629 |         self.final_res_block = block_klass(dim, dim)
630 | 
631 |         self.final_to_rgb = nn.Conv2d(dim, channels, 1)
632 | 
633 |         # determine the projection of the style embedding to convolutional modulation weights (+ adaptive kernel selection weights) for all layers
634 | 
635 |         self.style_to_conv_modulations = nn.Linear(style_network.dim, sum(style_embed_split_dims))
636 |         self.style_embed_split_dims = style_embed_split_dims
637 | 
638 |     @property
639 |     def allowable_rgb_resolutions(self):
640 |         input_res_base = int(log2(self.input_image_size))
641 |         output_res_base = int(log2(self.image_size))
642 |         allowed_rgb_res_base = list(range(input_res_base, output_res_base))
643 |         return [*map(lambda p: 2 ** p, allowed_rgb_res_base)]
644 | 
645 |     @property
646 |     def device(self):
647 |         return next(self.parameters()).device
648 | 
649 |     @property
650 |     def total_params(self):
651 |         return sum([p.numel() for p in self.parameters()])
652 | 
653 |     def resize_to_same_dimensions(self, x, size):
654 |         mode = 'trilinear' if x.ndim == 5 else 'bilinear'
655 |         return F.interpolate(x, tuple(size), mode = mode)
656 | 
657 |     def forward(
658 |         self,
659 |         lowres_image_or_video,
660 |         styles = None,
661 |         noise = None,
662 |         texts: List[str] | None = None,
663 |         global_text_tokens = None,
664 |         fine_text_tokens = None,
665 |         text_mask = None,
666 |         return_all_rgbs = False,
667 |         replace_rgb_with_input_lowres_image = True   # discriminator should also receive the low resolution image the upsampler sees
668 |     ):
669 |         x = lowres_image_or_video
670 |         shape = x.shape
671 |         batch_size = shape[0]
672 | 
673 |         assert shape[-2:] == ((self.input_image_size,) * 2)
674 | 
675 |         # take care of text encodings
676 |         # which requires global text tokens to adaptively select the kernels from the main contribution in the paper
677 |         # and fine text tokens to attend to using cross attention
678 | 
679 |         if not self.unconditional:
680 |             if exists(texts):
681 |                 assert exists(self.text_encoder)
682 |                 global_text_tokens, fine_text_tokens, text_mask = self.text_encoder(texts)
683 |             else:
684 |                 assert all([*map(exists, (global_text_tokens, fine_text_tokens, text_mask))])
685 |         else:
686 |             assert not any([*map(exists, (texts, global_text_tokens, fine_text_tokens))])
687 | 
688 |         # styles
689 | 
690 |         if not exists(styles):
691 |             assert exists(self.style_network)
692 | 
693 |             noise = default(noise, torch.randn((batch_size, self.style_network.dim), device = self.device))
694 |             styles = self.style_network(noise, global_text_tokens)
695 | 
696 |         # project styles to conv modulations
697 | 
698 |         conv_mods = self.style_to_conv_modulations(styles)
699 |         conv_mods = conv_mods.split(self.style_embed_split_dims, dim = -1)
700 |         conv_mods = iter(conv_mods)
701 | 
702 |         # first detect whether input is image or video and handle accordingly
703 | 
704 |         input_is_video = lowres_image_or_video.ndim == 5
705 |         assert not (not self.can_upsample_video and input_is_video), 'this network cannot upsample video unless you set `has_temporal_layers = True`'
706 | 
707 |         fold_time_into_batch = identity
708 |         split_time_from_batch = identity
709 | 
710 |         if input_is_video:
711 |             fold_time_into_batch = lambda t: rearrange(t, 'b c t h w -> (b t) c h w')
712 |             split_time_from_batch = lambda t: rearrange(t, '(b t) c h w -> b c t h w', b = batch_size)
713 | 
714 |         x = fold_time_into_batch(x)
715 | 
716 |         # set lowres_images for final rgb output
717 | 
718 |         lowres_images = x
719 | 
720 |         # initial conv
721 | 
722 |         x = self.init_conv(x)
723 | 
724 |         h = []
725 | 
726 |         # downsample stages
727 | 
728 |         for (
729 |             block1,
730 |             block2,
731 |             cross_attn,
732 |             attn,
733 |             temporal_block,
734 |             temporal_attn,
735 |             downsample,
736 |         ) in self.downs:
737 | 
738 |             x = block1(x, conv_mods_iter = conv_mods)
739 |             h.append(x)
740 | 
741 |             x = block2(x, conv_mods_iter = conv_mods)
742 | 
743 |             x = attn(x)
744 | 
745 |             if exists(cross_attn):
746 |                 x = cross_attn(x, context = fine_text_tokens, mask = text_mask)
747 | 
748 |             if input_is_video:
749 |                 x = split_time_from_batch(x)
750 |                 x, split_space_back = fold_space_into_batch(x)
751 | 
752 |                 x = temporal_block(x, conv_mods_iter = conv_mods)
753 | 
754 |                 x = rearrange(x, 'b c t -> b c t 1')
755 |                 x = temporal_attn(x)
756 |                 x = rearrange(x, 'b c t 1 -> b c t')
757 | 
758 |                 x = split_space_back(x)
759 |                 x = fold_time_into_batch(x)
760 | 
761 |             elif self.can_upsample_video:
762 |                 conv_mods = islice(conv_mods, temporal_block.num_mods, None)
763 | 
764 |             skip_connect = x
765 | 
766 |             # downsample with hf shuttle
767 | 
768 |             x = split_time_from_batch(x)
769 | 
770 |             x, hf_fmap = downsample(x)
771 | 
772 |             x = fold_time_into_batch(x)
773 |             hf_fmap = fold_time_into_batch(hf_fmap)
774 | 
775 |             # add high freq fmap to skip connection as proposed in videogigagan
776 | 
777 |             skip_connect = torch.cat((skip_connect, hf_fmap), dim = 1)
778 | 
779 |             h.append(skip_connect)
780 | 
781 |         x = self.mid_block1(x, conv_mods_iter = conv_mods)
782 |         x = self.mid_attn(x)
783 |         x = self.mid_block2(x, conv_mods_iter = conv_mods)
784 | 
785 |         # rgbs
786 | 
787 |         rgbs = []
788 | 
789 |         init_rgb_shape = list(x.shape)
790 |         init_rgb_shape[1] = self.channels
791 | 
792 |         rgb = self.mid_to_rgb(x)
793 |         rgbs.append(rgb)
794 | 
795 |         # upsample stages
796 | 
797 |         for (
798 |             upsample,
799 |             upsample_rgb,
800 |             temporal_upsample,
801 |             temporal_upsample_rgb,
802 |             to_rgb,
803 |             block1,
804 |             block2,
805 |             cross_attn,
806 |             attn,
807 |             temporal_block,
808 |             temporal_attn,
809 |         ) in self.ups:
810 | 
811 |             x = upsample(x)
812 |             rgb = upsample_rgb(rgb)
813 | 
814 |             if input_is_video:
815 |                 x = split_time_from_batch(x)
816 |                 rgb = split_time_from_batch(rgb)
817 | 
818 |                 x = temporal_upsample(x)
819 |                 rgb = temporal_upsample_rgb(rgb)
820 | 
821 |                 x = fold_time_into_batch(x)
822 |                 rgb = fold_time_into_batch(rgb)
823 | 
824 |             res1 = h.pop() * self.skip_connect_scale
825 |             res2 = h.pop() * self.skip_connect_scale
826 | 
827 |             # handle skip connections not being the same shape
828 | 
829 |             if x.shape[0] != res1.shape[0] or x.shape[2:] != res1.shape[2:]:
830 |                 x = split_time_from_batch(x)
831 |                 res1 = split_time_from_batch(res1)
832 |                 res2 = split_time_from_batch(res2)
833 | 
834 |                 res1 = self.resize_to_same_dimensions(res1, x.shape[2:])
835 |                 res2 = self.resize_to_same_dimensions(res2, x.shape[2:])
836 | 
837 |                 x = fold_time_into_batch(x)
838 |                 res1 = fold_time_into_batch(res1)
839 |                 res2 = fold_time_into_batch(res2)
840 | 
841 |             # concat skip connections
842 | 
843 |             x = torch.cat((x, res1), dim = 1)
844 |             x = block1(x, conv_mods_iter = conv_mods)
845 | 
846 |             x = torch.cat((x, res2), dim = 1)
847 |             x = block2(x, conv_mods_iter = conv_mods)
848 | 
849 |             if exists(cross_attn):
850 |                 x = cross_attn(x, context = fine_text_tokens, mask = text_mask)
851 | 
852 |             x = attn(x)
853 | 
854 |             if input_is_video:
855 |                 x = split_time_from_batch(x)
856 |                 x, split_space_back = fold_space_into_batch(x)
857 | 
858 |                 x = temporal_block(x, conv_mods_iter = conv_mods)
859 | 
860 |                 x = rearrange(x, 'b c t -> b c t 1')
861 |                 x = temporal_attn(x)
862 |                 x = rearrange(x, 'b c t 1 -> b c t')
863 | 
864 |                 x = split_space_back(x)
865 |                 x = fold_time_into_batch(x)
866 | 
867 |             elif self.can_upsample_video:
868 |                 conv_mods = islice(conv_mods, temporal_block.num_mods, None)
869 | 
870 |             rgb = rgb + to_rgb(x)
871 |             rgbs.append(rgb)
872 | 
873 |         x = self.final_res_block(x, conv_mods_iter = conv_mods)
874 | 
875 |         assert len([*conv_mods]) == 0
876 | 
877 |         rgb = rgb + self.final_to_rgb(x)
878 | 
879 |         # handle video input
880 | 
881 |         if input_is_video:
882 |             rgb = split_time_from_batch(rgb)
883 | 
884 |         if not return_all_rgbs:
885 |             return rgb
886 | 
887 |         # only keep those rgbs whose feature map is greater than the input image to be upsampled
888 | 
889 |         rgbs = list(filter(lambda t: t.shape[-1] > shape[-1], rgbs))
890 | 
891 |         # and return the original input image as the smallest rgb
892 | 
893 |         rgbs = [lowres_images, *rgbs]
894 | 
895 |         if input_is_video:
896 |             rgbs = [*map(split_time_from_batch, rgbs)]
897 | 
898 |         return rgb, rgbs
899 | 


--------------------------------------------------------------------------------
/gigagan_pytorch/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.3.0'
2 | 


--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.ruff]
2 | line-length = 1000
3 | ignore-init-module-imports = true
4 | exclude = ["setup.py"]


--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
 1 | from setuptools import setup, find_packages
 2 | 
 3 | exec(open('gigagan_pytorch/version.py').read())
 4 | 
 5 | setup(
 6 |   name = 'gigagan-pytorch',
 7 |   packages = find_packages(exclude=[]),
 8 |   version = __version__,
 9 |   license='MIT',
10 |   description = 'GigaGAN - Pytorch',
11 |   author = 'Phil Wang',
12 |   author_email = 'lucidrains@gmail.com',
13 |   long_description_content_type = 'text/markdown',
14 |   url = 'https://github.com/lucidrains/ETSformer-pytorch',
15 |   keywords = [
16 |     'artificial intelligence',
17 |     'deep learning',
18 |     'generative adversarial networks'
19 |   ],
20 |   install_requires=[
21 |     'accelerate',
22 |     'beartype',
23 |     'einops>=0.6',
24 |     'ema-pytorch',
25 |     'kornia',
26 |     'numerize',
27 |     'open-clip-torch>=2.0.0,<3.0.0',
28 |     'pillow',
29 |     'torch>=1.6',
30 |     'torchvision',
31 |     'tqdm'
32 |   ],
33 |   classifiers=[
34 |     'Development Status :: 4 - Beta',
35 |     'Intended Audience :: Developers',
36 |     'Topic :: Scientific/Engineering :: Artificial Intelligence',
37 |     'License :: OSI Approved :: MIT License',
38 |     'Programming Language :: Python :: 3.6',
39 |   ],
40 | )
41 | 


--------------------------------------------------------------------------------