├── gigagan_pytorch
├── version.py
├── __init__.py
├── optimizer.py
├── distributed.py
├── attend.py
├── data.py
├── open_clip.py
├── unet_upsampler.py
└── gigagan_pytorch.py
├── gigagan-sample.png
├── gigagan-architecture.png
├── pyproject.toml
├── .pre-commit-config.yaml
├── LICENSE
├── setup.py
├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
└── README.md
/gigagan_pytorch/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.2.20'
2 |
--------------------------------------------------------------------------------
/gigagan-sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wetdog/gigagan-pytorch/main/gigagan-sample.png
--------------------------------------------------------------------------------
/gigagan-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wetdog/gigagan-pytorch/main/gigagan-architecture.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.ruff]
2 | line-length = 1000
3 | ignore-init-module-imports = true
4 | exclude = ["setup.py"]
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | repos:
3 | - repo: https://github.com/astral-sh/ruff-pre-commit
4 | rev: v0.0.278
5 | hooks:
6 | - id: ruff
7 | args: [ --fix, --exit-non-zero-on-fix]
8 |
--------------------------------------------------------------------------------
/gigagan_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from gigagan_pytorch.gigagan_pytorch import (
2 | GigaGAN,
3 | Generator,
4 | Discriminator,
5 | VisionAidedDiscriminator,
6 | AdaptiveConv2DMod,
7 | StyleNetwork,
8 | TextEncoder
9 | )
10 |
11 | from gigagan_pytorch.unet_upsampler import UnetUpsampler
12 |
13 | from gigagan_pytorch.data import (
14 | ImageDataset,
15 | TextImageDataset,
16 | MockTextImageDataset
17 | )
18 |
19 | __all__ = [
20 | GigaGAN,
21 | Generator,
22 | Discriminator,
23 | VisionAidedDiscriminator,
24 | AdaptiveConv2DMod,
25 | StyleNetwork,
26 | UnetUpsampler,
27 | TextEncoder,
28 | ImageDataset,
29 | TextImageDataset,
30 | MockTextImageDataset
31 | ]
32 |
--------------------------------------------------------------------------------
/gigagan_pytorch/optimizer.py:
--------------------------------------------------------------------------------
1 | from torch.optim import AdamW, Adam
2 |
3 | def separate_weight_decayable_params(params):
4 | wd_params, no_wd_params = [], []
5 | for param in params:
6 | param_list = no_wd_params if param.ndim < 2 else wd_params
7 | param_list.append(param)
8 | return wd_params, no_wd_params
9 |
10 | def get_optimizer(
11 | params,
12 | lr = 1e-4,
13 | wd = 1e-2,
14 | betas = (0.9, 0.99),
15 | eps = 1e-8,
16 | filter_by_requires_grad = True,
17 | group_wd_params = True,
18 | **kwargs
19 | ):
20 | if filter_by_requires_grad:
21 | params = list(filter(lambda t: t.requires_grad, params))
22 |
23 | if group_wd_params and wd > 0:
24 | wd_params, no_wd_params = separate_weight_decayable_params(params)
25 |
26 | params = [
27 | {'params': wd_params},
28 | {'params': no_wd_params, 'weight_decay': 0},
29 | ]
30 |
31 | if wd == 0:
32 | return Adam(params, lr = lr, betas = betas, eps = eps)
33 |
34 | return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
35 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | exec(open('gigagan_pytorch/version.py').read())
4 |
5 | setup(
6 | name = 'gigagan-pytorch',
7 | packages = find_packages(exclude=[]),
8 | version = __version__,
9 | license='MIT',
10 | description = 'GigaGAN - Pytorch',
11 | author = 'Phil Wang',
12 | author_email = 'lucidrains@gmail.com',
13 | long_description_content_type = 'text/markdown',
14 | url = 'https://github.com/lucidrains/ETSformer-pytorch',
15 | keywords = [
16 | 'artificial intelligence',
17 | 'deep learning',
18 | 'generative adversarial networks'
19 | ],
20 | install_requires=[
21 | 'accelerate',
22 | 'beartype',
23 | 'einops>=0.6',
24 | 'ema-pytorch',
25 | 'kornia',
26 | 'numerize',
27 | 'open-clip-torch>=2.0.0,<3.0.0',
28 | 'pillow',
29 | 'torch>=1.6',
30 | 'torchvision',
31 | 'tqdm'
32 | ],
33 | classifiers=[
34 | 'Development Status :: 4 - Beta',
35 | 'Intended Audience :: Developers',
36 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
37 | 'License :: OSI Approved :: MIT License',
38 | 'Programming Language :: Python :: 3.6',
39 | ],
40 | )
41 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/gigagan_pytorch/distributed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Function
4 | import torch.distributed as dist
5 |
6 | from einops import rearrange
7 |
8 | # helpers
9 |
10 | def exists(val):
11 | return val is not None
12 |
13 | def pad_dim_to(t, length, dim = 0):
14 | pad_length = length - t.shape[dim]
15 | zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
16 | return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))
17 |
18 | # distributed helpers
19 |
20 | def all_gather_variable_dim(t, dim = 0, sizes = None):
21 | device, world_size = t.device, dist.get_world_size()
22 |
23 | if not exists(sizes):
24 | size = torch.tensor(t.shape[dim], device = device, dtype = torch.long)
25 | sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
26 | dist.all_gather(sizes, size)
27 | sizes = torch.stack(sizes)
28 |
29 | max_size = sizes.amax().item()
30 | padded_t = pad_dim_to(t, max_size, dim = dim)
31 |
32 | gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)]
33 | dist.all_gather(gathered_tensors, padded_t)
34 |
35 | gathered_tensor = torch.cat(gathered_tensors, dim = dim)
36 | seq = torch.arange(max_size, device = device)
37 |
38 | mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
39 | mask = rearrange(mask, 'i j -> (i j)')
40 | seq = torch.arange(mask.shape[-1], device = device)
41 | indices = seq[mask]
42 |
43 | gathered_tensor = gathered_tensor.index_select(dim, indices)
44 |
45 | return gathered_tensor, sizes
46 |
47 | class AllGather(Function):
48 | @staticmethod
49 | def forward(ctx, x, dim, sizes):
50 | is_dist = dist.is_initialized() and dist.get_world_size() > 1
51 | ctx.is_dist = is_dist
52 |
53 | if not is_dist:
54 | return x, None
55 |
56 | x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
57 | ctx.batch_sizes = batch_sizes.tolist()
58 | ctx.dim = dim
59 | return x, batch_sizes
60 |
61 | @staticmethod
62 | def backward(ctx, grads, _):
63 | if not ctx.is_dist:
64 | return grads, None, None
65 |
66 | batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
67 | grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
68 | return grads_by_rank[rank], None, None
69 |
70 | all_gather = AllGather.apply
71 |
--------------------------------------------------------------------------------
/gigagan_pytorch/attend.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | from packaging import version
3 | from collections import namedtuple
4 |
5 | import torch
6 | from torch import nn, einsum
7 | import torch.nn.functional as F
8 |
9 |
10 | # constants
11 |
12 | AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
13 |
14 | # helpers
15 |
16 | def exists(val):
17 | return val is not None
18 |
19 | def once(fn):
20 | called = False
21 | @wraps(fn)
22 | def inner(x):
23 | nonlocal called
24 | if called:
25 | return
26 | called = True
27 | return fn(x)
28 | return inner
29 |
30 | print_once = once(print)
31 |
32 | # main class
33 |
34 | class Attend(nn.Module):
35 | def __init__(
36 | self,
37 | dropout = 0.,
38 | flash = False
39 | ):
40 | super().__init__()
41 | self.dropout = dropout
42 | self.attn_dropout = nn.Dropout(dropout)
43 |
44 | self.flash = flash
45 | assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
46 |
47 | # determine efficient attention configs for cuda and cpu
48 |
49 | self.cpu_config = AttentionConfig(True, True, True)
50 | self.cuda_config = None
51 |
52 | if not torch.cuda.is_available() or not flash:
53 | return
54 |
55 | device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
56 |
57 | if device_properties.major == 8 and device_properties.minor == 0:
58 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
59 | self.cuda_config = AttentionConfig(True, False, False)
60 | else:
61 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
62 | self.cuda_config = AttentionConfig(False, True, True)
63 |
64 | def flash_attn(self, q, k, v):
65 | is_cuda = q.is_cuda
66 |
67 | q, k, v = map(lambda t: t.contiguous(), (q, k, v))
68 |
69 | # Check if there is a compatible device for flash attention
70 |
71 | config = self.cuda_config if is_cuda else self.cpu_config
72 |
73 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
74 |
75 | with torch.backends.cuda.sdp_kernel(**config._asdict()):
76 | out = F.scaled_dot_product_attention(
77 | q, k, v,
78 | dropout_p = self.dropout if self.training else 0.
79 | )
80 |
81 | return out
82 |
83 | def forward(self, q, k, v):
84 | """
85 | einstein notation
86 | b - batch
87 | h - heads
88 | n, i, j - sequence length (base sequence length, source, target)
89 | d - feature dimension
90 | """
91 |
92 | if self.flash:
93 | return self.flash_attn(q, k, v)
94 |
95 | scale = q.shape[-1] ** -0.5
96 |
97 | # similarity
98 |
99 | sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
100 |
101 | # attention
102 |
103 | attn = sim.softmax(dim = -1)
104 | attn = self.attn_dropout(attn)
105 |
106 | # aggregate values
107 |
108 | out = einsum("b h i j, b h j d -> b h i d", attn, v)
109 |
110 | return out
111 |
--------------------------------------------------------------------------------
/gigagan_pytorch/data.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from pathlib import Path
3 |
4 | import torch
5 | from torch import nn
6 | from torch.utils.data import Dataset, DataLoader
7 |
8 | from PIL import Image
9 | from torchvision import transforms as T
10 |
11 | from beartype.door import is_bearable
12 | from beartype.typing import Tuple
13 |
14 | # helper functions
15 |
16 | def exists(val):
17 | return val is not None
18 |
19 | def convert_image_to_fn(img_type, image):
20 | if image.mode == img_type:
21 | return image
22 |
23 | return image.convert(img_type)
24 |
25 | # custom collation function
26 | # so dataset can return a str and it will collate into List[str]
27 |
28 | def collate_tensors_or_str(data):
29 | is_one_data = not isinstance(data[0], tuple)
30 |
31 | if is_one_data:
32 | data = torch.stack(data)
33 | return (data,)
34 |
35 | outputs = []
36 | for datum in zip(*data):
37 | if is_bearable(datum, Tuple[str, ...]):
38 | output = list(datum)
39 | else:
40 | output = torch.stack(datum)
41 |
42 | outputs.append(output)
43 |
44 | return tuple(outputs)
45 |
46 | # dataset classes
47 |
48 | class ImageDataset(Dataset):
49 | def __init__(
50 | self,
51 | folder,
52 | image_size,
53 | exts = ['jpg', 'jpeg', 'png', 'tiff'],
54 | augment_horizontal_flip = False,
55 | convert_image_to = None
56 | ):
57 | super().__init__()
58 | self.folder = folder
59 | self.image_size = image_size
60 |
61 | self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
62 |
63 | assert len(self.paths) > 0, 'your folder contains no images'
64 | assert len(self.paths) > 100, 'you need at least 100 images, 10k for research paper, millions for miraculous results (try Laion-5B)'
65 |
66 | maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else nn.Identity()
67 |
68 | self.transform = T.Compose([
69 | T.Lambda(maybe_convert_fn),
70 | T.Resize(image_size),
71 | T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
72 | T.CenterCrop(image_size),
73 | T.ToTensor()
74 | ])
75 |
76 | def get_dataloader(self, *args, **kwargs):
77 | return DataLoader(self, *args, shuffle = True, drop_last = True, **kwargs)
78 |
79 | def __len__(self):
80 | return len(self.paths)
81 |
82 | def __getitem__(self, index):
83 | path = self.paths[index]
84 | img = Image.open(path)
85 | return self.transform(img)
86 |
87 | class TextImageDataset(Dataset):
88 | def __init__(self):
89 | raise NotImplementedError
90 |
91 | def get_dataloader(self, *args, **kwargs):
92 | return DataLoader(self, *args, collate_fn = collate_tensors_or_str, **kwargs)
93 |
94 | class MockTextImageDataset(TextImageDataset):
95 | def __init__(
96 | self,
97 | image_size,
98 | length = int(1e5),
99 | channels = 3
100 | ):
101 | self.image_size = image_size
102 | self.channels = channels
103 | self.length = length
104 |
105 | def get_dataloader(self, *args, **kwargs):
106 | return DataLoader(self, *args, collate_fn = collate_tensors_or_str, **kwargs)
107 |
108 | def __len__(self):
109 | return self.length
110 |
111 | def __getitem__(self, index):
112 | mock_image = torch.randn(self.channels, self.image_size, self.image_size)
113 | return mock_image, 'mock text'
114 |
--------------------------------------------------------------------------------
/gigagan_pytorch/open_clip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | import torch.nn.functional as F
4 | import open_clip
5 |
6 | from einops import rearrange
7 |
8 | from beartype import beartype
9 | from beartype.typing import List, Optional
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 | def l2norm(t):
15 | return F.normalize(t, dim = -1)
16 |
17 | class OpenClipAdapter(nn.Module):
18 | @beartype
19 | def __init__(
20 | self,
21 | name = 'ViT-B/32',
22 | pretrained = 'laion400m_e32',
23 | tokenizer_name = 'ViT-B-32-quickgelu',
24 | eos_id = 49407
25 | ):
26 | super().__init__()
27 |
28 | clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained)
29 | tokenizer = open_clip.get_tokenizer(tokenizer_name)
30 |
31 | self.clip = clip
32 | self.tokenizer = tokenizer
33 | self.eos_id = eos_id
34 |
35 | # hook for getting final text representation
36 |
37 | text_attention_final = self.find_layer('ln_final')
38 | self._dim_latent = text_attention_final.weight.shape[0]
39 | self.text_handle = text_attention_final.register_forward_hook(self._text_hook)
40 |
41 | # hook for getting final image representation
42 | # this is for vision-aided gan loss
43 |
44 | self._dim_image_latent = self.find_layer('visual.ln_post').weight.shape[0]
45 |
46 | num_visual_layers = len(clip.visual.transformer.resblocks)
47 | self.image_handles = []
48 |
49 | for visual_layer in range(num_visual_layers):
50 | image_attention_final = self.find_layer(f'visual.transformer.resblocks.{visual_layer}')
51 |
52 | handle = image_attention_final.register_forward_hook(self._image_hook)
53 | self.image_handles.append(handle)
54 |
55 | # normalize fn
56 |
57 | self.clip_normalize = preprocess.transforms[-1]
58 | self.cleared = False
59 |
60 | @property
61 | def device(self):
62 | return next(self.parameters()).device
63 |
64 | def find_layer(self, layer):
65 | modules = dict([*self.clip.named_modules()])
66 | return modules.get(layer, None)
67 |
68 | def clear(self):
69 | if self.cleared:
70 | return
71 |
72 | self.text_handle()
73 | self.image_handle()
74 |
75 | def _text_hook(self, _, inputs, outputs):
76 | self.text_encodings = outputs
77 |
78 | def _image_hook(self, _, inputs, outputs):
79 | if not hasattr(self, 'image_encodings'):
80 | self.image_encodings = []
81 |
82 | self.image_encodings.append(outputs)
83 |
84 | @property
85 | def dim_latent(self):
86 | return self._dim_latent
87 |
88 | @property
89 | def image_size(self):
90 | image_size = self.clip.visual.image_size
91 | if isinstance(image_size, tuple):
92 | return max(image_size)
93 | return image_size
94 |
95 | @property
96 | def image_channels(self):
97 | return 3
98 |
99 | @property
100 | def max_text_len(self):
101 | return self.clip.positional_embedding.shape[0]
102 |
103 | @beartype
104 | def embed_texts(
105 | self,
106 | texts: List[str]
107 | ):
108 | ids = self.tokenizer(texts)
109 | ids = ids.to(self.device)
110 | ids = ids[..., :self.max_text_len]
111 |
112 | is_eos_id = (ids == self.eos_id)
113 | text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
114 | text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
115 | text_mask = text_mask & (ids != 0)
116 | assert not self.cleared
117 |
118 | text_embed = self.clip.encode_text(ids)
119 | text_encodings = self.text_encodings
120 | text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
121 | del self.text_encodings
122 | return l2norm(text_embed.float()), text_encodings.float()
123 |
124 | def embed_images(self, images):
125 | if images.shape[-1] != self.image_size:
126 | images = F.interpolate(images, self.image_size)
127 |
128 | assert not self.cleared
129 | images = self.clip_normalize(images)
130 | image_embeds = self.clip.encode_image(images)
131 |
132 | image_encodings = rearrange(self.image_encodings, 'l n b d -> l b n d')
133 | del self.image_encodings
134 |
135 | return l2norm(image_embeds.float()), image_encodings.float()
136 |
137 | @beartype
138 | def contrastive_loss(
139 | self,
140 | images,
141 | texts: Optional[List[str]] = None,
142 | text_embeds: Optional[torch.Tensor] = None
143 | ):
144 | assert exists(texts) ^ exists(text_embeds)
145 |
146 | if not exists(text_embeds):
147 | text_embeds, _ = self.embed_texts(texts)
148 |
149 | image_embeds, _ = self.embed_images(images)
150 |
151 | n = text_embeds.shape[0]
152 |
153 | temperature = self.clip.logit_scale.exp()
154 | sim = einsum('i d, j d -> i j', text_embeds, image_embeds) * temperature
155 |
156 | labels = torch.arange(n, device = sim.device)
157 |
158 | return (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
159 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | ## GigaGAN - Pytorch
6 |
7 | Implementation of GigaGAN (project page), new SOTA GAN out of Adobe.
8 |
9 | I will also add a few findings from lightweight gan, for faster convergence (skip layer excitation) and better stability (reconstruction auxiliary loss in discriminator)
10 |
11 | It will also contain the code for the 1k - 4k upsamplers, which I find to be the highlight of this paper.
12 |
13 | Please join
if you are interested in helping out with the replication with the LAION community
14 |
15 | ## Appreciation
16 |
17 | - StabilityAI and 🤗 Huggingface for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.
18 |
19 | - 🤗 Huggingface for their accelerate library
20 |
21 | - All the maintainers at OpenClip, for their SOTA open sourced contrastive learning text-image models
22 |
23 | - Xavier for the very helpful code review, and for discussions on how the scale invariance in the discriminator should be built!
24 |
25 | - @CerebralSeed for pull requesting the initial sampling code for both the generator and upsampler!
26 |
27 | - Keerth for the code review and pointing out some discrepancies with the paper!
28 |
29 | ## Install
30 |
31 | ```bash
32 | $ pip install gigagan-pytorch
33 | ```
34 |
35 | ## Usage
36 |
37 | Simple unconditional GAN, for starters
38 |
39 | ```python
40 | import torch
41 |
42 | from gigagan_pytorch import (
43 | GigaGAN,
44 | ImageDataset
45 | )
46 |
47 | gan = GigaGAN(
48 | generator = dict(
49 | dim_capacity = 8,
50 | style_network = dict(
51 | dim = 64,
52 | depth = 4
53 | ),
54 | image_size = 256,
55 | dim_max = 512,
56 | num_skip_layers_excite = 4,
57 | unconditional = True
58 | ),
59 | discriminator = dict(
60 | dim_capacity = 16,
61 | dim_max = 512,
62 | image_size = 256,
63 | num_skip_layers_excite = 4,
64 | unconditional = True
65 | ),
66 | amp = True
67 | ).cuda()
68 |
69 | # dataset
70 |
71 | dataset = ImageDataset(
72 | folder = '/path/to/your/data',
73 | image_size = 256
74 | )
75 |
76 | dataloader = dataset.get_dataloader(batch_size = 1)
77 |
78 | # you must then set the dataloader for the GAN before training
79 |
80 | gan.set_dataloader(dataloader)
81 |
82 | # training the discriminator and generator alternating
83 | # for 100 steps in this example, batch size 1, gradient accumulated 8 times
84 |
85 | gan(
86 | steps = 100,
87 | grad_accum_every = 8
88 | )
89 |
90 | # after much training
91 |
92 | images = gan.generate(batch_size = 4) # (4, 3, 256, 256)
93 | ```
94 |
95 | For unconditional Unet Upsampler
96 |
97 | ```python
98 | import torch
99 | from gigagan_pytorch import (
100 | GigaGAN,
101 | ImageDataset
102 | )
103 |
104 | gan = GigaGAN(
105 | train_upsampler = True, # set this to True
106 | generator = dict(
107 | style_network = dict(
108 | dim = 64,
109 | depth = 4
110 | ),
111 | dim = 32,
112 | image_size = 256,
113 | input_image_size = 64,
114 | unconditional = True
115 | ),
116 | discriminator = dict(
117 | dim_capacity = 16,
118 | dim_max = 512,
119 | image_size = 256,
120 | num_skip_layers_excite = 4,
121 | multiscale_input_resolutions = (128,),
122 | unconditional = True
123 | ),
124 | amp = True
125 | ).cuda()
126 |
127 | dataset = ImageDataset(
128 | folder = '/path/to/your/data',
129 | image_size = 256
130 | )
131 |
132 | dataloader = dataset.get_dataloader(batch_size = 1)
133 |
134 | gan.set_dataloader(dataloader)
135 |
136 | # training the discriminator and generator alternating
137 | # for 100 steps in this example, batch size 1, gradient accumulated 8 times
138 |
139 | gan(
140 | steps = 100,
141 | grad_accum_every = 8
142 | )
143 |
144 | # after much training
145 |
146 | lowres = torch.randn(1, 3, 64, 64).cuda()
147 |
148 | images = gan.generate(lowres) # (1, 3, 256, 256)
149 | ```
150 |
151 | ## Losses
152 |
153 | * `G` - Generator
154 | * `MSG` - Multiscale Generator
155 | * `D` - Discriminator
156 | * `MSD` - Multiscale Discriminator
157 | * `GP` - Gradient Penalty
158 | * `SSL` - Auxiliary Reconstruction in Discriminator (from Lightweight GAN)
159 | * `VD` - Vision-aided Discriminator
160 | * `VG` - Vision-aided Generator
161 | * `CL` - Generator Constrastive Loss
162 | * `MAL` - Matching Aware Loss
163 |
164 | A healthy run would have `G`, `MSG`, `D`, `MSD` with values hovering between `0` to `10`, and usually staying pretty constant. If at any time after 1k training steps these values persist at triple digits, that would mean something is wrong. It is ok for generator and discriminator values to occasionally dip negative, but it should swing back up to the range above.
165 |
166 | `GP` and `SSL` should be pushed towards `0`. `GP` can occasionally spike; I like to imagine it as the networks undergoing some epiphany
167 |
168 | ## Multi-GPU Training
169 |
170 | The `GigaGAN` class is now equipped with 🤗 Accelerator. You can easily do multi-gpu training in two steps using their `accelerate` CLI
171 |
172 | At the project root directory, where the training script is, run
173 |
174 | ```python
175 | $ accelerate config
176 | ```
177 |
178 | Then, in the same directory
179 |
180 | ```python
181 | $ accelerate launch train.py
182 | ```
183 |
184 | ## Todo
185 |
186 | - [x] make sure it can be trained unconditionally
187 | - [x] read the relevant papers and knock out all 3 auxiliary losses
188 | - [x] matching aware loss
189 | - [x] clip loss
190 | - [x] vision-aided discriminator loss
191 | - [x] add reconstruction losses on arbitrary stages in the discriminator (lightweight gan)
192 | - [x] figure out how the random projections are used from projected-gan
193 | - [x] vision aided discriminator needs to extract N layers from the vision model in CLIP
194 | - [x] figure out whether to discard CLS token and reshape into image dimensions for convolution, or stick with attention and condition with adaptive layernorm - also turn off vision aided gan in unconditional case
195 | - [x] unet upsampler
196 | - [x] add adaptive conv
197 | - [x] modify latter stage of unet to also output rgb residuals, and pass the rgb into discriminator. make discriminator agnostic to rgb being passed in
198 | - [x] do pixel shuffle upsamples for unet
199 | - [x] get a code review for the multi-scale inputs and outputs, as the paper was a bit vague
200 | - [x] add upsampling network architecture
201 | - [x] make unconditional work for both base generator and upsampler
202 | - [x] make text conditioned training work for both base and upsampler
203 | - [x] make recon more efficient by random sampling patches
204 | - [x] make sure generator and discriminator can also accept pre-encoded CLIP text encodings
205 | - [x] do a review of the auxiliary losses
206 | - [x] add contrastive loss for generator
207 | - [x] add vision aided loss
208 | - [x] add gradient penalty for vision aided discr - make optional
209 | - [x] add matching awareness loss - figure out if rotating text conditions by one is good enough for mismatching (without drawing an additional batch from dataloader)
210 | - [x] make sure gradient accumulation works with matching aware loss
211 | - [x] matching awareness loss runs and is stable
212 | - [x] vision aided trains
213 | - [x] add some differentiable augmentations, proven technique from the old GAN days
214 | - [x] remove any magic being done with automatic rgbs processing, and have it explicitly passed in - offer functions on the discriminator that can process real images into the right multi-scales
215 | - [x] add horizontal flip for starters
216 |
217 | - [ ] move all modulation projections into the adaptive conv2d class
218 | - [ ] add accelerate
219 | - [x] works single machine
220 | - [x] works for mixed precision (make sure gradient penalty is scaled correctly), take care of manual scaler saving and reloading, borrow from imagen-pytorch
221 | - [x] make sure it works multi-GPU for one machine
222 | - [ ] have someone else try multiple machines
223 |
224 | - [ ] clip should be optional for all modules, and managed by `GigaGAN`, with text -> text embeds processed once
225 | - [ ] add ability to select a random subset from multiscale dimension, for efficiency
226 |
227 | - [ ] port over CLI from lightweight|stylegan2-pytorch
228 | - [ ] hook up laion dataset for text-image
229 |
230 | ## Citations
231 |
232 | ```bibtex
233 | @misc{https://doi.org/10.48550/arxiv.2303.05511,
234 | url = {https://arxiv.org/abs/2303.05511},
235 | author = {Kang, Minguk and Zhu, Jun-Yan and Zhang, Richard and Park, Jaesik and Shechtman, Eli and Paris, Sylvain and Park, Taesung},
236 | title = {Scaling up GANs for Text-to-Image Synthesis},
237 | publisher = {arXiv},
238 | year = {2023},
239 | copyright = {arXiv.org perpetual, non-exclusive license}
240 | }
241 | ```
242 |
243 | ```bibtex
244 | @article{Liu2021TowardsFA,
245 | title = {Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis},
246 | author = {Bingchen Liu and Yizhe Zhu and Kunpeng Song and A. Elgammal},
247 | journal = {ArXiv},
248 | year = {2021},
249 | volume = {abs/2101.04775}
250 | }
251 | ```
252 |
253 | ```bibtex
254 | @inproceedings{dao2022flashattention,
255 | title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
256 | author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
257 | booktitle = {Advances in Neural Information Processing Systems},
258 | year = {2022}
259 | }
260 | ```
261 |
262 | ```bibtex
263 | @inproceedings{Karras2020ada,
264 | title = {Training Generative Adversarial Networks with Limited Data},
265 | author = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila},
266 | booktitle = {Proc. NeurIPS},
267 | year = {2020}
268 | }
269 | ```
270 |
271 | ```bibtex
272 | @article{Xu2024VideoGigaGANTD,
273 | title = {VideoGigaGAN: Towards Detail-rich Video Super-Resolution},
274 | author = {Yiran Xu and Taesung Park and Richard Zhang and Yang Zhou and Eli Shechtman and Feng Liu and Jia-Bin Huang and Difan Liu},
275 | journal = {ArXiv},
276 | year = {2024},
277 | volume = {abs/2404.12388},
278 | url ={https://api.semanticscholar.org/CorpusID:269214195}
279 | }
280 | ```
281 |
--------------------------------------------------------------------------------
/gigagan_pytorch/unet_upsampler.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from math import log2
4 | from functools import partial
5 | from itertools import islice
6 |
7 | import torch
8 | from torch import nn
9 | import torch.nn.functional as F
10 | from torch.nn import Module, ModuleList
11 |
12 | from einops import rearrange, repeat, pack, unpack
13 | from einops.layers.torch import Rearrange
14 |
15 | from gigagan_pytorch.attend import Attend
16 | from gigagan_pytorch.gigagan_pytorch import (
17 | BaseGenerator,
18 | StyleNetwork,
19 | AdaptiveConv2DMod,
20 | AdaptiveConv1DMod,
21 | TextEncoder,
22 | CrossAttentionBlock,
23 | Upsample,
24 | PixelShuffleUpsample,
25 | Blur
26 | )
27 |
28 | from kornia.filters import filter3d, filter2d
29 |
30 | from beartype import beartype
31 | from beartype.typing import List, Dict, Iterable, Literal
32 |
33 | # helpers functions
34 |
35 | def exists(x):
36 | return x is not None
37 |
38 | def default(val, d):
39 | if exists(val):
40 | return val
41 | return d() if callable(d) else d
42 |
43 | def pack_one(t, pattern):
44 | return pack([t], pattern)
45 |
46 | def unpack_one(t, ps, pattern):
47 | return unpack(t, ps, pattern)[0]
48 |
49 | def cast_tuple(t, length = 1):
50 | if isinstance(t, tuple):
51 | return t
52 | return ((t,) * length)
53 |
54 | def identity(t, *args, **kwargs):
55 | return t
56 |
57 | def is_power_of_two(n):
58 | return log2(n).is_integer()
59 |
60 | def null_iterator():
61 | while True:
62 | yield None
63 |
64 | def fold_space_into_batch(x):
65 | x = rearrange(x, 'b c t h w -> b h w c t')
66 | x, ps = pack_one(x, '* c t')
67 |
68 | def split_space_from_batch(out):
69 | out = unpack_one(x, ps, '* c t')
70 | out = rearrange(out, 'b h w c t -> b c t h w')
71 | return out
72 |
73 | return x, split_space_from_batch
74 |
75 | # small helper modules
76 |
77 | def interpolate_1d(x, length, mode = 'bilinear'):
78 | x = rearrange(x, 'b c t -> b c t 1')
79 | x = F.interpolate(x, (length, 1), mode = mode)
80 | return rearrange(x, 'b c t 1 -> b c t')
81 |
82 | class Downsample(Module):
83 | def __init__(
84 | self,
85 | dim,
86 | dim_out = None,
87 | skip_downsample = False,
88 | has_temporal_layers = False
89 | ):
90 | super().__init__()
91 | dim_out = default(dim_out, dim)
92 |
93 | self.skip_downsample = skip_downsample
94 |
95 | self.conv2d = nn.Conv2d(dim, dim_out, 3, padding = 1)
96 |
97 | self.has_temporal_layers = has_temporal_layers
98 |
99 | if has_temporal_layers:
100 | self.conv1d = nn.Conv1d(dim_out, dim_out, 3, padding = 1)
101 |
102 | nn.init.dirac_(self.conv1d.weight)
103 | nn.init.zeros_(self.conv1d.bias)
104 |
105 | self.register_buffer('filter', torch.Tensor([1., 2., 1.]))
106 |
107 | def forward(self, x):
108 | batch = x.shape[0]
109 | is_input_video = x.ndim == 5
110 |
111 | assert not (is_input_video and not self.has_temporal_layers)
112 |
113 | if is_input_video:
114 | x = rearrange(x, 'b c t h w -> (b t) c h w')
115 |
116 | x = self.conv2d(x)
117 |
118 | if is_input_video:
119 | x = rearrange(x, '(b t) c h w -> b h w c t', b = batch)
120 | x, ps = pack_one(x, '* c t')
121 |
122 | x = self.conv1d(x)
123 |
124 | x = unpack_one(x, ps, '* c t')
125 | x = rearrange(x, 'b h w c t -> b c t h w')
126 |
127 | # if not downsampling, early return
128 |
129 | if self.skip_downsample:
130 | return x, x[:, 0:0]
131 |
132 | # save before blur to subtract out for high frequency fmap skip connection
133 |
134 | before_blur_input = x
135 |
136 | # blur 2d or 3d, depending
137 |
138 | f = self.filter
139 | N = None
140 |
141 | if is_input_video:
142 | f = f[N, N, :] * f[N, :, N] * f[:, N, N]
143 | filter_fn = filter3d
144 | maxpool_fn = F.max_pool3d
145 | else:
146 | f = f[N, :] * f[:, N]
147 | filter_fn = filter2d
148 | maxpool_fn = F.max_pool2d
149 |
150 | blurred = filter_fn(x, f[N, ...], normalized = True)
151 |
152 | # get high frequency fmap
153 |
154 | high_freq_fmap = before_blur_input - blurred
155 |
156 | # max pool 2d or 3d, depending
157 |
158 | x = maxpool_fn(x, kernel_size = 2)
159 |
160 | return x, high_freq_fmap
161 |
162 | class TemporalBlur(Module):
163 | def __init__(self):
164 | super().__init__()
165 | f = torch.Tensor([1, 2, 1])
166 | self.register_buffer('f', f)
167 |
168 | def forward(self, x):
169 | f = repeat(self.f, 't -> 1 t h w', h = 3, w = 3)
170 | return filter3d(x, f, normalized = True)
171 |
172 | class TemporalUpsample(Module):
173 | def __init__(
174 | self,
175 | dim,
176 | dim_out = None
177 | ):
178 | super().__init__()
179 | self.blur = TemporalBlur()
180 |
181 | def forward(self, x):
182 | assert x.ndim == 5
183 | time = x.shape[2]
184 |
185 | x = rearrange(x, 'b c t h w -> b h w c t')
186 | x, ps = pack_one(x, '* c t')
187 |
188 | x = interpolate_1d(x, time * 2, mode = 'bilinear')
189 |
190 | x = unpack_one(x, ps, '* c t')
191 | x = rearrange(x, 'b h w c t -> b c t h w')
192 | x = self.blur(x)
193 | return x
194 |
195 | class PixelShuffleTemporalUpsample(Module):
196 | def __init__(self, dim, dim_out = None):
197 | super().__init__()
198 | dim_out = default(dim_out, dim)
199 |
200 | conv = nn.Conv3d(dim, dim_out * 2, 1)
201 |
202 | self.net = nn.Sequential(
203 | conv,
204 | nn.SiLU(),
205 | Rearrange('b (c p) t h w -> b c (t p) h w', p = 2)
206 | )
207 |
208 | self.init_conv_(conv)
209 |
210 | def init_conv_(self, conv):
211 | o, i, t, h, w = conv.weight.shape
212 | conv_weight = torch.empty(o // 2, i, t, h, w)
213 | nn.init.kaiming_uniform_(conv_weight)
214 | conv_weight = repeat(conv_weight, 'o ... -> (o 2) ...')
215 |
216 | conv.weight.data.copy_(conv_weight)
217 | nn.init.zeros_(conv.bias.data)
218 |
219 | def forward(self, x):
220 | return self.net(x)
221 |
222 | # norm
223 |
224 | class RMSNorm(Module):
225 | def __init__(self, dim):
226 | super().__init__()
227 | self.scale = dim ** 0.5
228 | self.gamma = nn.Parameter(torch.ones(dim))
229 |
230 | def forward(self, x):
231 | spatial_dims = ((1,) * (x.ndim - 2))
232 | gamma = self.gamma.reshape(-1, *spatial_dims)
233 |
234 | return F.normalize(x, dim = 1) * gamma * self.scale
235 |
236 | # building block modules
237 |
238 | class Block(Module):
239 | @beartype
240 | def __init__(
241 | self,
242 | dim,
243 | dim_out,
244 | num_conv_kernels = 0,
245 | conv_type: Literal['1d', '2d'] = '2d',
246 | ):
247 | super().__init__()
248 |
249 | adaptive_conv_klass = AdaptiveConv2DMod if conv_type == '2d' else AdaptiveConv1DMod
250 |
251 | self.proj = adaptive_conv_klass(dim, dim_out, kernel = 3, num_conv_kernels = num_conv_kernels)
252 | self.norm = RMSNorm(dim_out)
253 | self.act = nn.SiLU()
254 |
255 | def forward(
256 | self,
257 | x,
258 | conv_mods_iter: Iterable | None = None
259 | ):
260 | conv_mods_iter = default(conv_mods_iter, null_iterator())
261 |
262 | x = self.proj(
263 | x,
264 | mod = next(conv_mods_iter),
265 | kernel_mod = next(conv_mods_iter)
266 | )
267 |
268 | x = self.norm(x)
269 | x = self.act(x)
270 | return x
271 |
272 | class ResnetBlock(Module):
273 | @beartype
274 | def __init__(
275 | self,
276 | dim,
277 | dim_out,
278 | *,
279 | num_conv_kernels = 0,
280 | conv_type: Literal['1d', '2d'] = '2d',
281 | style_dims: List[int] = []
282 | ):
283 | super().__init__()
284 |
285 | mod_dims = [
286 | dim,
287 | num_conv_kernels,
288 | dim_out,
289 | num_conv_kernels
290 | ]
291 |
292 | style_dims.extend(mod_dims)
293 |
294 | self.num_mods = len(mod_dims)
295 |
296 | self.block1 = Block(dim, dim_out, num_conv_kernels = num_conv_kernels, conv_type = conv_type)
297 | self.block2 = Block(dim_out, dim_out, num_conv_kernels = num_conv_kernels, conv_type = conv_type)
298 |
299 | conv_klass = nn.Conv2d if conv_type == '2d' else nn.Conv1d
300 | self.res_conv = conv_klass(dim, dim_out, 1) if dim != dim_out else nn.Identity()
301 |
302 | def forward(
303 | self,
304 | x,
305 | conv_mods_iter: Iterable | None = None
306 | ):
307 | h = self.block1(x, conv_mods_iter = conv_mods_iter)
308 | h = self.block2(h, conv_mods_iter = conv_mods_iter)
309 |
310 | return h + self.res_conv(x)
311 |
312 | class LinearAttention(Module):
313 | def __init__(
314 | self,
315 | dim,
316 | heads = 4,
317 | dim_head = 32
318 | ):
319 | super().__init__()
320 | self.scale = dim_head ** -0.5
321 | self.heads = heads
322 | hidden_dim = dim_head * heads
323 |
324 | self.norm = RMSNorm(dim)
325 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
326 |
327 | self.to_out = nn.Sequential(
328 | nn.Conv2d(hidden_dim, dim, 1),
329 | RMSNorm(dim)
330 | )
331 |
332 | def forward(self, x):
333 | b, c, h, w = x.shape
334 |
335 | x = self.norm(x)
336 |
337 | qkv = self.to_qkv(x).chunk(3, dim = 1)
338 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
339 |
340 | q = q.softmax(dim = -2)
341 | k = k.softmax(dim = -1)
342 |
343 | q = q * self.scale
344 |
345 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
346 |
347 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
348 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
349 | return self.to_out(out)
350 |
351 | class Attention(Module):
352 | def __init__(
353 | self,
354 | dim,
355 | heads = 4,
356 | dim_head = 32,
357 | flash = False
358 | ):
359 | super().__init__()
360 | self.heads = heads
361 | hidden_dim = dim_head * heads
362 |
363 | self.norm = RMSNorm(dim)
364 | self.attend = Attend(flash = flash)
365 |
366 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
367 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
368 |
369 | def forward(self, x):
370 | b, c, h, w = x.shape
371 |
372 | x = self.norm(x)
373 |
374 | qkv = self.to_qkv(x).chunk(3, dim = 1)
375 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h (x y) c', h = self.heads), qkv)
376 |
377 | out = self.attend(q, k, v)
378 |
379 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
380 | return self.to_out(out)
381 |
382 | # feedforward
383 |
384 | def FeedForward(dim, mult = 4):
385 | return nn.Sequential(
386 | RMSNorm(dim),
387 | nn.Conv2d(dim, dim * mult, 1),
388 | nn.GELU(),
389 | nn.Conv2d(dim * mult, dim, 1)
390 | )
391 |
392 | # transformers
393 |
394 | class Transformer(Module):
395 | def __init__(
396 | self,
397 | dim,
398 | dim_head = 64,
399 | heads = 8,
400 | depth = 1,
401 | flash_attn = True,
402 | ff_mult = 4
403 | ):
404 | super().__init__()
405 | self.layers = ModuleList([])
406 |
407 | for _ in range(depth):
408 | self.layers.append(ModuleList([
409 | Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash_attn),
410 | FeedForward(dim = dim, mult = ff_mult)
411 | ]))
412 |
413 | def forward(self, x):
414 | for attn, ff in self.layers:
415 | x = attn(x) + x
416 | x = ff(x) + x
417 |
418 | return x
419 |
420 | class LinearTransformer(Module):
421 | def __init__(
422 | self,
423 | dim,
424 | dim_head = 64,
425 | heads = 8,
426 | depth = 1,
427 | ff_mult = 4
428 | ):
429 | super().__init__()
430 | self.layers = ModuleList([])
431 |
432 | for _ in range(depth):
433 | self.layers.append(ModuleList([
434 | LinearAttention(dim = dim, dim_head = dim_head, heads = heads),
435 | FeedForward(dim = dim, mult = ff_mult)
436 | ]))
437 |
438 | def forward(self, x):
439 | for attn, ff in self.layers:
440 | x = attn(x) + x
441 | x = ff(x) + x
442 |
443 | return x
444 |
445 | # model
446 |
447 | class UnetUpsampler(BaseGenerator):
448 |
449 | @beartype
450 | def __init__(
451 | self,
452 | dim,
453 | *,
454 | image_size,
455 | input_image_size,
456 | init_dim = None,
457 | out_dim = None,
458 | text_encoder: TextEncoder | Dict | None = None,
459 | style_network: StyleNetwork | Dict | None = None,
460 | style_network_dim = None,
461 | dim_mults = (1, 2, 4, 8, 16),
462 | channels = 3,
463 | full_attn = (False, False, False, True, True),
464 | cross_attn = (False, False, False, True, True),
465 | flash_attn = True,
466 | self_attn_dim_head = 64,
467 | self_attn_heads = 8,
468 | self_attn_dot_product = True,
469 | self_attn_ff_mult = 4,
470 | attn_depths = (1, 1, 1, 1, 1),
471 | temporal_attn_depths = (1, 1, 1, 1, 1),
472 | cross_attn_dim_head = 64,
473 | cross_attn_heads = 8,
474 | cross_ff_mult = 4,
475 | has_temporal_layers = False,
476 | mid_attn_depth = 1,
477 | num_conv_kernels = 2,
478 | unconditional = True,
479 | skip_connect_scale = None
480 | ):
481 | super().__init__()
482 |
483 | # able to upsample video
484 |
485 | self.can_upsample_video = has_temporal_layers
486 |
487 | # style network
488 |
489 | if isinstance(text_encoder, dict):
490 | text_encoder = TextEncoder(**text_encoder)
491 |
492 | self.text_encoder = text_encoder
493 |
494 | if isinstance(style_network, dict):
495 | style_network = StyleNetwork(**style_network)
496 |
497 | self.style_network = style_network
498 |
499 | assert exists(style_network) ^ exists(style_network_dim), 'either style_network or style_network_dim must be passed in'
500 |
501 | # validate text conditioning and style network hparams
502 |
503 | self.unconditional = unconditional
504 | assert unconditional ^ exists(text_encoder), 'if unconditional, text encoder should not be given, and vice versa'
505 | assert not (unconditional and exists(style_network) and style_network.dim_text_latent > 0)
506 | assert unconditional or text_encoder.dim == style_network.dim_text_latent, 'the `dim_text_latent` on your StyleNetwork must be equal to the `dim` set for the TextEncoder'
507 |
508 | assert is_power_of_two(image_size) and is_power_of_two(input_image_size), 'both output image size and input image size must be power of 2'
509 | assert input_image_size < image_size, 'input image size must be smaller than the output image size, thus upsampling'
510 |
511 | num_layer_no_downsample = int(log2(image_size) - log2(input_image_size))
512 | assert num_layer_no_downsample <= len(dim_mults), 'you need more stages in this unet for the level of upsampling'
513 |
514 | self.image_size = image_size
515 | self.input_image_size = input_image_size
516 |
517 | # setup adaptive conv
518 |
519 | style_embed_split_dims = []
520 |
521 | # determine dimensions
522 |
523 | self.channels = channels
524 | input_channels = channels
525 |
526 | init_dim = default(init_dim, dim)
527 | self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
528 |
529 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
530 |
531 | *_, mid_dim = dims
532 |
533 | in_out = list(zip(dims[:-1], dims[1:]))
534 |
535 | block_klass = partial(
536 | ResnetBlock,
537 | num_conv_kernels = num_conv_kernels,
538 | style_dims = style_embed_split_dims
539 | )
540 |
541 | # attention
542 |
543 | full_attn = cast_tuple(full_attn, length = len(dim_mults))
544 | assert len(full_attn) == len(dim_mults)
545 |
546 | FullAttention = partial(Transformer, flash_attn = flash_attn)
547 |
548 | cross_attn = cast_tuple(cross_attn, length = len(dim_mults))
549 | assert unconditional or len(full_attn) == len(dim_mults)
550 |
551 | # skip connection scale
552 |
553 | self.skip_connect_scale = default(skip_connect_scale, 2 ** -0.5)
554 |
555 | # layers
556 |
557 | self.downs = ModuleList([])
558 | self.ups = ModuleList([])
559 | num_resolutions = len(in_out)
560 | skip_connect_dims = []
561 |
562 | for ind, ((dim_in, dim_out), layer_full_attn, layer_cross_attn, layer_attn_depth, layer_temporal_attn_depth) in enumerate(zip(in_out, full_attn, cross_attn, attn_depths, temporal_attn_depths)):
563 |
564 | should_not_downsample = ind < num_layer_no_downsample
565 | has_cross_attn = not self.unconditional and layer_cross_attn
566 |
567 | attn_klass = FullAttention if layer_full_attn else LinearTransformer
568 |
569 | skip_connect_dims.append(dim_in)
570 | skip_connect_dims.append(dim_in + (dim_out if not should_not_downsample else 0))
571 |
572 | temporal_resnet_block = None
573 | temporal_attn = None
574 |
575 | if has_temporal_layers:
576 | temporal_resnet_block = block_klass(dim_in, dim_in, conv_type = '1d')
577 | temporal_attn = FullAttention(dim_in, dim_head = self_attn_dim_head, heads = self_attn_heads, depth = layer_temporal_attn_depth)
578 |
579 | # all unet downsample stages
580 |
581 | self.downs.append(ModuleList([
582 | block_klass(dim_in, dim_in),
583 | block_klass(dim_in, dim_in),
584 | CrossAttentionBlock(dim_in, dim_context = text_encoder.dim, dim_head = self_attn_dim_head, heads = self_attn_heads, ff_mult = self_attn_ff_mult) if has_cross_attn else None,
585 | attn_klass(dim_in, dim_head = self_attn_dim_head, heads = self_attn_heads, depth = layer_attn_depth),
586 | temporal_resnet_block,
587 | temporal_attn,
588 | Downsample(dim_in, dim_out, skip_downsample = should_not_downsample, has_temporal_layers = has_temporal_layers)
589 | ]))
590 |
591 | self.mid_block1 = block_klass(mid_dim, mid_dim)
592 | self.mid_attn = FullAttention(mid_dim, dim_head = self_attn_dim_head, heads = self_attn_heads, depth = mid_attn_depth)
593 | self.mid_block2 = block_klass(mid_dim, mid_dim)
594 | self.mid_to_rgb = nn.Conv2d(mid_dim, channels, 1)
595 |
596 | for ind, ((dim_in, dim_out), layer_cross_attn, layer_full_attn, layer_attn_depth, layer_temporal_attn_depth) in enumerate(zip(reversed(in_out), reversed(full_attn), reversed(cross_attn), reversed(attn_depths), reversed(temporal_attn_depths))):
597 |
598 | attn_klass = FullAttention if layer_full_attn else LinearTransformer
599 | has_cross_attn = not self.unconditional and layer_cross_attn
600 |
601 | temporal_upsample = None
602 | temporal_upsample_rgb = None
603 | temporal_resnet_block = None
604 | temporal_attn = None
605 |
606 | if has_temporal_layers:
607 | temporal_upsample = PixelShuffleTemporalUpsample(dim_in, dim_in)
608 | temporal_upsample_rgb = TemporalUpsample(dim_in, dim_in)
609 |
610 | temporal_resnet_block = block_klass(dim_in, dim_in, conv_type = '1d')
611 | temporal_attn = FullAttention(dim_in, dim_head = self_attn_dim_head, heads = self_attn_heads, depth = layer_temporal_attn_depth)
612 |
613 | self.ups.append(ModuleList([
614 | PixelShuffleUpsample(dim_out, dim_in),
615 | Upsample(),
616 | temporal_upsample,
617 | temporal_upsample_rgb,
618 | nn.Conv2d(dim_in, channels, 1),
619 | block_klass(dim_in + skip_connect_dims.pop(), dim_in),
620 | block_klass(dim_in + skip_connect_dims.pop(), dim_in),
621 | CrossAttentionBlock(dim_in, dim_context = text_encoder.dim, dim_head = self_attn_dim_head, heads = self_attn_heads, ff_mult = cross_ff_mult) if has_cross_attn else None,
622 | attn_klass(dim_in, dim_head = cross_attn_dim_head, heads = self_attn_heads, depth = layer_attn_depth),
623 | temporal_resnet_block,
624 | temporal_attn
625 | ]))
626 |
627 | self.out_dim = default(out_dim, channels)
628 |
629 | self.final_res_block = block_klass(dim, dim)
630 |
631 | self.final_to_rgb = nn.Conv2d(dim, channels, 1)
632 |
633 | # determine the projection of the style embedding to convolutional modulation weights (+ adaptive kernel selection weights) for all layers
634 |
635 | self.style_to_conv_modulations = nn.Linear(style_network.dim, sum(style_embed_split_dims))
636 | self.style_embed_split_dims = style_embed_split_dims
637 |
638 | @property
639 | def allowable_rgb_resolutions(self):
640 | input_res_base = int(log2(self.input_image_size))
641 | output_res_base = int(log2(self.image_size))
642 | allowed_rgb_res_base = list(range(input_res_base, output_res_base))
643 | return [*map(lambda p: 2 ** p, allowed_rgb_res_base)]
644 |
645 | @property
646 | def device(self):
647 | return next(self.parameters()).device
648 |
649 | @property
650 | def total_params(self):
651 | return sum([p.numel() for p in self.parameters()])
652 |
653 | def resize_to_same_dimensions(self, x, size):
654 | mode = 'trilinear' if x.ndim == 5 else 'bilinear'
655 | return F.interpolate(x, tuple(size), mode = mode)
656 |
657 | def forward(
658 | self,
659 | lowres_image_or_video,
660 | styles = None,
661 | noise = None,
662 | texts: List[str] | None = None,
663 | global_text_tokens = None,
664 | fine_text_tokens = None,
665 | text_mask = None,
666 | return_all_rgbs = False,
667 | replace_rgb_with_input_lowres_image = True # discriminator should also receive the low resolution image the upsampler sees
668 | ):
669 | x = lowres_image_or_video
670 | shape = x.shape
671 | batch_size = shape[0]
672 |
673 | assert shape[-2:] == ((self.input_image_size,) * 2)
674 |
675 | # take care of text encodings
676 | # which requires global text tokens to adaptively select the kernels from the main contribution in the paper
677 | # and fine text tokens to attend to using cross attention
678 |
679 | if not self.unconditional:
680 | if exists(texts):
681 | assert exists(self.text_encoder)
682 | global_text_tokens, fine_text_tokens, text_mask = self.text_encoder(texts)
683 | else:
684 | assert all([*map(exists, (global_text_tokens, fine_text_tokens, text_mask))])
685 | else:
686 | assert not any([*map(exists, (texts, global_text_tokens, fine_text_tokens))])
687 |
688 | # styles
689 |
690 | if not exists(styles):
691 | assert exists(self.style_network)
692 |
693 | noise = default(noise, torch.randn((batch_size, self.style_network.dim), device = self.device))
694 | styles = self.style_network(noise, global_text_tokens)
695 |
696 | # project styles to conv modulations
697 |
698 | conv_mods = self.style_to_conv_modulations(styles)
699 | conv_mods = conv_mods.split(self.style_embed_split_dims, dim = -1)
700 | conv_mods = iter(conv_mods)
701 |
702 | # first detect whether input is image or video and handle accordingly
703 |
704 | input_is_video = lowres_image_or_video.ndim == 5
705 | assert not (not self.can_upsample_video and input_is_video), 'this network cannot upsample video unless you set `has_temporal_layers = True`'
706 |
707 | fold_time_into_batch = identity
708 | split_time_from_batch = identity
709 |
710 | if input_is_video:
711 | fold_time_into_batch = lambda t: rearrange(t, 'b c t h w -> (b t) c h w')
712 | split_time_from_batch = lambda t: rearrange(t, '(b t) c h w -> b c t h w', b = batch_size)
713 |
714 | x = fold_time_into_batch(x)
715 |
716 | # set lowres_images for final rgb output
717 |
718 | lowres_images = x
719 |
720 | # initial conv
721 |
722 | x = self.init_conv(x)
723 |
724 | h = []
725 |
726 | # downsample stages
727 |
728 | for (
729 | block1,
730 | block2,
731 | cross_attn,
732 | attn,
733 | temporal_block,
734 | temporal_attn,
735 | downsample,
736 | ) in self.downs:
737 |
738 | x = block1(x, conv_mods_iter = conv_mods)
739 | h.append(x)
740 |
741 | x = block2(x, conv_mods_iter = conv_mods)
742 |
743 | x = attn(x)
744 |
745 | if exists(cross_attn):
746 | x = cross_attn(x, context = fine_text_tokens, mask = text_mask)
747 |
748 | if input_is_video:
749 | x = split_time_from_batch(x)
750 | x, split_space_back = fold_space_into_batch(x)
751 |
752 | x = temporal_block(x, conv_mods_iter = conv_mods)
753 |
754 | x = rearrange(x, 'b c t -> b c t 1')
755 | x = temporal_attn(x)
756 | x = rearrange(x, 'b c t 1 -> b c t')
757 |
758 | x = split_space_back(x)
759 | x = fold_time_into_batch(x)
760 |
761 | elif self.can_upsample_video:
762 | conv_mods = islice(conv_mods, temporal_block.num_mods, None)
763 |
764 | skip_connect = x
765 |
766 | # downsample with hf shuttle
767 |
768 | x = split_time_from_batch(x)
769 |
770 | x, hf_fmap = downsample(x)
771 |
772 | x = fold_time_into_batch(x)
773 | hf_fmap = fold_time_into_batch(hf_fmap)
774 |
775 | # add high freq fmap to skip connection as proposed in videogigagan
776 |
777 | skip_connect = torch.cat((skip_connect, hf_fmap), dim = 1)
778 |
779 | h.append(skip_connect)
780 |
781 | x = self.mid_block1(x, conv_mods_iter = conv_mods)
782 | x = self.mid_attn(x)
783 | x = self.mid_block2(x, conv_mods_iter = conv_mods)
784 |
785 | # rgbs
786 |
787 | rgbs = []
788 |
789 | init_rgb_shape = list(x.shape)
790 | init_rgb_shape[1] = self.channels
791 |
792 | rgb = self.mid_to_rgb(x)
793 | rgbs.append(rgb)
794 |
795 | # upsample stages
796 |
797 | for (
798 | upsample,
799 | upsample_rgb,
800 | temporal_upsample,
801 | temporal_upsample_rgb,
802 | to_rgb,
803 | block1,
804 | block2,
805 | cross_attn,
806 | attn,
807 | temporal_block,
808 | temporal_attn,
809 | ) in self.ups:
810 |
811 | x = upsample(x)
812 | rgb = upsample_rgb(rgb)
813 |
814 | if input_is_video:
815 | x = split_time_from_batch(x)
816 | rgb = split_time_from_batch(rgb)
817 |
818 | x = temporal_upsample(x)
819 | rgb = temporal_upsample_rgb(rgb)
820 |
821 | x = fold_time_into_batch(x)
822 | rgb = fold_time_into_batch(rgb)
823 |
824 | res1 = h.pop() * self.skip_connect_scale
825 | res2 = h.pop() * self.skip_connect_scale
826 |
827 | # handle skip connections not being the same shape
828 |
829 | if x.shape[0] != res1.shape[0] or x.shape[2:] != res1.shape[2:]:
830 | x = split_time_from_batch(x)
831 | res1 = split_time_from_batch(res1)
832 | res2 = split_time_from_batch(res2)
833 |
834 | res1 = self.resize_to_same_dimensions(res1, x.shape[2:])
835 | res2 = self.resize_to_same_dimensions(res2, x.shape[2:])
836 |
837 | x = fold_time_into_batch(x)
838 | res1 = fold_time_into_batch(res1)
839 | res2 = fold_time_into_batch(res2)
840 |
841 | # concat skip connections
842 |
843 | x = torch.cat((x, res1), dim = 1)
844 | x = block1(x, conv_mods_iter = conv_mods)
845 |
846 | x = torch.cat((x, res2), dim = 1)
847 | x = block2(x, conv_mods_iter = conv_mods)
848 |
849 | if exists(cross_attn):
850 | x = cross_attn(x, context = fine_text_tokens, mask = text_mask)
851 |
852 | x = attn(x)
853 |
854 | if input_is_video:
855 | x = split_time_from_batch(x)
856 | x, split_space_back = fold_space_into_batch(x)
857 |
858 | x = temporal_block(x, conv_mods_iter = conv_mods)
859 |
860 | x = rearrange(x, 'b c t -> b c t 1')
861 | x = temporal_attn(x)
862 | x = rearrange(x, 'b c t 1 -> b c t')
863 |
864 | x = split_space_back(x)
865 | x = fold_time_into_batch(x)
866 |
867 | elif self.can_upsample_video:
868 | conv_mods = islice(conv_mods, temporal_block.num_mods, None)
869 |
870 | rgb = rgb + to_rgb(x)
871 | rgbs.append(rgb)
872 |
873 | x = self.final_res_block(x, conv_mods_iter = conv_mods)
874 |
875 | assert len([*conv_mods]) == 0
876 |
877 | rgb = rgb + self.final_to_rgb(x)
878 |
879 | # handle video input
880 |
881 | if input_is_video:
882 | rgb = split_time_from_batch(rgb)
883 |
884 | if not return_all_rgbs:
885 | return rgb
886 |
887 | # only keep those rgbs whose feature map is greater than the input image to be upsampled
888 |
889 | rgbs = list(filter(lambda t: t.shape[-1] > shape[-1], rgbs))
890 |
891 | # and return the original input image as the smallest rgb
892 |
893 | rgbs = [lowres_images, *rgbs]
894 |
895 | if input_is_video:
896 | rgbs = [*map(split_time_from_batch, rgbs)]
897 |
898 | return rgb, rgbs
899 |
--------------------------------------------------------------------------------
/gigagan_pytorch/gigagan_pytorch.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections import namedtuple
4 | from pathlib import Path
5 | from math import log2, sqrt
6 | from random import random
7 | from functools import partial
8 |
9 | from torchvision import utils
10 |
11 | import torch
12 | import torch.nn.functional as F
13 | from torch import nn, einsum, Tensor
14 | from torch.autograd import grad as torch_grad
15 | from torch.utils.data import DataLoader
16 | from torch.cuda.amp import GradScaler
17 |
18 | from beartype import beartype
19 | from beartype.typing import List, Tuple, Dict, Iterable
20 |
21 | from einops import rearrange, pack, unpack, repeat, reduce
22 | from einops.layers.torch import Rearrange, Reduce
23 |
24 | from kornia.filters import filter2d
25 |
26 | from ema_pytorch import EMA
27 |
28 | from gigagan_pytorch.version import __version__
29 | from gigagan_pytorch.open_clip import OpenClipAdapter
30 | from gigagan_pytorch.optimizer import get_optimizer
31 | from gigagan_pytorch.distributed import all_gather
32 |
33 | from tqdm import tqdm
34 |
35 | from numerize import numerize
36 |
37 | from accelerate import Accelerator, DistributedType
38 | from accelerate.utils import DistributedDataParallelKwargs
39 |
40 | # helpers
41 |
42 | def exists(val):
43 | return val is not None
44 |
45 | @beartype
46 | def is_empty(arr: Iterable):
47 | return len(arr) == 0
48 |
49 | def default(*vals):
50 | for val in vals:
51 | if exists(val):
52 | return val
53 | return None
54 |
55 | def cast_tuple(t, length = 1):
56 | return t if isinstance(t, tuple) else ((t,) * length)
57 |
58 | def is_power_of_two(n):
59 | return log2(n).is_integer()
60 |
61 | def safe_unshift(arr):
62 | if len(arr) == 0:
63 | return None
64 | return arr.pop(0)
65 |
66 | def divisible_by(numer, denom):
67 | return (numer % denom) == 0
68 |
69 | def group_by_num_consecutive(arr, num):
70 | out = []
71 | for ind, el in enumerate(arr):
72 | if ind > 0 and divisible_by(ind, num):
73 | yield out
74 | out = []
75 |
76 | out.append(el)
77 |
78 | if len(out) > 0:
79 | yield out
80 |
81 | def is_unique(arr):
82 | return len(set(arr)) == len(arr)
83 |
84 | def cycle(dl):
85 | while True:
86 | for data in dl:
87 | yield data
88 |
89 | def num_to_groups(num, divisor):
90 | groups, remainder = divmod(num, divisor)
91 | arr = [divisor] * groups
92 | if remainder > 0:
93 | arr.append(remainder)
94 | return arr
95 |
96 | def mkdir_if_not_exists(path):
97 | path.mkdir(exist_ok = True, parents = True)
98 |
99 | @beartype
100 | def set_requires_grad_(
101 | m: nn.Module,
102 | requires_grad: bool
103 | ):
104 | for p in m.parameters():
105 | p.requires_grad = requires_grad
106 |
107 | # activation functions
108 |
109 | def leaky_relu(neg_slope = 0.2):
110 | return nn.LeakyReLU(neg_slope)
111 |
112 | def conv2d_3x3(dim_in, dim_out):
113 | return nn.Conv2d(dim_in, dim_out, 3, padding = 1)
114 |
115 | # tensor helpers
116 |
117 | def log(t, eps = 1e-20):
118 | return t.clamp(min = eps).log()
119 |
120 | def gradient_penalty(
121 | images,
122 | outputs,
123 | grad_output_weights = None,
124 | weight = 10,
125 | scaler: GradScaler | None = None,
126 | eps = 1e-4
127 | ):
128 | if not isinstance(outputs, (list, tuple)):
129 | outputs = [outputs]
130 |
131 | if exists(scaler):
132 | outputs = [*map(scaler.scale, outputs)]
133 |
134 | if not exists(grad_output_weights):
135 | grad_output_weights = (1,) * len(outputs)
136 |
137 | maybe_scaled_gradients, *_ = torch_grad(
138 | outputs = outputs,
139 | inputs = images,
140 | grad_outputs = [(torch.ones_like(output) * weight) for output, weight in zip(outputs, grad_output_weights)],
141 | create_graph = True,
142 | retain_graph = True,
143 | only_inputs = True
144 | )
145 |
146 | gradients = maybe_scaled_gradients
147 |
148 | if exists(scaler):
149 | scale = scaler.get_scale()
150 | inv_scale = 1. / max(scale, eps)
151 | gradients = maybe_scaled_gradients * inv_scale
152 |
153 | gradients = rearrange(gradients, 'b ... -> b (...)')
154 | return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
155 |
156 | # hinge gan losses
157 |
158 | def generator_hinge_loss(fake):
159 | return fake.mean()
160 |
161 | def discriminator_hinge_loss(real, fake):
162 | return (F.relu(1 + real) + F.relu(1 - fake)).mean()
163 |
164 | # auxiliary losses
165 |
166 | def aux_matching_loss(real, fake):
167 | """
168 | making logits negative, as in this framework, discriminator is 0 for real, high value for fake. GANs can have this arbitrarily swapped, as it only matters if the generator and discriminator are opposites
169 | """
170 | return (log(1 + (-real).exp()) + log(1 + (-fake).exp())).mean()
171 |
172 | @beartype
173 | def aux_clip_loss(
174 | clip: OpenClipAdapter,
175 | images: Tensor,
176 | texts: List[str] | None = None,
177 | text_embeds: Tensor | None = None
178 | ):
179 | assert exists(texts) ^ exists(text_embeds)
180 |
181 | images, batch_sizes = all_gather(images, 0, None)
182 |
183 | if exists(texts):
184 | text_embeds, _ = clip.embed_texts(texts)
185 | text_embeds, _ = all_gather(text_embeds, 0, batch_sizes)
186 |
187 | return clip.contrastive_loss(images = images, text_embeds = text_embeds)
188 |
189 | # differentiable augmentation - Karras et al. stylegan-ada
190 | # start with horizontal flip
191 |
192 | class DiffAugment(nn.Module):
193 | def __init__(
194 | self,
195 | *,
196 | prob,
197 | horizontal_flip,
198 | horizontal_flip_prob = 0.5
199 | ):
200 | super().__init__()
201 | self.prob = prob
202 | assert 0 <= prob <= 1.
203 |
204 | self.horizontal_flip = horizontal_flip
205 | self.horizontal_flip_prob = horizontal_flip_prob
206 |
207 | def forward(
208 | self,
209 | images,
210 | rgbs: List[Tensor]
211 | ):
212 | if random() >= self.prob:
213 | return images, rgbs
214 |
215 | if random() < self.horizontal_flip_prob:
216 | images = torch.flip(images, (-1,))
217 | rgbs = [torch.flip(rgb, (-1,)) for rgb in rgbs]
218 |
219 | return images, rgbs
220 |
221 | # rmsnorm (newer papers show mean-centering in layernorm not necessary)
222 |
223 | class ChannelRMSNorm(nn.Module):
224 | def __init__(self, dim):
225 | super().__init__()
226 | self.scale = dim ** 0.5
227 | self.gamma = nn.Parameter(torch.ones(dim, 1, 1))
228 |
229 | def forward(self, x):
230 | normed = F.normalize(x, dim = 1)
231 | return normed * self.scale * self.gamma
232 |
233 | class RMSNorm(nn.Module):
234 | def __init__(self, dim):
235 | super().__init__()
236 | self.scale = dim ** 0.5
237 | self.gamma = nn.Parameter(torch.ones(dim))
238 |
239 | def forward(self, x):
240 | normed = F.normalize(x, dim = -1)
241 | return normed * self.scale * self.gamma
242 |
243 | # down and upsample
244 |
245 | class Blur(nn.Module):
246 | def __init__(self):
247 | super().__init__()
248 | f = torch.Tensor([1, 2, 1])
249 | self.register_buffer('f', f)
250 |
251 | def forward(self, x):
252 | f = self.f
253 | f = f[None, None, :] * f[None, :, None]
254 | return filter2d(x, f, normalized = True)
255 |
256 | def Upsample(*args):
257 | return nn.Sequential(
258 | nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False),
259 | Blur()
260 | )
261 |
262 | class PixelShuffleUpsample(nn.Module):
263 | def __init__(self, dim, dim_out = None):
264 | super().__init__()
265 | dim_out = default(dim_out, dim)
266 | conv = nn.Conv2d(dim, dim_out * 4, 1)
267 |
268 | self.net = nn.Sequential(
269 | conv,
270 | nn.SiLU(),
271 | nn.PixelShuffle(2)
272 | )
273 |
274 | self.init_conv_(conv)
275 |
276 | def init_conv_(self, conv):
277 | o, i, h, w = conv.weight.shape
278 | conv_weight = torch.empty(o // 4, i, h, w)
279 | nn.init.kaiming_uniform_(conv_weight)
280 | conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
281 |
282 | conv.weight.data.copy_(conv_weight)
283 | nn.init.zeros_(conv.bias.data)
284 |
285 | def forward(self, x):
286 | return self.net(x)
287 |
288 | def Downsample(dim):
289 | return nn.Sequential(
290 | Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
291 | nn.Conv2d(dim * 4, dim, 1)
292 | )
293 |
294 | # skip layer excitation
295 |
296 | def SqueezeExcite(dim, dim_out, reduction = 4, dim_min = 32):
297 | dim_hidden = max(dim_out // reduction, dim_min)
298 |
299 | return nn.Sequential(
300 | Reduce('b c h w -> b c', 'mean'),
301 | nn.Linear(dim, dim_hidden),
302 | nn.SiLU(),
303 | nn.Linear(dim_hidden, dim_out),
304 | nn.Sigmoid(),
305 | Rearrange('b c -> b c 1 1')
306 | )
307 |
308 | # adaptive conv
309 | # the main novelty of the paper - they propose to learn a softmax weighted sum of N convolutional kernels, depending on the text embedding
310 |
311 | def get_same_padding(size, kernel, dilation, stride):
312 | return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
313 |
314 | class AdaptiveConv2DMod(nn.Module):
315 | def __init__(
316 | self,
317 | dim,
318 | dim_out,
319 | kernel,
320 | *,
321 | demod = True,
322 | stride = 1,
323 | dilation = 1,
324 | eps = 1e-8,
325 | num_conv_kernels = 1 # set this to be greater than 1 for adaptive
326 | ):
327 | super().__init__()
328 | self.eps = eps
329 |
330 | self.dim_out = dim_out
331 |
332 | self.kernel = kernel
333 | self.stride = stride
334 | self.dilation = dilation
335 | self.adaptive = num_conv_kernels > 1
336 |
337 | self.weights = nn.Parameter(torch.randn((num_conv_kernels, dim_out, dim, kernel, kernel)))
338 |
339 | self.demod = demod
340 |
341 | nn.init.kaiming_normal_(self.weights, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu')
342 |
343 | def forward(
344 | self,
345 | fmap,
346 | mod: Tensor,
347 | kernel_mod: Tensor | None = None
348 | ):
349 | """
350 | notation
351 |
352 | b - batch
353 | n - convs
354 | o - output
355 | i - input
356 | k - kernel
357 | """
358 |
359 | b, h = fmap.shape[0], fmap.shape[-2]
360 |
361 | # account for feature map that has been expanded by the scale in the first dimension
362 | # due to multiscale inputs and outputs
363 |
364 | if mod.shape[0] != b:
365 | mod = repeat(mod, 'b ... -> (s b) ...', s = b // mod.shape[0])
366 |
367 | if exists(kernel_mod):
368 | kernel_mod_has_el = kernel_mod.numel() > 0
369 |
370 | assert self.adaptive or not kernel_mod_has_el
371 |
372 | if kernel_mod_has_el and kernel_mod.shape[0] != b:
373 | kernel_mod = repeat(kernel_mod, 'b ... -> (s b) ...', s = b // kernel_mod.shape[0])
374 |
375 | # prepare weights for modulation
376 |
377 | weights = self.weights
378 |
379 | if self.adaptive:
380 | weights = repeat(weights, '... -> b ...', b = b)
381 |
382 | # determine an adaptive weight and 'select' the kernel to use with softmax
383 |
384 | assert exists(kernel_mod) and kernel_mod.numel() > 0
385 |
386 | kernel_attn = kernel_mod.softmax(dim = -1)
387 | kernel_attn = rearrange(kernel_attn, 'b n -> b n 1 1 1 1')
388 |
389 | weights = reduce(weights * kernel_attn, 'b n ... -> b ...', 'sum')
390 |
391 | # do the modulation, demodulation, as done in stylegan2
392 |
393 | mod = rearrange(mod, 'b i -> b 1 i 1 1')
394 |
395 | weights = weights * (mod + 1)
396 |
397 | if self.demod:
398 | inv_norm = reduce(weights ** 2, 'b o i k1 k2 -> b o 1 1 1', 'sum').clamp(min = self.eps).rsqrt()
399 | weights = weights * inv_norm
400 |
401 | fmap = rearrange(fmap, 'b c h w -> 1 (b c) h w')
402 |
403 | weights = rearrange(weights, 'b o ... -> (b o) ...')
404 |
405 | padding = get_same_padding(h, self.kernel, self.dilation, self.stride)
406 | fmap = F.conv2d(fmap, weights, padding = padding, groups = b)
407 |
408 | return rearrange(fmap, '1 (b o) ... -> b o ...', b = b)
409 |
410 | class AdaptiveConv1DMod(nn.Module):
411 | """ 1d version of adaptive conv, for time dimension in videogigagan """
412 |
413 | def __init__(
414 | self,
415 | dim,
416 | dim_out,
417 | kernel,
418 | *,
419 | demod = True,
420 | stride = 1,
421 | dilation = 1,
422 | eps = 1e-8,
423 | num_conv_kernels = 1 # set this to be greater than 1 for adaptive
424 | ):
425 | super().__init__()
426 | self.eps = eps
427 |
428 | self.dim_out = dim_out
429 |
430 | self.kernel = kernel
431 | self.stride = stride
432 | self.dilation = dilation
433 | self.adaptive = num_conv_kernels > 1
434 |
435 | self.weights = nn.Parameter(torch.randn((num_conv_kernels, dim_out, dim, kernel)))
436 |
437 | self.demod = demod
438 |
439 | nn.init.kaiming_normal_(self.weights, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu')
440 |
441 | def forward(
442 | self,
443 | fmap,
444 | mod: Tensor,
445 | kernel_mod: Tensor | None = None
446 | ):
447 | """
448 | notation
449 |
450 | b - batch
451 | n - convs
452 | o - output
453 | i - input
454 | k - kernel
455 | """
456 |
457 | b, t = fmap.shape[0], fmap.shape[-1]
458 |
459 | # account for feature map that has been expanded by the scale in the first dimension
460 | # due to multiscale inputs and outputs
461 |
462 | if mod.shape[0] != b:
463 | mod = repeat(mod, 'b ... -> (s b) ...', s = b // mod.shape[0])
464 |
465 | if exists(kernel_mod):
466 | kernel_mod_has_el = kernel_mod.numel() > 0
467 |
468 | assert self.adaptive or not kernel_mod_has_el
469 |
470 | if kernel_mod_has_el and kernel_mod.shape[0] != b:
471 | kernel_mod = repeat(kernel_mod, 'b ... -> (s b) ...', s = b // kernel_mod.shape[0])
472 |
473 | # prepare weights for modulation
474 |
475 | weights = self.weights
476 |
477 | if self.adaptive:
478 | weights = repeat(weights, '... -> b ...', b = b)
479 |
480 | # determine an adaptive weight and 'select' the kernel to use with softmax
481 |
482 | assert exists(kernel_mod) and kernel_mod.numel() > 0
483 |
484 | kernel_attn = kernel_mod.softmax(dim = -1)
485 | kernel_attn = rearrange(kernel_attn, 'b n -> b n 1 1 1')
486 |
487 | weights = reduce(weights * kernel_attn, 'b n ... -> b ...', 'sum')
488 |
489 | # do the modulation, demodulation, as done in stylegan2
490 |
491 | mod = rearrange(mod, 'b i -> b 1 i 1')
492 |
493 | weights = weights * (mod + 1)
494 |
495 | if self.demod:
496 | inv_norm = reduce(weights ** 2, 'b o i k -> b o 1 1', 'sum').clamp(min = self.eps).rsqrt()
497 | weights = weights * inv_norm
498 |
499 | fmap = rearrange(fmap, 'b c t -> 1 (b c) t')
500 |
501 | weights = rearrange(weights, 'b o ... -> (b o) ...')
502 |
503 | padding = get_same_padding(t, self.kernel, self.dilation, self.stride)
504 | fmap = F.conv1d(fmap, weights, padding = padding, groups = b)
505 |
506 | return rearrange(fmap, '1 (b o) ... -> b o ...', b = b)
507 |
508 | # attention
509 | # they use an attention with a better Lipchitz constant - l2 distance similarity instead of dot product - also shared query / key space - shown in vitgan to be more stable
510 | # not sure what they did about token attention to self, so masking out, as done in some other papers using shared query / key space
511 |
512 | class SelfAttention(nn.Module):
513 | def __init__(
514 | self,
515 | dim,
516 | dim_head = 64,
517 | heads = 8,
518 | dot_product = False
519 | ):
520 | super().__init__()
521 | self.heads = heads
522 | self.scale = dim_head ** -0.5
523 | dim_inner = dim_head * heads
524 |
525 | self.dot_product = dot_product
526 |
527 | self.norm = ChannelRMSNorm(dim)
528 |
529 | self.to_q = nn.Conv2d(dim, dim_inner, 1, bias = False)
530 | self.to_k = nn.Conv2d(dim, dim_inner, 1, bias = False) if dot_product else None
531 | self.to_v = nn.Conv2d(dim, dim_inner, 1, bias = False)
532 |
533 | self.null_kv = nn.Parameter(torch.randn(2, heads, dim_head))
534 |
535 | self.to_out = nn.Conv2d(dim_inner, dim, 1, bias = False)
536 |
537 | def forward(self, fmap):
538 | """
539 | einstein notation
540 |
541 | b - batch
542 | h - heads
543 | x - height
544 | y - width
545 | d - dimension
546 | i - source seq (attend from)
547 | j - target seq (attend to)
548 | """
549 | batch = fmap.shape[0]
550 |
551 | fmap = self.norm(fmap)
552 |
553 | x, y = fmap.shape[-2:]
554 |
555 | h = self.heads
556 |
557 | q, v = self.to_q(fmap), self.to_v(fmap)
558 |
559 | k = self.to_k(fmap) if exists(self.to_k) else q
560 |
561 | q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = self.heads), (q, k, v))
562 |
563 | # add a null key / value, so network can choose to pay attention to nothing
564 |
565 | nk, nv = map(lambda t: repeat(t, 'h d -> (b h) 1 d', b = batch), self.null_kv)
566 |
567 | k = torch.cat((nk, k), dim = -2)
568 | v = torch.cat((nv, v), dim = -2)
569 |
570 | # l2 distance or dot product
571 |
572 | if self.dot_product:
573 | sim = einsum('b i d, b j d -> b i j', q, k)
574 | else:
575 | # using pytorch cdist leads to nans in lightweight gan training framework, at least
576 | q_squared = (q * q).sum(dim = -1)
577 | k_squared = (k * k).sum(dim = -1)
578 | l2dist_squared = rearrange(q_squared, 'b i -> b i 1') + rearrange(k_squared, 'b j -> b 1 j') - 2 * einsum('b i d, b j d -> b i j', q, k) # hope i'm mathing right
579 | sim = -l2dist_squared
580 |
581 | # scale
582 |
583 | sim = sim * self.scale
584 |
585 | # attention
586 |
587 | attn = sim.softmax(dim = -1)
588 |
589 | out = einsum('b i j, b j d -> b i d', attn, v)
590 |
591 | out = rearrange(out, '(b h) (x y) d -> b (h d) x y', x = x, y = y, h = h)
592 |
593 | return self.to_out(out)
594 |
595 | class CrossAttention(nn.Module):
596 | def __init__(
597 | self,
598 | dim,
599 | dim_context,
600 | dim_head = 64,
601 | heads = 8
602 | ):
603 | super().__init__()
604 | self.heads = heads
605 | self.scale = dim_head ** -0.5
606 | dim_inner = dim_head * heads
607 | kv_input_dim = default(dim_context, dim)
608 |
609 | self.norm = ChannelRMSNorm(dim)
610 | self.norm_context = RMSNorm(kv_input_dim)
611 |
612 | self.to_q = nn.Conv2d(dim, dim_inner, 1, bias = False)
613 | self.to_kv = nn.Linear(kv_input_dim, dim_inner * 2, bias = False)
614 | self.to_out = nn.Conv2d(dim_inner, dim, 1, bias = False)
615 |
616 | def forward(self, fmap, context, mask = None):
617 | """
618 | einstein notation
619 |
620 | b - batch
621 | h - heads
622 | x - height
623 | y - width
624 | d - dimension
625 | i - source seq (attend from)
626 | j - target seq (attend to)
627 | """
628 |
629 | fmap = self.norm(fmap)
630 | context = self.norm_context(context)
631 |
632 | x, y = fmap.shape[-2:]
633 |
634 | h = self.heads
635 |
636 | q, k, v = (self.to_q(fmap), *self.to_kv(context).chunk(2, dim = -1))
637 |
638 | k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (k, v))
639 |
640 | q = rearrange(q, 'b (h d) x y -> (b h) (x y) d', h = self.heads)
641 |
642 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
643 |
644 | if exists(mask):
645 | mask = repeat(mask, 'b j -> (b h) 1 j', h = self.heads)
646 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
647 |
648 | attn = sim.softmax(dim = -1)
649 |
650 | out = einsum('b i j, b j d -> b i d', attn, v)
651 |
652 | out = rearrange(out, '(b h) (x y) d -> b (h d) x y', x = x, y = y, h = h)
653 |
654 | return self.to_out(out)
655 |
656 | # classic transformer attention, stick with l2 distance
657 |
658 | class TextAttention(nn.Module):
659 | def __init__(
660 | self,
661 | dim,
662 | dim_head = 64,
663 | heads = 8
664 | ):
665 | super().__init__()
666 | self.heads = heads
667 | self.scale = dim_head ** -0.5
668 | dim_inner = dim_head * heads
669 |
670 | self.norm = RMSNorm(dim)
671 | self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
672 |
673 | self.null_kv = nn.Parameter(torch.randn(2, heads, dim_head))
674 |
675 | self.to_out = nn.Linear(dim_inner, dim, bias = False)
676 |
677 | def forward(self, encodings, mask = None):
678 | """
679 | einstein notation
680 |
681 | b - batch
682 | h - heads
683 | x - height
684 | y - width
685 | d - dimension
686 | i - source seq (attend from)
687 | j - target seq (attend to)
688 | """
689 | batch = encodings.shape[0]
690 |
691 | encodings = self.norm(encodings)
692 |
693 | h = self.heads
694 |
695 | q, k, v = self.to_qkv(encodings).chunk(3, dim = -1)
696 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v))
697 |
698 | # add a null key / value, so network can choose to pay attention to nothing
699 |
700 | nk, nv = map(lambda t: repeat(t, 'h d -> (b h) 1 d', b = batch), self.null_kv)
701 |
702 | k = torch.cat((nk, k), dim = -2)
703 | v = torch.cat((nv, v), dim = -2)
704 |
705 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
706 |
707 | # key padding mask
708 |
709 | if exists(mask):
710 | mask = F.pad(mask, (1, 0), value = True)
711 | mask = repeat(mask, 'b n -> (b h) 1 n', h = h)
712 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
713 |
714 | # attention
715 |
716 | attn = sim.softmax(dim = -1)
717 | out = einsum('b i j, b j d -> b i d', attn, v)
718 |
719 | out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
720 |
721 | return self.to_out(out)
722 |
723 | # feedforward
724 |
725 | def FeedForward(
726 | dim,
727 | mult = 4,
728 | channel_first = False
729 | ):
730 | dim_hidden = int(dim * mult)
731 | norm_klass = ChannelRMSNorm if channel_first else RMSNorm
732 | proj = partial(nn.Conv2d, kernel_size = 1) if channel_first else nn.Linear
733 |
734 | return nn.Sequential(
735 | norm_klass(dim),
736 | proj(dim, dim_hidden),
737 | nn.GELU(),
738 | proj(dim_hidden, dim)
739 | )
740 |
741 | # different types of transformer blocks or transformers (multiple blocks)
742 |
743 | class SelfAttentionBlock(nn.Module):
744 | def __init__(
745 | self,
746 | dim,
747 | dim_head = 64,
748 | heads = 8,
749 | ff_mult = 4,
750 | dot_product = False
751 | ):
752 | super().__init__()
753 | self.attn = SelfAttention(dim = dim, dim_head = dim_head, heads = heads, dot_product = dot_product)
754 | self.ff = FeedForward(dim = dim, mult = ff_mult, channel_first = True)
755 |
756 | def forward(self, x):
757 | x = self.attn(x) + x
758 | x = self.ff(x) + x
759 | return x
760 |
761 | class CrossAttentionBlock(nn.Module):
762 | def __init__(
763 | self,
764 | dim,
765 | dim_context,
766 | dim_head = 64,
767 | heads = 8,
768 | ff_mult = 4
769 | ):
770 | super().__init__()
771 | self.attn = CrossAttention(dim = dim, dim_context = dim_context, dim_head = dim_head, heads = heads)
772 | self.ff = FeedForward(dim = dim, mult = ff_mult, channel_first = True)
773 |
774 | def forward(self, x, context, mask = None):
775 | x = self.attn(x, context = context, mask = mask) + x
776 | x = self.ff(x) + x
777 | return x
778 |
779 | class Transformer(nn.Module):
780 | def __init__(
781 | self,
782 | dim,
783 | depth,
784 | dim_head = 64,
785 | heads = 8,
786 | ff_mult = 4
787 | ):
788 | super().__init__()
789 | self.layers = nn.ModuleList([])
790 | for _ in range(depth):
791 | self.layers.append(nn.ModuleList([
792 | TextAttention(dim = dim, dim_head = dim_head, heads = heads),
793 | FeedForward(dim = dim, mult = ff_mult)
794 | ]))
795 |
796 | self.norm = RMSNorm(dim)
797 |
798 | def forward(self, x, mask = None):
799 | for attn, ff in self.layers:
800 | x = attn(x, mask = mask) + x
801 | x = ff(x) + x
802 |
803 | return self.norm(x)
804 |
805 | # text encoder
806 |
807 | class TextEncoder(nn.Module):
808 | @beartype
809 | def __init__(
810 | self,
811 | *,
812 | dim,
813 | depth,
814 | clip: OpenClipAdapter | None = None,
815 | dim_head = 64,
816 | heads = 8,
817 | ):
818 | super().__init__()
819 | self.dim = dim
820 |
821 | if not exists(clip):
822 | clip = OpenClipAdapter()
823 |
824 | self.clip = clip
825 | set_requires_grad_(clip, False)
826 |
827 | self.learned_global_token = nn.Parameter(torch.randn(dim))
828 |
829 | self.project_in = nn.Linear(clip.dim_latent, dim) if clip.dim_latent != dim else nn.Identity()
830 |
831 | self.transformer = Transformer(
832 | dim = dim,
833 | depth = depth,
834 | dim_head = dim_head,
835 | heads = heads
836 | )
837 |
838 | @beartype
839 | def forward(
840 | self,
841 | texts: List[str] | None = None,
842 | text_encodings: Tensor | None = None
843 | ):
844 | assert exists(texts) ^ exists(text_encodings)
845 |
846 | if not exists(text_encodings):
847 | with torch.no_grad():
848 | self.clip.eval()
849 | _, text_encodings = self.clip.embed_texts(texts)
850 |
851 | mask = (text_encodings != 0.).any(dim = -1)
852 |
853 | text_encodings = self.project_in(text_encodings)
854 |
855 | mask_with_global = F.pad(mask, (1, 0), value = True)
856 |
857 | batch = text_encodings.shape[0]
858 | global_tokens = repeat(self.learned_global_token, 'd -> b d', b = batch)
859 |
860 | text_encodings, ps = pack([global_tokens, text_encodings], 'b * d')
861 |
862 | text_encodings = self.transformer(text_encodings, mask = mask_with_global)
863 |
864 | global_tokens, text_encodings = unpack(text_encodings, ps, 'b * d')
865 |
866 | return global_tokens, text_encodings, mask
867 |
868 | # style mapping network
869 |
870 | class EqualLinear(nn.Module):
871 | def __init__(
872 | self,
873 | dim,
874 | dim_out,
875 | lr_mul = 1,
876 | bias = True
877 | ):
878 | super().__init__()
879 | self.weight = nn.Parameter(torch.randn(dim_out, dim))
880 | if bias:
881 | self.bias = nn.Parameter(torch.zeros(dim_out))
882 |
883 | self.lr_mul = lr_mul
884 |
885 | def forward(self, input):
886 | return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)
887 |
888 | class StyleNetwork(nn.Module):
889 | def __init__(
890 | self,
891 | dim,
892 | depth,
893 | lr_mul = 0.1,
894 | dim_text_latent = 0
895 | ):
896 | super().__init__()
897 | self.dim = dim
898 | self.dim_text_latent = dim_text_latent
899 |
900 | layers = []
901 | for i in range(depth):
902 | is_first = i == 0
903 | dim_in = (dim + dim_text_latent) if is_first else dim
904 |
905 | layers.extend([EqualLinear(dim_in, dim, lr_mul), leaky_relu()])
906 |
907 | self.net = nn.Sequential(*layers)
908 |
909 | def forward(
910 | self,
911 | x,
912 | text_latent = None
913 | ):
914 | x = F.normalize(x, dim = 1)
915 |
916 | if self.dim_text_latent > 0:
917 | assert exists(text_latent)
918 | x = torch.cat((x, text_latent), dim = -1)
919 |
920 | return self.net(x)
921 |
922 | # noise
923 |
924 | class Noise(nn.Module):
925 | def __init__(self, dim):
926 | super().__init__()
927 | self.weight = nn.Parameter(torch.zeros(dim, 1, 1))
928 |
929 | def forward(
930 | self,
931 | x,
932 | noise = None
933 | ):
934 | b, _, h, w, device = *x.shape, x.device
935 |
936 | if not exists(noise):
937 | noise = torch.randn(b, 1, h, w, device = device)
938 |
939 | return x + self.weight * noise
940 |
941 | # generator
942 |
943 | class BaseGenerator(nn.Module):
944 | pass
945 |
946 | class Generator(BaseGenerator):
947 | @beartype
948 | def __init__(
949 | self,
950 | *,
951 | image_size,
952 | dim_capacity = 16,
953 | dim_max = 2048,
954 | channels = 3,
955 | style_network: StyleNetwork | Dict | None = None,
956 | style_network_dim = None,
957 | text_encoder: TextEncoder | Dict | None = None,
958 | dim_latent = 512,
959 | self_attn_resolutions: Tuple[int, ...] = (32, 16),
960 | self_attn_dim_head = 64,
961 | self_attn_heads = 8,
962 | self_attn_dot_product = True,
963 | self_attn_ff_mult = 4,
964 | cross_attn_resolutions: Tuple[int, ...] = (32, 16),
965 | cross_attn_dim_head = 64,
966 | cross_attn_heads = 8,
967 | cross_attn_ff_mult = 4,
968 | num_conv_kernels = 2, # the number of adaptive conv kernels
969 | num_skip_layers_excite = 0,
970 | unconditional = False,
971 | pixel_shuffle_upsample = False
972 | ):
973 | super().__init__()
974 | self.channels = channels
975 |
976 | if isinstance(style_network, dict):
977 | style_network = StyleNetwork(**style_network)
978 |
979 | self.style_network = style_network
980 |
981 | assert exists(style_network) ^ exists(style_network_dim), 'style_network_dim must be given to the generator if StyleNetwork not passed in as style_network'
982 |
983 | if not exists(style_network_dim):
984 | style_network_dim = style_network.dim
985 |
986 | self.style_network_dim = style_network_dim
987 |
988 | if isinstance(text_encoder, dict):
989 | text_encoder = TextEncoder(**text_encoder)
990 |
991 | self.text_encoder = text_encoder
992 |
993 | self.unconditional = unconditional
994 |
995 | assert not (unconditional and exists(text_encoder))
996 | assert not (unconditional and exists(style_network) and style_network.dim_text_latent > 0)
997 | assert unconditional or (exists(text_encoder) and text_encoder.dim == style_network.dim_text_latent), 'the `dim_text_latent` on your StyleNetwork must be equal to the `dim` set for the TextEncoder'
998 |
999 | assert is_power_of_two(image_size)
1000 | num_layers = int(log2(image_size) - 1)
1001 | self.num_layers = num_layers
1002 |
1003 | # generator requires convolutions conditioned by the style vector
1004 | # and also has N convolutional kernels adaptively selected (one of the only novelties of the paper)
1005 |
1006 | is_adaptive = num_conv_kernels > 1
1007 | dim_kernel_mod = num_conv_kernels if is_adaptive else 0
1008 |
1009 | style_embed_split_dims = []
1010 |
1011 | adaptive_conv = partial(AdaptiveConv2DMod, kernel = 3, num_conv_kernels = num_conv_kernels)
1012 |
1013 | # initial 4x4 block and conv
1014 |
1015 | self.init_block = nn.Parameter(torch.randn(dim_latent, 4, 4))
1016 | self.init_conv = adaptive_conv(dim_latent, dim_latent)
1017 |
1018 | style_embed_split_dims.extend([
1019 | dim_latent,
1020 | dim_kernel_mod
1021 | ])
1022 |
1023 | # main network
1024 |
1025 | num_layers = int(log2(image_size) - 1)
1026 | self.num_layers = num_layers
1027 |
1028 | resolutions = image_size / ((2 ** torch.arange(num_layers).flip(0)))
1029 | resolutions = resolutions.long().tolist()
1030 |
1031 | dim_layers = (2 ** (torch.arange(num_layers) + 1)) * dim_capacity
1032 | dim_layers.clamp_(max = dim_max)
1033 |
1034 | dim_layers = torch.flip(dim_layers, (0,))
1035 | dim_layers = F.pad(dim_layers, (1, 0), value = dim_latent)
1036 |
1037 | dim_layers = dim_layers.tolist()
1038 |
1039 | dim_pairs = list(zip(dim_layers[:-1], dim_layers[1:]))
1040 |
1041 | self.num_skip_layers_excite = num_skip_layers_excite
1042 |
1043 | self.layers = nn.ModuleList([])
1044 |
1045 | # go through layers and construct all parameters
1046 |
1047 | for ind, ((dim_in, dim_out), resolution) in enumerate(zip(dim_pairs, resolutions)):
1048 | is_last = (ind + 1) == len(dim_pairs)
1049 | is_first = ind == 0
1050 |
1051 | should_upsample = not is_first
1052 | should_upsample_rgb = not is_last
1053 | should_skip_layer_excite = num_skip_layers_excite > 0 and (ind + num_skip_layers_excite) < len(dim_pairs)
1054 |
1055 | has_self_attn = resolution in self_attn_resolutions
1056 | has_cross_attn = resolution in cross_attn_resolutions and not unconditional
1057 |
1058 | skip_squeeze_excite = None
1059 | if should_skip_layer_excite:
1060 | dim_skip_in, _ = dim_pairs[ind + num_skip_layers_excite]
1061 | skip_squeeze_excite = SqueezeExcite(dim_in, dim_skip_in)
1062 |
1063 | resnet_block = nn.ModuleList([
1064 | adaptive_conv(dim_in, dim_out),
1065 | Noise(dim_out),
1066 | leaky_relu(),
1067 | adaptive_conv(dim_out, dim_out),
1068 | Noise(dim_out),
1069 | leaky_relu()
1070 | ])
1071 |
1072 | to_rgb = AdaptiveConv2DMod(dim_out, channels, 1, num_conv_kernels = 1, demod = False)
1073 |
1074 | self_attn = cross_attn = rgb_upsample = upsample = None
1075 |
1076 | upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample
1077 |
1078 | upsample = upsample_klass(dim_in) if should_upsample else None
1079 | rgb_upsample = upsample_klass(channels) if should_upsample_rgb else None
1080 |
1081 | if has_self_attn:
1082 | self_attn = SelfAttentionBlock(
1083 | dim_out,
1084 | dim_head = self_attn_dim_head,
1085 | heads = self_attn_heads,
1086 | ff_mult = self_attn_ff_mult,
1087 | dot_product = self_attn_dot_product
1088 | )
1089 |
1090 | if has_cross_attn:
1091 | cross_attn = CrossAttentionBlock(
1092 | dim_out,
1093 | dim_context = text_encoder.dim,
1094 | dim_head = cross_attn_dim_head,
1095 | heads = cross_attn_heads,
1096 | ff_mult = cross_attn_ff_mult,
1097 | )
1098 |
1099 | style_embed_split_dims.extend([
1100 | dim_in, # for first conv in resnet block
1101 | dim_kernel_mod, # first conv kernel selection
1102 | dim_out, # second conv in resnet block
1103 | dim_kernel_mod, # second conv kernel selection
1104 | dim_out, # to RGB conv
1105 | 0, # RGB conv kernel selection
1106 | ])
1107 |
1108 | self.layers.append(nn.ModuleList([
1109 | skip_squeeze_excite,
1110 | resnet_block,
1111 | to_rgb,
1112 | self_attn,
1113 | cross_attn,
1114 | upsample,
1115 | rgb_upsample
1116 | ]))
1117 |
1118 | # determine the projection of the style embedding to convolutional modulation weights (+ adaptive kernel selection weights) for all layers
1119 |
1120 | self.style_to_conv_modulations = nn.Linear(style_network_dim, sum(style_embed_split_dims))
1121 | self.style_embed_split_dims = style_embed_split_dims
1122 |
1123 | self.apply(self.init_)
1124 | nn.init.normal_(self.init_block, std = 0.02)
1125 |
1126 | def init_(self, m):
1127 | if type(m) in {nn.Conv2d, nn.Linear}:
1128 | nn.init.kaiming_normal_(m.weight, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu')
1129 |
1130 | @property
1131 | def total_params(self):
1132 | return sum([p.numel() for p in self.parameters() if p.requires_grad])
1133 |
1134 | @property
1135 | def device(self):
1136 | return next(self.parameters()).device
1137 |
1138 | @beartype
1139 | def forward(
1140 | self,
1141 | styles = None,
1142 | noise = None,
1143 | texts: List[str] | None = None,
1144 | text_encodings: Tensor | None = None,
1145 | global_text_tokens = None,
1146 | fine_text_tokens = None,
1147 | text_mask = None,
1148 | batch_size = 1,
1149 | return_all_rgbs = False
1150 | ):
1151 | # take care of text encodings
1152 | # which requires global text tokens to adaptively select the kernels from the main contribution in the paper
1153 | # and fine text tokens to attend to using cross attention
1154 |
1155 | if not self.unconditional:
1156 | if exists(texts) or exists(text_encodings):
1157 | assert exists(texts) ^ exists(text_encodings), 'either raw texts as List[str] or text_encodings (from clip) as Tensor is passed in, but not both'
1158 | assert exists(self.text_encoder)
1159 |
1160 | if exists(texts):
1161 | text_encoder_kwargs = dict(texts = texts)
1162 | elif exists(text_encodings):
1163 | text_encoder_kwargs = dict(text_encodings = text_encodings)
1164 |
1165 | global_text_tokens, fine_text_tokens, text_mask = self.text_encoder(**text_encoder_kwargs)
1166 | else:
1167 | assert all([*map(exists, (global_text_tokens, fine_text_tokens, text_mask))]), 'raw text or text embeddings were not passed in for conditional training'
1168 | else:
1169 | assert not any([*map(exists, (texts, global_text_tokens, fine_text_tokens))])
1170 |
1171 | # determine styles
1172 |
1173 | if not exists(styles):
1174 | assert exists(self.style_network)
1175 |
1176 | if not exists(noise):
1177 | noise = torch.randn((batch_size, self.style_network_dim), device = self.device)
1178 |
1179 | styles = self.style_network(noise, global_text_tokens)
1180 |
1181 | # project styles to conv modulations
1182 |
1183 | conv_mods = self.style_to_conv_modulations(styles)
1184 | conv_mods = conv_mods.split(self.style_embed_split_dims, dim = -1)
1185 | conv_mods = iter(conv_mods)
1186 |
1187 | # prepare initial block
1188 |
1189 | batch_size = styles.shape[0]
1190 |
1191 | x = repeat(self.init_block, 'c h w -> b c h w', b = batch_size)
1192 | x = self.init_conv(x, mod = next(conv_mods), kernel_mod = next(conv_mods))
1193 |
1194 | rgb = torch.zeros((batch_size, self.channels, 4, 4), device = self.device, dtype = x.dtype)
1195 |
1196 | # skip layer squeeze excitations
1197 |
1198 | excitations = [None] * self.num_skip_layers_excite
1199 |
1200 | # all the rgb's of each layer of the generator is to be saved for multi-resolution input discrimination
1201 |
1202 | rgbs = []
1203 |
1204 | # main network
1205 |
1206 | for squeeze_excite, (resnet_conv1, noise1, act1, resnet_conv2, noise2, act2), to_rgb_conv, self_attn, cross_attn, upsample, upsample_rgb in self.layers:
1207 |
1208 | if exists(upsample):
1209 | x = upsample(x)
1210 |
1211 | if exists(squeeze_excite):
1212 | skip_excite = squeeze_excite(x)
1213 | excitations.append(skip_excite)
1214 |
1215 | excite = safe_unshift(excitations)
1216 | if exists(excite):
1217 | x = x * excite
1218 |
1219 | x = resnet_conv1(x, mod = next(conv_mods), kernel_mod = next(conv_mods))
1220 | x = noise1(x)
1221 | x = act1(x)
1222 |
1223 | x = resnet_conv2(x, mod = next(conv_mods), kernel_mod = next(conv_mods))
1224 | x = noise2(x)
1225 | x = act2(x)
1226 |
1227 | if exists(self_attn):
1228 | x = self_attn(x)
1229 |
1230 | if exists(cross_attn):
1231 | x = cross_attn(x, context = fine_text_tokens, mask = text_mask)
1232 |
1233 | layer_rgb = to_rgb_conv(x, mod = next(conv_mods), kernel_mod = next(conv_mods))
1234 |
1235 | rgb = rgb + layer_rgb
1236 |
1237 | rgbs.append(rgb)
1238 |
1239 | if exists(upsample_rgb):
1240 | rgb = upsample_rgb(rgb)
1241 |
1242 | # sanity check
1243 |
1244 | assert is_empty([*conv_mods]), 'convolutions were incorrectly modulated'
1245 |
1246 | if return_all_rgbs:
1247 | return rgb, rgbs
1248 |
1249 | return rgb
1250 |
1251 | # discriminator
1252 |
1253 | @beartype
1254 | class SimpleDecoder(nn.Module):
1255 | def __init__(
1256 | self,
1257 | dim,
1258 | *,
1259 | dims: Tuple[int, ...],
1260 | patch_dim: int = 1,
1261 | frac_patches: float = 1.,
1262 | dropout: float = 0.5
1263 | ):
1264 | super().__init__()
1265 | assert 0 < frac_patches <= 1.
1266 |
1267 | self.patch_dim = patch_dim
1268 | self.frac_patches = frac_patches
1269 |
1270 | self.dropout = nn.Dropout(dropout)
1271 |
1272 | dims = [dim, *dims]
1273 |
1274 | layers = [conv2d_3x3(dim, dim)]
1275 |
1276 | for dim_in, dim_out in zip(dims[:-1], dims[1:]):
1277 | layers.append(nn.Sequential(
1278 | Upsample(dim_in),
1279 | conv2d_3x3(dim_in, dim_out),
1280 | leaky_relu()
1281 | ))
1282 |
1283 | self.net = nn.Sequential(*layers)
1284 |
1285 | @property
1286 | def device(self):
1287 | return next(self.parameters()).device
1288 |
1289 | def forward(
1290 | self,
1291 | fmap,
1292 | orig_image
1293 | ):
1294 | fmap = self.dropout(fmap)
1295 |
1296 | if self.frac_patches < 1.:
1297 | batch, patch_dim = fmap.shape[0], self.patch_dim
1298 | fmap_size, img_size = fmap.shape[-1], orig_image.shape[-1]
1299 |
1300 | assert divisible_by(fmap_size, patch_dim), f'feature map dimensions are {fmap_size}, but the patch dim was designated to be {patch_dim}'
1301 | assert divisible_by(img_size, patch_dim), f'image size is {img_size} but the patch dim was specified to be {patch_dim}'
1302 |
1303 | fmap, orig_image = map(lambda t: rearrange(t, 'b c (p1 h) (p2 w) -> b (p1 p2) c h w', p1 = patch_dim, p2 = patch_dim), (fmap, orig_image))
1304 |
1305 | total_patches = patch_dim ** 2
1306 | num_patches_recon = max(int(self.frac_patches * total_patches), 1)
1307 |
1308 | batch_arange = torch.arange(batch, device = self.device)[..., None]
1309 | batch_randperm = torch.randn((batch, total_patches)).sort(dim = -1).indices
1310 | patch_indices = batch_randperm[..., :num_patches_recon]
1311 |
1312 | fmap, orig_image = map(lambda t: t[batch_arange, patch_indices], (fmap, orig_image))
1313 | fmap, orig_image = map(lambda t: rearrange(t, 'b p ... -> (b p) ...'), (fmap, orig_image))
1314 |
1315 | recon = self.net(fmap)
1316 | return F.mse_loss(recon, orig_image)
1317 |
1318 | class RandomFixedProjection(nn.Module):
1319 | def __init__(
1320 | self,
1321 | dim,
1322 | dim_out,
1323 | channel_first = True
1324 | ):
1325 | super().__init__()
1326 | weights = torch.randn(dim, dim_out)
1327 | nn.init.kaiming_normal_(weights, mode = 'fan_out', nonlinearity = 'linear')
1328 |
1329 | self.channel_first = channel_first
1330 | self.register_buffer('fixed_weights', weights)
1331 |
1332 | def forward(self, x):
1333 | if not self.channel_first:
1334 | return x @ self.fixed_weights
1335 |
1336 | return einsum('b c ..., c d -> b d ...', x, self.fixed_weights)
1337 |
1338 | class VisionAidedDiscriminator(nn.Module):
1339 | """ the vision-aided gan loss """
1340 |
1341 | @beartype
1342 | def __init__(
1343 | self,
1344 | *,
1345 | depth = 2,
1346 | dim_head = 64,
1347 | heads = 8,
1348 | clip: OpenClipAdapter | None = None,
1349 | layer_indices = (-1, -2, -3),
1350 | conv_dim = None,
1351 | text_dim = None,
1352 | unconditional = False,
1353 | num_conv_kernels = 2
1354 | ):
1355 | super().__init__()
1356 |
1357 | if not exists(clip):
1358 | clip = OpenClipAdapter()
1359 |
1360 | self.clip = clip
1361 | dim = clip._dim_image_latent
1362 |
1363 | self.unconditional = unconditional
1364 | text_dim = default(text_dim, dim)
1365 | conv_dim = default(conv_dim, dim)
1366 |
1367 | self.layer_discriminators = nn.ModuleList([])
1368 | self.layer_indices = layer_indices
1369 |
1370 | conv_klass = partial(AdaptiveConv2DMod, kernel = 3, num_conv_kernels = num_conv_kernels) if not unconditional else conv2d_3x3
1371 |
1372 | for _ in layer_indices:
1373 | self.layer_discriminators.append(nn.ModuleList([
1374 | RandomFixedProjection(dim, conv_dim),
1375 | conv_klass(conv_dim, conv_dim),
1376 | nn.Linear(text_dim, conv_dim) if not unconditional else None,
1377 | nn.Linear(text_dim, num_conv_kernels) if not unconditional else None,
1378 | nn.Sequential(
1379 | conv2d_3x3(conv_dim, 1),
1380 | Rearrange('b 1 ... -> b ...')
1381 | )
1382 | ]))
1383 |
1384 | def parameters(self):
1385 | return self.layer_discriminators.parameters()
1386 |
1387 | @property
1388 | def total_params(self):
1389 | return sum([p.numel() for p in self.parameters()])
1390 |
1391 | @beartype
1392 | def forward(
1393 | self,
1394 | images,
1395 | texts: List[str] | None = None,
1396 | text_embeds: Tensor | None = None,
1397 | return_clip_encodings = False
1398 | ):
1399 |
1400 | assert self.unconditional or (exists(text_embeds) ^ exists(texts))
1401 |
1402 | with torch.no_grad():
1403 | if not self.unconditional and exists(texts):
1404 | self.clip.eval()
1405 | text_embeds = self.clip.embed_texts
1406 |
1407 | _, image_encodings = self.clip.embed_images(images)
1408 |
1409 | logits = []
1410 |
1411 | for layer_index, (rand_proj, conv, to_conv_mod, to_conv_kernel_mod, to_logits) in zip(self.layer_indices, self.layer_discriminators):
1412 | image_encoding = image_encodings[layer_index]
1413 |
1414 | cls_token, rest_tokens = image_encoding[:, :1], image_encoding[:, 1:]
1415 | height_width = int(sqrt(rest_tokens.shape[-2])) # assume square
1416 |
1417 | img_fmap = rearrange(rest_tokens, 'b (h w) d -> b d h w', h = height_width)
1418 |
1419 | img_fmap = img_fmap + rearrange(cls_token, 'b 1 d -> b d 1 1 ') # pool the cls token into the rest of the tokens
1420 |
1421 | img_fmap = rand_proj(img_fmap)
1422 |
1423 | if self.unconditional:
1424 | img_fmap = conv(img_fmap)
1425 | else:
1426 | assert exists(text_embeds)
1427 |
1428 | img_fmap = conv(
1429 | img_fmap,
1430 | mod = to_conv_mod(text_embeds),
1431 | kernel_mod = to_conv_kernel_mod(text_embeds)
1432 | )
1433 |
1434 | layer_logits = to_logits(img_fmap)
1435 |
1436 | logits.append(layer_logits)
1437 |
1438 | if not return_clip_encodings:
1439 | return logits
1440 |
1441 | return logits, image_encodings
1442 |
1443 | class Predictor(nn.Module):
1444 | def __init__(
1445 | self,
1446 | dim,
1447 | depth = 4,
1448 | num_conv_kernels = 2,
1449 | unconditional = False
1450 | ):
1451 | super().__init__()
1452 | self.unconditional = unconditional
1453 | self.residual_fn = nn.Conv2d(dim, dim, 1)
1454 | self.residual_scale = 2 ** -0.5
1455 |
1456 | self.layers = nn.ModuleList([])
1457 |
1458 | klass = nn.Conv2d if unconditional else partial(AdaptiveConv2DMod, num_conv_kernels = num_conv_kernels)
1459 | klass_kwargs = dict(padding = 1) if unconditional else dict()
1460 |
1461 | for ind in range(depth):
1462 | self.layers.append(nn.ModuleList([
1463 | klass(dim, dim, 3, **klass_kwargs),
1464 | leaky_relu(),
1465 | klass(dim, dim, 3, **klass_kwargs),
1466 | leaky_relu()
1467 | ]))
1468 |
1469 | self.to_logits = nn.Conv2d(dim, 1, 1)
1470 |
1471 | def forward(
1472 | self,
1473 | x,
1474 | mod = None,
1475 | kernel_mod = None
1476 | ):
1477 | residual = self.residual_fn(x)
1478 |
1479 | kwargs = dict()
1480 |
1481 | if not self.unconditional:
1482 | kwargs = dict(mod = mod, kernel_mod = kernel_mod)
1483 |
1484 | for conv1, activation, conv2, activation in self.layers:
1485 |
1486 | inner_residual = x
1487 |
1488 | x = conv1(x, **kwargs)
1489 | x = activation(x)
1490 | x = conv2(x, **kwargs)
1491 | x = activation(x)
1492 |
1493 | x = x + inner_residual
1494 | x = x * self.residual_scale
1495 |
1496 | x = x + residual
1497 | return self.to_logits(x)
1498 |
1499 | class Discriminator(nn.Module):
1500 | @beartype
1501 | def __init__(
1502 | self,
1503 | *,
1504 | dim_capacity = 16,
1505 | image_size,
1506 | dim_max = 2048,
1507 | channels = 3,
1508 | attn_resolutions: Tuple[int, ...] = (32, 16),
1509 | attn_dim_head = 64,
1510 | attn_heads = 8,
1511 | self_attn_dot_product = False,
1512 | ff_mult = 4,
1513 | text_encoder: TextEncoder | Dict | None = None,
1514 | text_dim = None,
1515 | filter_input_resolutions: bool = True,
1516 | multiscale_input_resolutions: Tuple[int, ...] = (64, 32, 16, 8),
1517 | multiscale_output_skip_stages: int = 1,
1518 | aux_recon_resolutions: Tuple[int, ...] = (8,),
1519 | aux_recon_patch_dims: Tuple[int, ...] = (2,),
1520 | aux_recon_frac_patches: Tuple[float, ...] = (0.25,),
1521 | aux_recon_fmap_dropout: float = 0.5,
1522 | resize_mode = 'bilinear',
1523 | num_conv_kernels = 2,
1524 | num_skip_layers_excite = 0,
1525 | unconditional = False,
1526 | predictor_depth = 2
1527 | ):
1528 | super().__init__()
1529 | self.unconditional = unconditional
1530 | assert not (unconditional and exists(text_encoder))
1531 |
1532 | assert is_power_of_two(image_size)
1533 | assert all([*map(is_power_of_two, attn_resolutions)])
1534 |
1535 | if filter_input_resolutions:
1536 | multiscale_input_resolutions = [*filter(lambda t: t < image_size, multiscale_input_resolutions)]
1537 |
1538 | assert is_unique(multiscale_input_resolutions)
1539 | assert all([*map(is_power_of_two, multiscale_input_resolutions)])
1540 | assert all([*map(lambda t: t < image_size, multiscale_input_resolutions)])
1541 |
1542 | self.multiscale_input_resolutions = multiscale_input_resolutions
1543 |
1544 | assert multiscale_output_skip_stages > 0
1545 | multiscale_output_resolutions = [resolution // (2 ** multiscale_output_skip_stages) for resolution in multiscale_input_resolutions]
1546 |
1547 | assert all([*map(lambda t: t >= 4, multiscale_output_resolutions)])
1548 |
1549 | assert all([*map(lambda t: t < image_size, multiscale_output_resolutions)])
1550 |
1551 | if len(multiscale_input_resolutions) > 0 and len(multiscale_output_resolutions) > 0:
1552 | assert max(multiscale_input_resolutions) > max(multiscale_output_resolutions)
1553 | assert min(multiscale_input_resolutions) > min(multiscale_output_resolutions)
1554 |
1555 | self.multiscale_output_resolutions = multiscale_output_resolutions
1556 |
1557 | assert all([*map(is_power_of_two, aux_recon_resolutions)])
1558 | assert len(aux_recon_resolutions) == len(aux_recon_patch_dims) == len(aux_recon_frac_patches)
1559 |
1560 | self.aux_recon_resolutions_to_patches = {resolution: (patch_dim, frac_patches) for resolution, patch_dim, frac_patches in zip(aux_recon_resolutions, aux_recon_patch_dims, aux_recon_frac_patches)}
1561 |
1562 | self.resize_mode = resize_mode
1563 |
1564 | num_layers = int(log2(image_size) - 1)
1565 | self.num_layers = num_layers
1566 | self.image_size = image_size
1567 |
1568 | resolutions = image_size / ((2 ** torch.arange(num_layers)))
1569 | resolutions = resolutions.long().tolist()
1570 |
1571 | dim_layers = (2 ** (torch.arange(num_layers) + 1)) * dim_capacity
1572 | dim_layers = F.pad(dim_layers, (1, 0), value = channels)
1573 | dim_layers.clamp_(max = dim_max)
1574 |
1575 | dim_layers = dim_layers.tolist()
1576 | dim_last = dim_layers[-1]
1577 | dim_pairs = list(zip(dim_layers[:-1], dim_layers[1:]))
1578 |
1579 | self.num_skip_layers_excite = num_skip_layers_excite
1580 |
1581 | self.residual_scale = 2 ** -0.5
1582 | self.layers = nn.ModuleList([])
1583 |
1584 | upsample_dims = []
1585 | predictor_dims = []
1586 | dim_kernel_attn = (num_conv_kernels if num_conv_kernels > 1 else 0)
1587 |
1588 | for ind, ((dim_in, dim_out), resolution) in enumerate(zip(dim_pairs, resolutions)):
1589 | is_first = ind == 0
1590 | is_last = (ind + 1) == len(dim_pairs)
1591 | should_downsample = not is_last
1592 | should_skip_layer_excite = not is_first and num_skip_layers_excite > 0 and (ind + num_skip_layers_excite) < len(dim_pairs)
1593 |
1594 | has_attn = resolution in attn_resolutions
1595 | has_multiscale_output = resolution in multiscale_output_resolutions
1596 |
1597 | has_aux_recon_decoder = resolution in aux_recon_resolutions
1598 | upsample_dims.insert(0, dim_in)
1599 |
1600 | skip_squeeze_excite = None
1601 | if should_skip_layer_excite:
1602 | dim_skip_in, _ = dim_pairs[ind + num_skip_layers_excite]
1603 | skip_squeeze_excite = SqueezeExcite(dim_in, dim_skip_in)
1604 |
1605 | # multi-scale rgb input to feature dimension
1606 |
1607 | from_rgb = nn.Conv2d(channels, dim_in, 7, padding = 3)
1608 |
1609 | # residual convolution
1610 |
1611 | residual_conv = nn.Conv2d(dim_in, dim_out, 1, stride = (2 if should_downsample else 1))
1612 |
1613 | # main resnet block
1614 |
1615 | resnet_block = nn.Sequential(
1616 | conv2d_3x3(dim_in, dim_out),
1617 | leaky_relu(),
1618 | conv2d_3x3(dim_out, dim_out),
1619 | leaky_relu()
1620 | )
1621 |
1622 | # multi-scale output
1623 |
1624 | multiscale_output_predictor = None
1625 |
1626 | if has_multiscale_output:
1627 | multiscale_output_predictor = Predictor(dim_out, num_conv_kernels = num_conv_kernels, depth = 2, unconditional = unconditional)
1628 | predictor_dims.extend([dim_out, dim_kernel_attn])
1629 |
1630 | aux_recon_decoder = None
1631 |
1632 | if has_aux_recon_decoder:
1633 | patch_dim, frac_patches = self.aux_recon_resolutions_to_patches[resolution]
1634 |
1635 | aux_recon_decoder = SimpleDecoder(
1636 | dim_out,
1637 | dims = tuple(upsample_dims),
1638 | patch_dim = patch_dim,
1639 | frac_patches = frac_patches,
1640 | dropout = aux_recon_fmap_dropout
1641 | )
1642 |
1643 | self.layers.append(nn.ModuleList([
1644 | skip_squeeze_excite,
1645 | from_rgb,
1646 | resnet_block,
1647 | residual_conv,
1648 | SelfAttentionBlock(dim_out, heads = attn_heads, dim_head = attn_dim_head, ff_mult = ff_mult, dot_product = self_attn_dot_product) if has_attn else None,
1649 | multiscale_output_predictor,
1650 | aux_recon_decoder,
1651 | Downsample(dim_out) if should_downsample else None,
1652 | ]))
1653 |
1654 | self.to_logits = nn.Sequential(
1655 | conv2d_3x3(dim_last, dim_last),
1656 | Rearrange('b c h w -> b (c h w)'),
1657 | nn.Linear(dim_last * (4 ** 2), 1),
1658 | Rearrange('b 1 -> b')
1659 | )
1660 |
1661 | # take care of text conditioning in the multiscale predictor branches
1662 |
1663 | assert unconditional or (exists(text_dim) ^ exists(text_encoder))
1664 |
1665 | if not unconditional:
1666 | if isinstance(text_encoder, dict):
1667 | text_encoder = TextEncoder(**text_encoder)
1668 |
1669 | self.text_dim = default(text_dim, text_encoder.dim)
1670 |
1671 | self.predictor_dims = predictor_dims
1672 | self.text_to_conv_conditioning = nn.Linear(self.text_dim, sum(predictor_dims)) if exists(self.text_dim) else None
1673 |
1674 | self.text_encoder = text_encoder
1675 |
1676 | self.apply(self.init_)
1677 |
1678 | def init_(self, m):
1679 | if type(m) in {nn.Conv2d, nn.Linear}:
1680 | nn.init.kaiming_normal_(m.weight, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu')
1681 |
1682 | def resize_image_to(self, images, resolution):
1683 | return F.interpolate(images, resolution, mode = self.resize_mode)
1684 |
1685 | def real_images_to_rgbs(self, images):
1686 | return [self.resize_image_to(images, resolution) for resolution in self.multiscale_input_resolutions]
1687 |
1688 | @property
1689 | def total_params(self):
1690 | return sum([p.numel() for p in self.parameters()])
1691 |
1692 | @property
1693 | def device(self):
1694 | return next(self.parameters()).device
1695 |
1696 | @beartype
1697 | def forward(
1698 | self,
1699 | images,
1700 | rgbs: List[Tensor], # multi-resolution inputs (rgbs) from the generator
1701 | texts: List[str] | None = None,
1702 | text_encodings: Tensor | None = None,
1703 | text_embeds = None,
1704 | real_images = None, # if this were passed in, the network will automatically append the real to the presumably generated images passed in as the first argument, and generate all intermediate resolutions through resizing and concat appropriately
1705 | return_multiscale_outputs = True, # can force it not to return multi-scale logits
1706 | calc_aux_loss = True
1707 | ):
1708 | if not self.unconditional:
1709 | assert (exists(texts) ^ exists(text_encodings)) ^ exists(text_embeds), 'either texts as List[str] is passed in, or clip text_encodings as Tensor'
1710 |
1711 | if exists(texts):
1712 | assert exists(self.text_encoder)
1713 | text_embeds, *_ = self.text_encoder(texts = texts)
1714 |
1715 | elif exists(text_encodings):
1716 | assert exists(self.text_encoder)
1717 | text_embeds, *_ = self.text_encoder(text_encodings = text_encodings)
1718 |
1719 | assert exists(text_embeds), 'raw text or text embeddings were not passed into discriminator for conditional training'
1720 |
1721 | conv_mods = self.text_to_conv_conditioning(text_embeds).split(self.predictor_dims, dim = -1)
1722 | conv_mods = iter(conv_mods)
1723 |
1724 | else:
1725 | assert not any([*map(exists, (texts, text_embeds))])
1726 |
1727 | x = images
1728 |
1729 | image_size = (self.image_size, self.image_size)
1730 |
1731 | assert x.shape[-2:] == image_size
1732 |
1733 | batch = x.shape[0]
1734 |
1735 | # index the rgbs by resolution
1736 |
1737 | rgbs_index = {t.shape[-1]: t for t in rgbs} if exists(rgbs) else {}
1738 |
1739 | # assert that the necessary resolutions are there
1740 |
1741 | assert is_empty(set(self.multiscale_input_resolutions) - set(rgbs_index.keys())), f'rgbs of necessary resolution {self.multiscale_input_resolutions} were not passed in'
1742 |
1743 | # hold multiscale outputs
1744 |
1745 | multiscale_outputs = []
1746 |
1747 | # hold auxiliary recon losses
1748 |
1749 | aux_recon_losses = []
1750 |
1751 | # excitations
1752 |
1753 | excitations = [None] * (self.num_skip_layers_excite + 1) # +1 since first image in pixel space is not excited
1754 |
1755 | for squeeze_excite, from_rgb, block, residual_fn, attn, predictor, recon_decoder, downsample in self.layers:
1756 | resolution = x.shape[-1]
1757 |
1758 | if exists(squeeze_excite):
1759 | skip_excite = squeeze_excite(x)
1760 | excitations.append(skip_excite)
1761 |
1762 | excite = safe_unshift(excitations)
1763 |
1764 | if exists(excite):
1765 | excite = repeat(excite, 'b ... -> (s b) ...', s = x.shape[0] // excite.shape[0])
1766 | x = x * excite
1767 |
1768 | batch_prev_stage = x.shape[0]
1769 | has_multiscale_input = resolution in self.multiscale_input_resolutions
1770 |
1771 | if has_multiscale_input:
1772 | rgb = rgbs_index.get(resolution, None)
1773 |
1774 | # multi-scale input features
1775 |
1776 | multi_scale_input_feats = from_rgb(rgb)
1777 |
1778 | # expand multi-scale input features, as could include extra scales from previous stage
1779 |
1780 | multi_scale_input_feats = repeat(multi_scale_input_feats, 'b ... -> (s b) ...', s = x.shape[0] // rgb.shape[0])
1781 |
1782 | # add the multi-scale input features to the current hidden state from main stem
1783 |
1784 | x = x + multi_scale_input_feats
1785 |
1786 | # and also concat for scale invariance
1787 |
1788 | x = torch.cat((x, multi_scale_input_feats), dim = 0)
1789 |
1790 | residual = residual_fn(x)
1791 | x = block(x)
1792 |
1793 | if exists(attn):
1794 | x = attn(x)
1795 |
1796 | if exists(predictor):
1797 | pred_kwargs = dict()
1798 | if not self.unconditional:
1799 | pred_kwargs = dict(mod = next(conv_mods), kernel_mod = next(conv_mods))
1800 |
1801 | if return_multiscale_outputs:
1802 | predictor_input = x[:batch_prev_stage]
1803 | multiscale_outputs.append(predictor(predictor_input, **pred_kwargs))
1804 |
1805 | if exists(downsample):
1806 | x = downsample(x)
1807 |
1808 | x = x + residual
1809 | x = x * self.residual_scale
1810 |
1811 | if exists(recon_decoder) and calc_aux_loss:
1812 |
1813 | recon_output = x[:batch_prev_stage]
1814 | recon_output = rearrange(x, '(s b) ... -> s b ...', b = batch)
1815 |
1816 | aux_recon_target = images
1817 |
1818 | # only use the input real images for aux recon
1819 |
1820 | recon_output = recon_output[0]
1821 |
1822 | # only reconstruct a fraction of images across batch and scale
1823 | # for efficiency
1824 |
1825 | aux_recon_loss = recon_decoder(recon_output, aux_recon_target)
1826 | aux_recon_losses.append(aux_recon_loss)
1827 |
1828 | # sanity check
1829 |
1830 | assert self.unconditional or is_empty([*conv_mods]), 'convolutions were incorrectly modulated'
1831 |
1832 | # to logits
1833 |
1834 | logits = self.to_logits(x)
1835 | logits = rearrange(logits, '(s b) ... -> s b ...', b = batch)
1836 |
1837 | return logits, multiscale_outputs, aux_recon_losses
1838 |
1839 | # gan
1840 |
1841 | TrainDiscrLosses = namedtuple('TrainDiscrLosses', [
1842 | 'divergence',
1843 | 'multiscale_divergence',
1844 | 'vision_aided_divergence',
1845 | 'total_matching_aware_loss',
1846 | 'gradient_penalty',
1847 | 'aux_reconstruction'
1848 | ])
1849 |
1850 | TrainGenLosses = namedtuple('TrainGenLosses', [
1851 | 'divergence',
1852 | 'multiscale_divergence',
1853 | 'total_vd_divergence',
1854 | 'contrastive_loss'
1855 | ])
1856 |
1857 | class GigaGAN(nn.Module):
1858 | @beartype
1859 | def __init__(
1860 | self,
1861 | *,
1862 | generator: BaseGenerator | Dict,
1863 | discriminator: Discriminator | Dict,
1864 | vision_aided_discriminator: VisionAidedDiscriminator | Dict | None = None,
1865 | diff_augment: DiffAugment | Dict | None = None,
1866 | learning_rate = 2e-4,
1867 | betas = (0.5, 0.9),
1868 | weight_decay = 0.,
1869 | discr_aux_recon_loss_weight = 1.,
1870 | multiscale_divergence_loss_weight = 0.1,
1871 | vision_aided_divergence_loss_weight = 0.5,
1872 | generator_contrastive_loss_weight = 0.1,
1873 | matching_awareness_loss_weight = 0.1,
1874 | calc_multiscale_loss_every = 1,
1875 | apply_gradient_penalty_every = 4,
1876 | resize_image_mode = 'bilinear',
1877 | train_upsampler = False,
1878 | log_steps_every = 20,
1879 | create_ema_generator_at_init = True,
1880 | save_and_sample_every = 1000,
1881 | early_save_thres_steps = 2500,
1882 | early_save_and_sample_every = 100,
1883 | num_samples = 25,
1884 | model_folder = './gigagan-models',
1885 | results_folder = './gigagan-results',
1886 | sample_upsampler_dl: DataLoader | None = None,
1887 | accelerator: Accelerator | None = None,
1888 | accelerate_kwargs: dict = {},
1889 | find_unused_parameters = True,
1890 | amp = False,
1891 | mixed_precision_type = 'fp16'
1892 | ):
1893 | super().__init__()
1894 |
1895 | # create accelerator
1896 |
1897 | if accelerator:
1898 | self.accelerator = accelerator
1899 | assert is_empty(accelerate_kwargs)
1900 | else:
1901 | kwargs = DistributedDataParallelKwargs(find_unused_parameters = find_unused_parameters)
1902 |
1903 | self.accelerator = Accelerator(
1904 | kwargs_handlers = [kwargs],
1905 | mixed_precision = mixed_precision_type if amp else 'no',
1906 | **accelerate_kwargs
1907 | )
1908 |
1909 | # whether to train upsampler or not
1910 |
1911 | self.train_upsampler = train_upsampler
1912 |
1913 | if train_upsampler:
1914 | from gigagan_pytorch.unet_upsampler import UnetUpsampler
1915 | generator_klass = UnetUpsampler
1916 | else:
1917 | generator_klass = Generator
1918 |
1919 | # gradient penalty and auxiliary recon loss
1920 |
1921 | self.apply_gradient_penalty_every = apply_gradient_penalty_every
1922 | self.calc_multiscale_loss_every = calc_multiscale_loss_every
1923 |
1924 | if isinstance(generator, dict):
1925 | generator = generator_klass(**generator)
1926 |
1927 | if isinstance(discriminator, dict):
1928 | discriminator = Discriminator(**discriminator)
1929 |
1930 | if exists(vision_aided_discriminator) and isinstance(vision_aided_discriminator, dict):
1931 | vision_aided_discriminator = VisionAidedDiscriminator(**vision_aided_discriminator)
1932 |
1933 | assert isinstance(generator, generator_klass)
1934 |
1935 | # diff augment
1936 |
1937 | if isinstance(diff_augment, dict):
1938 | diff_augment = DiffAugment(**diff_augment)
1939 |
1940 | self.diff_augment = diff_augment
1941 |
1942 | # use _base to designate unwrapped models
1943 |
1944 | self.G = generator
1945 | self.D = discriminator
1946 | self.VD = vision_aided_discriminator
1947 |
1948 | # validate multiscale input resolutions
1949 |
1950 | if train_upsampler:
1951 | assert is_empty(set(discriminator.multiscale_input_resolutions) - set(generator.allowable_rgb_resolutions)), f'only multiscale input resolutions of {generator.allowable_rgb_resolutions} is allowed based on the unet input and output image size. simply do Discriminator(multiscale_input_resolutions = unet.allowable_rgb_resolutions) to resolve this error'
1952 |
1953 | # ema
1954 |
1955 | self.has_ema_generator = False
1956 |
1957 | if self.is_main and create_ema_generator_at_init:
1958 | self.create_ema_generator()
1959 |
1960 | # print number of parameters
1961 |
1962 | self.print('\n')
1963 |
1964 | self.print(f'Generator: {numerize.numerize(generator.total_params)}')
1965 | self.print(f'Discriminator: {numerize.numerize(discriminator.total_params)}')
1966 |
1967 | if exists(self.VD):
1968 | self.print(f'Vision Discriminator: {numerize.numerize(vision_aided_discriminator.total_params)}')
1969 |
1970 | self.print('\n')
1971 |
1972 | # text encoder
1973 |
1974 | assert generator.unconditional == discriminator.unconditional
1975 | assert not exists(vision_aided_discriminator) or vision_aided_discriminator.unconditional == generator.unconditional
1976 |
1977 | self.unconditional = generator.unconditional
1978 |
1979 | # optimizers
1980 |
1981 | self.G_opt = get_optimizer(self.G.parameters(), lr = learning_rate, betas = betas, weight_decay = weight_decay)
1982 | self.D_opt = get_optimizer(self.D.parameters(), lr = learning_rate, betas = betas, weight_decay = weight_decay)
1983 |
1984 | # prepare for distributed
1985 |
1986 | self.G, self.D, self.G_opt, self.D_opt = self.accelerator.prepare(self.G, self.D, self.G_opt, self.D_opt)
1987 |
1988 | # vision aided discriminator optimizer
1989 |
1990 | if exists(self.VD):
1991 | self.VD_opt = get_optimizer(self.VD.parameters(), lr = learning_rate, betas = betas, weight_decay = weight_decay)
1992 | self.VD_opt = self.accelerator.prepare(self.VD_opt)
1993 |
1994 | # loss related
1995 |
1996 | self.discr_aux_recon_loss_weight = discr_aux_recon_loss_weight
1997 | self.multiscale_divergence_loss_weight = multiscale_divergence_loss_weight
1998 | self.vision_aided_divergence_loss_weight = vision_aided_divergence_loss_weight
1999 | self.generator_contrastive_loss_weight = generator_contrastive_loss_weight
2000 | self.matching_awareness_loss_weight = matching_awareness_loss_weight
2001 |
2002 | # resize image mode
2003 |
2004 | self.resize_image_mode = resize_image_mode
2005 |
2006 | # steps
2007 |
2008 | self.log_steps_every = log_steps_every
2009 |
2010 | self.register_buffer('steps', torch.ones(1, dtype = torch.long))
2011 |
2012 | # save and sample
2013 |
2014 | self.save_and_sample_every = save_and_sample_every
2015 | self.early_save_thres_steps = early_save_thres_steps
2016 | self.early_save_and_sample_every = early_save_and_sample_every
2017 |
2018 | self.num_samples = num_samples
2019 |
2020 | self.train_dl = None
2021 |
2022 | self.sample_upsampler_dl_iter = None
2023 | if exists(sample_upsampler_dl):
2024 | self.sample_upsampler_dl_iter = cycle(self.sample_upsampler_dl)
2025 |
2026 | self.results_folder = Path(results_folder)
2027 | self.model_folder = Path(model_folder)
2028 |
2029 | mkdir_if_not_exists(self.results_folder)
2030 | mkdir_if_not_exists(self.model_folder)
2031 |
2032 | def save(self, path, overwrite = True):
2033 | path = Path(path)
2034 | mkdir_if_not_exists(path.parents[0])
2035 |
2036 | assert overwrite or not path.exists()
2037 |
2038 | pkg = dict(
2039 | G = self.unwrapped_G.state_dict(),
2040 | D = self.unwrapped_D.state_dict(),
2041 | G_opt = self.G_opt.state_dict(),
2042 | D_opt = self.D_opt.state_dict(),
2043 | steps = self.steps.item(),
2044 | version = __version__
2045 | )
2046 |
2047 | if exists(self.G_opt.scaler):
2048 | pkg['G_scaler'] = self.G_opt.scaler.state_dict()
2049 |
2050 | if exists(self.D_opt.scaler):
2051 | pkg['D_scaler'] = self.D_opt.scaler.state_dict()
2052 |
2053 | if exists(self.VD):
2054 | pkg['VD'] = self.unwrapped_VD.state_dict()
2055 | pkg['VD_opt'] = self.VD_opt.state_dict()
2056 |
2057 | if exists(self.VD_opt.scaler):
2058 | pkg['VD_scaler'] = self.VD_opt.scaler.state_dict()
2059 |
2060 | if self.has_ema_generator:
2061 | pkg['G_ema'] = self.G_ema.state_dict()
2062 |
2063 | torch.save(pkg, str(path))
2064 |
2065 | def load(self, path, strict = False):
2066 | path = Path(path)
2067 | assert path.exists()
2068 |
2069 | pkg = torch.load(str(path))
2070 |
2071 | if 'version' in pkg and pkg['version'] != __version__:
2072 | print(f"trying to load from version {pkg['version']}")
2073 |
2074 | self.unwrapped_G.load_state_dict(pkg['G'], strict = strict)
2075 | self.unwrapped_D.load_state_dict(pkg['D'], strict = strict)
2076 |
2077 | if exists(self.VD):
2078 | self.unwrapped_VD.load_state_dict(pkg['VD'], strict = strict)
2079 |
2080 | if self.has_ema_generator:
2081 | self.G_ema.load_state_dict(pkg['G_ema'])
2082 |
2083 | if 'steps' in pkg:
2084 | self.steps.copy_(torch.tensor([pkg['steps']]))
2085 |
2086 | if 'G_opt'not in pkg or 'D_opt' not in pkg:
2087 | return
2088 |
2089 | try:
2090 | self.G_opt.load_state_dict(pkg['G_opt'])
2091 | self.D_opt.load_state_dict(pkg['D_opt'])
2092 |
2093 | if exists(self.VD):
2094 | self.VD_opt.load_state_dict(pkg['VD_opt'])
2095 |
2096 | if 'G_scaler' in pkg and exists(self.G_opt.scaler):
2097 | self.G_opt.scaler.load_state_dict(pkg['G_scaler'])
2098 |
2099 | if 'D_scaler' in pkg and exists(self.D_opt.scaler):
2100 | self.D_opt.scaler.load_state_dict(pkg['D_scaler'])
2101 |
2102 | if 'VD_scaler' in pkg and exists(self.VD_opt.scaler):
2103 | self.VD_opt.scaler.load_state_dict(pkg['VD_scaler'])
2104 |
2105 | except Exception as e:
2106 | self.print(f'unable to load optimizers {e.msg}- optimizer states will be reset')
2107 | pass
2108 |
2109 | # accelerate related
2110 |
2111 | @property
2112 | def device(self):
2113 | return self.accelerator.device
2114 |
2115 | @property
2116 | def unwrapped_G(self):
2117 | return self.accelerator.unwrap_model(self.G)
2118 |
2119 | @property
2120 | def unwrapped_D(self):
2121 | return self.accelerator.unwrap_model(self.D)
2122 |
2123 | @property
2124 | def unwrapped_VD(self):
2125 | return self.accelerator.unwrap_model(self.VD)
2126 |
2127 | @property
2128 | def need_vision_aided_discriminator(self):
2129 | return exists(self.VD) and self.vision_aided_divergence_loss_weight > 0.
2130 |
2131 | @property
2132 | def need_contrastive_loss(self):
2133 | return self.generator_contrastive_loss_weight > 0. and not self.unconditional
2134 |
2135 | def print(self, msg):
2136 | self.accelerator.print(msg)
2137 |
2138 | @property
2139 | def is_distributed(self):
2140 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
2141 |
2142 | @property
2143 | def is_main(self):
2144 | return self.accelerator.is_main_process
2145 |
2146 | @property
2147 | def is_local_main(self):
2148 | return self.accelerator.is_local_main_process
2149 |
2150 | def resize_image_to(self, images, resolution):
2151 | return F.interpolate(images, resolution, mode = self.resize_image_mode)
2152 |
2153 | @beartype
2154 | def set_dataloader(self, dl: DataLoader):
2155 | assert not exists(self.train_dl), 'training dataloader has already been set'
2156 |
2157 | self.train_dl = dl
2158 | self.train_dl_batch_size = dl.batch_size
2159 |
2160 | self.train_dl = self.accelerator.prepare(self.train_dl)
2161 |
2162 | # generate function
2163 |
2164 | @torch.inference_mode()
2165 | def generate(self, *args, **kwargs):
2166 | model = self.G_ema if self.has_ema_generator else self.G
2167 | model.eval()
2168 | return model(*args, **kwargs)
2169 |
2170 | # create EMA generator
2171 |
2172 | def create_ema_generator(
2173 | self,
2174 | update_every = 10,
2175 | update_after_step = 100,
2176 | decay = 0.995
2177 | ):
2178 | if not self.is_main:
2179 | return
2180 |
2181 | assert not self.has_ema_generator, 'EMA generator has already been created'
2182 |
2183 | self.G_ema = EMA(self.unwrapped_G, update_every = update_every, update_after_step = update_after_step, beta = decay)
2184 | self.has_ema_generator = True
2185 |
2186 | def generate_kwargs(self, dl_iter, batch_size):
2187 | # what to pass into the generator
2188 | # depends on whether training upsampler or not
2189 |
2190 | maybe_text_kwargs = dict()
2191 | if self.train_upsampler or not self.unconditional:
2192 | assert exists(dl_iter)
2193 |
2194 | if self.unconditional:
2195 | real_images = next(dl_iter)
2196 | else:
2197 | result = next(dl_iter)
2198 | assert isinstance(result, tuple), 'dataset should return a tuple of two items for text conditioned training, (images: Tensor, texts: List[str])'
2199 | real_images, texts = result
2200 |
2201 | maybe_text_kwargs['texts'] = texts[:batch_size]
2202 |
2203 | real_images = real_images.to(self.device)
2204 |
2205 | # if training upsample generator, need to downsample real images
2206 |
2207 | if self.train_upsampler:
2208 | size = self.unwrapped_G.input_image_size
2209 | lowres_real_images = F.interpolate(real_images, (size, size))
2210 |
2211 | G_kwargs = dict(lowres_image = lowres_real_images)
2212 | else:
2213 | assert exists(batch_size)
2214 |
2215 | G_kwargs = dict(batch_size = batch_size)
2216 |
2217 | # create noise
2218 |
2219 | noise = torch.randn(batch_size, self.unwrapped_G.style_network.dim, device = self.device)
2220 |
2221 | G_kwargs.update(noise = noise)
2222 |
2223 | return G_kwargs, maybe_text_kwargs
2224 |
2225 | @beartype
2226 | def train_discriminator_step(
2227 | self,
2228 | dl_iter: Iterable,
2229 | grad_accum_every = 1,
2230 | apply_gradient_penalty = False,
2231 | calc_multiscale_loss = True
2232 | ):
2233 | total_divergence = 0.
2234 | total_vision_aided_divergence = 0.
2235 |
2236 | total_gp_loss = 0.
2237 | total_aux_loss = 0.
2238 |
2239 | total_multiscale_divergence = 0. if calc_multiscale_loss else None
2240 |
2241 | has_matching_awareness = not self.unconditional and self.matching_awareness_loss_weight > 0.
2242 |
2243 | total_matching_aware_loss = 0.
2244 |
2245 | all_texts = []
2246 | all_fake_images = []
2247 | all_fake_rgbs = []
2248 | all_real_images = []
2249 |
2250 | self.G.train()
2251 |
2252 | self.D.train()
2253 | self.D_opt.zero_grad()
2254 |
2255 | if self.need_vision_aided_discriminator:
2256 | self.VD.train()
2257 | self.VD_opt.zero_grad()
2258 |
2259 | for _ in range(grad_accum_every):
2260 |
2261 | if self.unconditional:
2262 | real_images = next(dl_iter)
2263 | else:
2264 | result = next(dl_iter)
2265 | assert isinstance(result, tuple), 'dataset should return a tuple of two items for text conditioned training, (images: Tensor, texts: List[str])'
2266 | real_images, texts = result
2267 |
2268 | all_real_images.append(real_images)
2269 | all_texts.extend(texts)
2270 |
2271 | # requires grad for real images, for gradient penalty
2272 |
2273 | real_images = real_images.to(self.device)
2274 | real_images.requires_grad_()
2275 |
2276 | real_images_rgbs = self.unwrapped_D.real_images_to_rgbs(real_images)
2277 |
2278 | # diff augment real images
2279 |
2280 | if exists(self.diff_augment):
2281 | real_images, real_images_rgbs = self.diff_augment(real_images, real_images_rgbs)
2282 |
2283 | # batch size
2284 |
2285 | batch_size = real_images.shape[0]
2286 |
2287 | # for discriminator training, fit upsampler and image synthesis logic under same function
2288 |
2289 | G_kwargs, maybe_text_kwargs = self.generate_kwargs(dl_iter, batch_size)
2290 |
2291 | # generator
2292 |
2293 | with torch.no_grad(), self.accelerator.autocast():
2294 | images, rgbs = self.G(
2295 | **G_kwargs,
2296 | **maybe_text_kwargs,
2297 | return_all_rgbs = True
2298 | )
2299 |
2300 | all_fake_images.append(images)
2301 | all_fake_rgbs.append(rgbs)
2302 |
2303 | # diff augment
2304 |
2305 | if exists(self.diff_augment):
2306 | images, rgbs = self.diff_augment(images, rgbs)
2307 |
2308 | # detach output of generator, as training discriminator only
2309 |
2310 | images.detach_()
2311 |
2312 | for rgb in rgbs:
2313 | rgb.detach_()
2314 |
2315 | # main divergence loss
2316 |
2317 | with self.accelerator.autocast():
2318 |
2319 | fake_logits, fake_multiscale_logits, _ = self.D(
2320 | images,
2321 | rgbs,
2322 | **maybe_text_kwargs,
2323 | return_multiscale_outputs = calc_multiscale_loss,
2324 | calc_aux_loss = False
2325 | )
2326 |
2327 | real_logits, real_multiscale_logits, aux_recon_losses = self.D(
2328 | real_images,
2329 | real_images_rgbs,
2330 | **maybe_text_kwargs,
2331 | return_multiscale_outputs = calc_multiscale_loss,
2332 | calc_aux_loss = True
2333 | )
2334 |
2335 | divergence = discriminator_hinge_loss(real_logits, fake_logits)
2336 | total_divergence += (divergence.item() / grad_accum_every)
2337 |
2338 | # handle multi-scale divergence
2339 |
2340 | multiscale_divergence = 0.
2341 |
2342 | if self.multiscale_divergence_loss_weight > 0. and len(fake_multiscale_logits) > 0:
2343 |
2344 | for multiscale_fake, multiscale_real in zip(fake_multiscale_logits, real_multiscale_logits):
2345 | multiscale_loss = discriminator_hinge_loss(multiscale_real, multiscale_fake)
2346 | multiscale_divergence = multiscale_divergence + multiscale_loss
2347 |
2348 | total_multiscale_divergence += (multiscale_divergence.item() / grad_accum_every)
2349 |
2350 | # figure out gradient penalty if needed
2351 |
2352 | gp_loss = 0.
2353 |
2354 | if apply_gradient_penalty:
2355 | gp_loss = gradient_penalty(
2356 | real_images,
2357 | outputs = [real_logits, *real_multiscale_logits],
2358 | grad_output_weights = [1., *(self.multiscale_divergence_loss_weight,) * len(real_multiscale_logits)],
2359 | scaler = self.D_opt.scaler
2360 | )
2361 |
2362 | if not torch.isnan(gp_loss):
2363 | total_gp_loss += (gp_loss.item() / grad_accum_every)
2364 |
2365 | # handle vision aided discriminator, if needed
2366 |
2367 | vd_loss = 0.
2368 |
2369 | if self.need_vision_aided_discriminator:
2370 |
2371 | fake_vision_aided_logits = self.VD(images, **maybe_text_kwargs)
2372 | real_vision_aided_logits, clip_encodings = self.VD(real_images, return_clip_encodings = True, **maybe_text_kwargs)
2373 |
2374 | for fake_logits, real_logits in zip(fake_vision_aided_logits, real_vision_aided_logits):
2375 | vd_loss = vd_loss + discriminator_hinge_loss(real_logits, fake_logits)
2376 |
2377 | total_vision_aided_divergence += (vd_loss.item() / grad_accum_every)
2378 |
2379 | # handle gradient penalty for vision aided discriminator
2380 |
2381 | if apply_gradient_penalty:
2382 |
2383 | vd_gp_loss = gradient_penalty(
2384 | clip_encodings,
2385 | outputs = real_vision_aided_logits,
2386 | grad_output_weights = [self.vision_aided_divergence_loss_weight] * len(real_vision_aided_logits),
2387 | scaler = self.VD_opt.scaler
2388 | )
2389 |
2390 | if not torch.isnan(vd_gp_loss):
2391 | gp_loss = gp_loss + vd_gp_loss
2392 |
2393 | total_gp_loss += (vd_gp_loss.item() / grad_accum_every)
2394 |
2395 | # sum up losses
2396 |
2397 | total_loss = divergence + gp_loss
2398 |
2399 | if self.multiscale_divergence_loss_weight > 0.:
2400 | total_loss = total_loss + multiscale_divergence * self.multiscale_divergence_loss_weight
2401 |
2402 | if self.vision_aided_divergence_loss_weight > 0.:
2403 | total_loss = total_loss + vd_loss * self.vision_aided_divergence_loss_weight
2404 |
2405 | if self.discr_aux_recon_loss_weight > 0.:
2406 | aux_loss = sum(aux_recon_losses)
2407 |
2408 | total_aux_loss += (aux_loss.item() / grad_accum_every)
2409 |
2410 | total_loss = total_loss + aux_loss * self.discr_aux_recon_loss_weight
2411 |
2412 | # backwards
2413 |
2414 | self.accelerator.backward(total_loss / grad_accum_every)
2415 |
2416 |
2417 | # matching awareness loss
2418 | # strategy would be to rotate the texts by one and assume batch is shuffled enough for mismatched conditions
2419 |
2420 | if has_matching_awareness:
2421 |
2422 | # rotate texts
2423 |
2424 | all_texts = [*all_texts[1:], all_texts[0]]
2425 | all_texts = group_by_num_consecutive(texts, batch_size)
2426 |
2427 | zipped_data = zip(
2428 | all_fake_images,
2429 | all_fake_rgbs,
2430 | all_real_images,
2431 | all_texts
2432 | )
2433 |
2434 | total_loss = 0.
2435 |
2436 | for fake_images, fake_rgbs, real_images, texts in zipped_data:
2437 |
2438 | with self.accelerator.autocast():
2439 | fake_logits, *_ = self.D(
2440 | fake_images,
2441 | fake_rgbs,
2442 | texts = texts,
2443 | return_multiscale_outputs = False,
2444 | calc_aux_loss = False
2445 | )
2446 |
2447 | real_images_rgbs = self.D.real_images_to_rgbs(real_images)
2448 |
2449 | real_logits, *_ = self.D(
2450 | real_images,
2451 | real_images_rgbs,
2452 | texts = texts,
2453 | return_multiscale_outputs = False,
2454 | calc_aux_loss = False
2455 | )
2456 |
2457 | matching_loss = aux_matching_loss(real_logits, fake_logits)
2458 |
2459 | total_matching_aware_loss = (matching_loss.item() / grad_accum_every)
2460 |
2461 | loss = matching_loss * self.matching_awareness_loss_weight
2462 |
2463 | self.accelerator.backward(loss / grad_accum_every)
2464 |
2465 | self.D_opt.step()
2466 |
2467 | if self.need_vision_aided_discriminator:
2468 | self.VD_opt.step()
2469 |
2470 | return TrainDiscrLosses(
2471 | total_divergence,
2472 | total_multiscale_divergence,
2473 | total_vision_aided_divergence,
2474 | total_matching_aware_loss,
2475 | total_gp_loss,
2476 | total_aux_loss
2477 | )
2478 |
2479 | def train_generator_step(
2480 | self,
2481 | batch_size = None,
2482 | dl_iter: Iterable | None = None,
2483 | grad_accum_every = 1,
2484 | calc_multiscale_loss = True
2485 | ):
2486 | total_divergence = 0.
2487 | total_multiscale_divergence = 0. if calc_multiscale_loss else None
2488 | total_vd_divergence = 0.
2489 | contrastive_loss = 0.
2490 |
2491 | self.G.train()
2492 | self.D.train()
2493 |
2494 | self.D_opt.zero_grad()
2495 | self.G_opt.zero_grad()
2496 |
2497 | all_images = []
2498 | all_texts = []
2499 |
2500 | for _ in range(grad_accum_every):
2501 |
2502 | # generator
2503 |
2504 | G_kwargs, maybe_text_kwargs = self.generate_kwargs(dl_iter, batch_size)
2505 |
2506 | with self.accelerator.autocast():
2507 | images, rgbs = self.G(
2508 | **G_kwargs,
2509 | **maybe_text_kwargs,
2510 | return_all_rgbs = True
2511 | )
2512 |
2513 | # diff augment
2514 |
2515 | if exists(self.diff_augment):
2516 | images, rgbs = self.diff_augment(images, rgbs)
2517 |
2518 | # accumulate all images and texts for maybe contrastive loss
2519 |
2520 | if self.need_contrastive_loss:
2521 | all_images.append(images)
2522 | all_texts.extend(maybe_text_kwargs['texts'])
2523 |
2524 | # discriminator
2525 |
2526 | logits, multiscale_logits, _ = self.D(
2527 | images,
2528 | rgbs,
2529 | **maybe_text_kwargs,
2530 | return_multiscale_outputs = calc_multiscale_loss,
2531 | calc_aux_loss = False
2532 | )
2533 |
2534 | # generator hinge loss discriminator and multiscale
2535 |
2536 | divergence = generator_hinge_loss(logits)
2537 |
2538 | total_divergence += (divergence.item() / grad_accum_every)
2539 |
2540 | total_loss = divergence
2541 |
2542 | if self.multiscale_divergence_loss_weight > 0. and len(multiscale_logits) > 0:
2543 | multiscale_divergence = 0.
2544 |
2545 | for multiscale_logit in multiscale_logits:
2546 | multiscale_divergence = multiscale_divergence + generator_hinge_loss(multiscale_logit)
2547 |
2548 | total_multiscale_divergence += (multiscale_divergence.item() / grad_accum_every)
2549 |
2550 | total_loss = total_loss + multiscale_divergence * self.multiscale_divergence_loss_weight
2551 |
2552 | # vision aided generator hinge loss
2553 |
2554 | if self.need_vision_aided_discriminator:
2555 | vd_loss = 0.
2556 |
2557 | logits = self.VD(images, **maybe_text_kwargs)
2558 |
2559 | for logit in logits:
2560 | vd_loss = vd_loss + generator_hinge_loss(logit)
2561 |
2562 | total_vd_divergence += (vd_loss.item() / grad_accum_every)
2563 |
2564 | total_loss = total_loss + vd_loss * self.vision_aided_divergence_loss_weight
2565 |
2566 | self.accelerator.backward(total_loss / grad_accum_every, retain_graph = self.need_contrastive_loss)
2567 |
2568 | # if needs the generator contrastive loss
2569 | # gather up all images and texts and calculate it
2570 |
2571 | if self.need_contrastive_loss:
2572 | all_images = torch.cat(all_images, dim = 0)
2573 |
2574 | contrastive_loss = aux_clip_loss(
2575 | clip = self.G.text_encoder.clip,
2576 | texts = all_texts,
2577 | images = all_images
2578 | )
2579 |
2580 | self.accelerator.backward(contrastive_loss * self.generator_contrastive_loss_weight)
2581 |
2582 | # generator optimizer step
2583 |
2584 | self.G_opt.step()
2585 |
2586 | # update exponentially moving averaged generator
2587 |
2588 | self.accelerator.wait_for_everyone()
2589 |
2590 | if self.is_main and self.has_ema_generator:
2591 | self.G_ema.update()
2592 |
2593 | return TrainGenLosses(
2594 | total_divergence,
2595 | total_multiscale_divergence,
2596 | total_vd_divergence,
2597 | contrastive_loss
2598 | )
2599 |
2600 | def sample(self, model, dl_iter, batch_size):
2601 | G_kwargs, maybe_text_kwargs = self.generate_kwargs(dl_iter, batch_size)
2602 |
2603 | with self.accelerator.autocast():
2604 | generator_output = model(**G_kwargs, **maybe_text_kwargs)
2605 |
2606 | if not self.train_upsampler:
2607 | return generator_output
2608 |
2609 | output_size = generator_output.shape[-1]
2610 | lowres_image = G_kwargs['lowres_image']
2611 | lowres_image = F.interpolate(lowres_image, (output_size, output_size))
2612 |
2613 | return torch.cat([lowres_image, generator_output])
2614 |
2615 | @torch.inference_mode()
2616 | def save_sample(
2617 | self,
2618 | batch_size,
2619 | dl_iter = None
2620 | ):
2621 | milestone = self.steps.item() // self.save_and_sample_every
2622 | nrow_mult = 2 if self.train_upsampler else 1
2623 | batches = num_to_groups(self.num_samples, batch_size)
2624 |
2625 | if self.train_upsampler:
2626 | dl_iter = default(self.sample_upsampler_dl_iter, dl_iter)
2627 |
2628 | assert exists(dl_iter)
2629 |
2630 | sample_models_and_output_file_name = [(self.unwrapped_G, f'sample-{milestone}.png')]
2631 |
2632 | if self.has_ema_generator:
2633 | sample_models_and_output_file_name.append((self.G_ema, f'ema-sample-{milestone}.png'))
2634 |
2635 | for model, filename in sample_models_and_output_file_name:
2636 | model.eval()
2637 |
2638 | all_images_list = list(map(lambda n: self.sample(model, dl_iter, n), batches))
2639 | all_images = torch.cat(all_images_list, dim = 0)
2640 |
2641 | all_images.clamp_(0., 1.)
2642 |
2643 | utils.save_image(
2644 | all_images,
2645 | str(self.results_folder / filename),
2646 | nrow = int(sqrt(self.num_samples)) * nrow_mult
2647 | )
2648 |
2649 | # Possible to do: Include some metric to save if improved, include some sampler dict text entries
2650 | self.save(str(self.model_folder / f'model-{milestone}.ckpt'))
2651 |
2652 | @beartype
2653 | def forward(
2654 | self,
2655 | *,
2656 | steps,
2657 | grad_accum_every = 1
2658 | ):
2659 | assert exists(self.train_dl), 'you need to set the dataloader by running .set_dataloader(dl: Dataloader)'
2660 |
2661 | batch_size = self.train_dl_batch_size
2662 | dl_iter = cycle(self.train_dl)
2663 |
2664 | last_gp_loss = 0.
2665 | last_multiscale_d_loss = 0.
2666 | last_multiscale_g_loss = 0.
2667 |
2668 | for _ in tqdm(range(steps), initial = self.steps.item()):
2669 | steps = self.steps.item()
2670 | is_first_step = steps == 1
2671 |
2672 | apply_gradient_penalty = self.apply_gradient_penalty_every > 0 and divisible_by(steps, self.apply_gradient_penalty_every)
2673 | calc_multiscale_loss = self.calc_multiscale_loss_every > 0 and divisible_by(steps, self.calc_multiscale_loss_every)
2674 |
2675 | (
2676 | d_loss,
2677 | multiscale_d_loss,
2678 | vision_aided_d_loss,
2679 | matching_aware_loss,
2680 | gp_loss,
2681 | recon_loss
2682 | ) = self.train_discriminator_step(
2683 | dl_iter = dl_iter,
2684 | grad_accum_every = grad_accum_every,
2685 | apply_gradient_penalty = apply_gradient_penalty,
2686 | calc_multiscale_loss = calc_multiscale_loss
2687 | )
2688 |
2689 | self.accelerator.wait_for_everyone()
2690 |
2691 | (
2692 | g_loss,
2693 | multiscale_g_loss,
2694 | vision_aided_g_loss,
2695 | contrastive_loss
2696 | ) = self.train_generator_step(
2697 | dl_iter = dl_iter,
2698 | batch_size = batch_size,
2699 | grad_accum_every = grad_accum_every,
2700 | calc_multiscale_loss = calc_multiscale_loss
2701 | )
2702 |
2703 | if exists(gp_loss):
2704 | last_gp_loss = gp_loss
2705 |
2706 | if exists(multiscale_d_loss):
2707 | last_multiscale_d_loss = multiscale_d_loss
2708 |
2709 | if exists(multiscale_g_loss):
2710 | last_multiscale_g_loss = multiscale_g_loss
2711 |
2712 | if is_first_step or divisible_by(steps, self.log_steps_every):
2713 |
2714 | losses = (
2715 | ('G', g_loss),
2716 | ('MSG', last_multiscale_g_loss),
2717 | ('VG', vision_aided_g_loss),
2718 | ('D', d_loss),
2719 | ('MSD', last_multiscale_d_loss),
2720 | ('VD', vision_aided_d_loss),
2721 | ('GP', last_gp_loss),
2722 | ('SSL', recon_loss),
2723 | ('CL', contrastive_loss),
2724 | ('MAL', matching_aware_loss)
2725 | )
2726 |
2727 | losses_str = ' | '.join([f'{loss_name}: {loss:.2f}' for loss_name, loss in losses])
2728 |
2729 | self.print(losses_str)
2730 |
2731 | self.accelerator.wait_for_everyone()
2732 |
2733 | if self.is_main and (is_first_step or divisible_by(steps, self.save_and_sample_every) or (steps <= self.early_save_thres_steps and divisible_by(steps, self.early_save_and_sample_every))):
2734 | self.save_sample(batch_size, dl_iter)
2735 |
2736 | self.steps += 1
2737 |
2738 | self.print(f'complete {steps} training steps')
2739 |
--------------------------------------------------------------------------------