├── gigagan_pytorch ├── version.py ├── __init__.py ├── optimizer.py ├── distributed.py ├── attend.py ├── data.py ├── open_clip.py ├── unet_upsampler.py └── gigagan_pytorch.py ├── gigagan-sample.png ├── gigagan-architecture.png ├── pyproject.toml ├── .pre-commit-config.yaml ├── LICENSE ├── setup.py ├── .github └── workflows │ └── python-publish.yml ├── .gitignore └── README.md /gigagan_pytorch/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.2.20' 2 | -------------------------------------------------------------------------------- /gigagan-sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/gigagan-pytorch/main/gigagan-sample.png -------------------------------------------------------------------------------- /gigagan-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/gigagan-pytorch/main/gigagan-architecture.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | line-length = 1000 3 | ignore-init-module-imports = true 4 | exclude = ["setup.py"] -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | repos: 3 | - repo: https://github.com/astral-sh/ruff-pre-commit 4 | rev: v0.0.278 5 | hooks: 6 | - id: ruff 7 | args: [ --fix, --exit-non-zero-on-fix] 8 | -------------------------------------------------------------------------------- /gigagan_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from gigagan_pytorch.gigagan_pytorch import ( 2 | GigaGAN, 3 | Generator, 4 | Discriminator, 5 | VisionAidedDiscriminator, 6 | AdaptiveConv2DMod, 7 | StyleNetwork, 8 | TextEncoder 9 | ) 10 | 11 | from gigagan_pytorch.unet_upsampler import UnetUpsampler 12 | 13 | from gigagan_pytorch.data import ( 14 | ImageDataset, 15 | TextImageDataset, 16 | MockTextImageDataset 17 | ) 18 | 19 | __all__ = [ 20 | GigaGAN, 21 | Generator, 22 | Discriminator, 23 | VisionAidedDiscriminator, 24 | AdaptiveConv2DMod, 25 | StyleNetwork, 26 | UnetUpsampler, 27 | TextEncoder, 28 | ImageDataset, 29 | TextImageDataset, 30 | MockTextImageDataset 31 | ] 32 | -------------------------------------------------------------------------------- /gigagan_pytorch/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import AdamW, Adam 2 | 3 | def separate_weight_decayable_params(params): 4 | wd_params, no_wd_params = [], [] 5 | for param in params: 6 | param_list = no_wd_params if param.ndim < 2 else wd_params 7 | param_list.append(param) 8 | return wd_params, no_wd_params 9 | 10 | def get_optimizer( 11 | params, 12 | lr = 1e-4, 13 | wd = 1e-2, 14 | betas = (0.9, 0.99), 15 | eps = 1e-8, 16 | filter_by_requires_grad = True, 17 | group_wd_params = True, 18 | **kwargs 19 | ): 20 | if filter_by_requires_grad: 21 | params = list(filter(lambda t: t.requires_grad, params)) 22 | 23 | if group_wd_params and wd > 0: 24 | wd_params, no_wd_params = separate_weight_decayable_params(params) 25 | 26 | params = [ 27 | {'params': wd_params}, 28 | {'params': no_wd_params, 'weight_decay': 0}, 29 | ] 30 | 31 | if wd == 0: 32 | return Adam(params, lr = lr, betas = betas, eps = eps) 33 | 34 | return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps) 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | exec(open('gigagan_pytorch/version.py').read()) 4 | 5 | setup( 6 | name = 'gigagan-pytorch', 7 | packages = find_packages(exclude=[]), 8 | version = __version__, 9 | license='MIT', 10 | description = 'GigaGAN - Pytorch', 11 | author = 'Phil Wang', 12 | author_email = 'lucidrains@gmail.com', 13 | long_description_content_type = 'text/markdown', 14 | url = 'https://github.com/lucidrains/ETSformer-pytorch', 15 | keywords = [ 16 | 'artificial intelligence', 17 | 'deep learning', 18 | 'generative adversarial networks' 19 | ], 20 | install_requires=[ 21 | 'accelerate', 22 | 'beartype', 23 | 'einops>=0.6', 24 | 'ema-pytorch', 25 | 'kornia', 26 | 'numerize', 27 | 'open-clip-torch>=2.0.0,<3.0.0', 28 | 'pillow', 29 | 'torch>=1.6', 30 | 'torchvision', 31 | 'tqdm' 32 | ], 33 | classifiers=[ 34 | 'Development Status :: 4 - Beta', 35 | 'Intended Audience :: Developers', 36 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 37 | 'License :: OSI Approved :: MIT License', 38 | 'Programming Language :: Python :: 3.6', 39 | ], 40 | ) 41 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /gigagan_pytorch/distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Function 4 | import torch.distributed as dist 5 | 6 | from einops import rearrange 7 | 8 | # helpers 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | def pad_dim_to(t, length, dim = 0): 14 | pad_length = length - t.shape[dim] 15 | zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) 16 | return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length)) 17 | 18 | # distributed helpers 19 | 20 | def all_gather_variable_dim(t, dim = 0, sizes = None): 21 | device, world_size = t.device, dist.get_world_size() 22 | 23 | if not exists(sizes): 24 | size = torch.tensor(t.shape[dim], device = device, dtype = torch.long) 25 | sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)] 26 | dist.all_gather(sizes, size) 27 | sizes = torch.stack(sizes) 28 | 29 | max_size = sizes.amax().item() 30 | padded_t = pad_dim_to(t, max_size, dim = dim) 31 | 32 | gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)] 33 | dist.all_gather(gathered_tensors, padded_t) 34 | 35 | gathered_tensor = torch.cat(gathered_tensors, dim = dim) 36 | seq = torch.arange(max_size, device = device) 37 | 38 | mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1') 39 | mask = rearrange(mask, 'i j -> (i j)') 40 | seq = torch.arange(mask.shape[-1], device = device) 41 | indices = seq[mask] 42 | 43 | gathered_tensor = gathered_tensor.index_select(dim, indices) 44 | 45 | return gathered_tensor, sizes 46 | 47 | class AllGather(Function): 48 | @staticmethod 49 | def forward(ctx, x, dim, sizes): 50 | is_dist = dist.is_initialized() and dist.get_world_size() > 1 51 | ctx.is_dist = is_dist 52 | 53 | if not is_dist: 54 | return x, None 55 | 56 | x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes) 57 | ctx.batch_sizes = batch_sizes.tolist() 58 | ctx.dim = dim 59 | return x, batch_sizes 60 | 61 | @staticmethod 62 | def backward(ctx, grads, _): 63 | if not ctx.is_dist: 64 | return grads, None, None 65 | 66 | batch_sizes, rank = ctx.batch_sizes, dist.get_rank() 67 | grads_by_rank = grads.split(batch_sizes, dim = ctx.dim) 68 | return grads_by_rank[rank], None, None 69 | 70 | all_gather = AllGather.apply 71 | -------------------------------------------------------------------------------- /gigagan_pytorch/attend.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from packaging import version 3 | from collections import namedtuple 4 | 5 | import torch 6 | from torch import nn, einsum 7 | import torch.nn.functional as F 8 | 9 | 10 | # constants 11 | 12 | AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) 13 | 14 | # helpers 15 | 16 | def exists(val): 17 | return val is not None 18 | 19 | def once(fn): 20 | called = False 21 | @wraps(fn) 22 | def inner(x): 23 | nonlocal called 24 | if called: 25 | return 26 | called = True 27 | return fn(x) 28 | return inner 29 | 30 | print_once = once(print) 31 | 32 | # main class 33 | 34 | class Attend(nn.Module): 35 | def __init__( 36 | self, 37 | dropout = 0., 38 | flash = False 39 | ): 40 | super().__init__() 41 | self.dropout = dropout 42 | self.attn_dropout = nn.Dropout(dropout) 43 | 44 | self.flash = flash 45 | assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' 46 | 47 | # determine efficient attention configs for cuda and cpu 48 | 49 | self.cpu_config = AttentionConfig(True, True, True) 50 | self.cuda_config = None 51 | 52 | if not torch.cuda.is_available() or not flash: 53 | return 54 | 55 | device_properties = torch.cuda.get_device_properties(torch.device('cuda')) 56 | 57 | if device_properties.major == 8 and device_properties.minor == 0: 58 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda') 59 | self.cuda_config = AttentionConfig(True, False, False) 60 | else: 61 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') 62 | self.cuda_config = AttentionConfig(False, True, True) 63 | 64 | def flash_attn(self, q, k, v): 65 | is_cuda = q.is_cuda 66 | 67 | q, k, v = map(lambda t: t.contiguous(), (q, k, v)) 68 | 69 | # Check if there is a compatible device for flash attention 70 | 71 | config = self.cuda_config if is_cuda else self.cpu_config 72 | 73 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale 74 | 75 | with torch.backends.cuda.sdp_kernel(**config._asdict()): 76 | out = F.scaled_dot_product_attention( 77 | q, k, v, 78 | dropout_p = self.dropout if self.training else 0. 79 | ) 80 | 81 | return out 82 | 83 | def forward(self, q, k, v): 84 | """ 85 | einstein notation 86 | b - batch 87 | h - heads 88 | n, i, j - sequence length (base sequence length, source, target) 89 | d - feature dimension 90 | """ 91 | 92 | if self.flash: 93 | return self.flash_attn(q, k, v) 94 | 95 | scale = q.shape[-1] ** -0.5 96 | 97 | # similarity 98 | 99 | sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale 100 | 101 | # attention 102 | 103 | attn = sim.softmax(dim = -1) 104 | attn = self.attn_dropout(attn) 105 | 106 | # aggregate values 107 | 108 | out = einsum("b h i j, b h j d -> b h i d", attn, v) 109 | 110 | return out 111 | -------------------------------------------------------------------------------- /gigagan_pytorch/data.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch import nn 6 | from torch.utils.data import Dataset, DataLoader 7 | 8 | from PIL import Image 9 | from torchvision import transforms as T 10 | 11 | from beartype.door import is_bearable 12 | from beartype.typing import Tuple 13 | 14 | # helper functions 15 | 16 | def exists(val): 17 | return val is not None 18 | 19 | def convert_image_to_fn(img_type, image): 20 | if image.mode == img_type: 21 | return image 22 | 23 | return image.convert(img_type) 24 | 25 | # custom collation function 26 | # so dataset can return a str and it will collate into List[str] 27 | 28 | def collate_tensors_or_str(data): 29 | is_one_data = not isinstance(data[0], tuple) 30 | 31 | if is_one_data: 32 | data = torch.stack(data) 33 | return (data,) 34 | 35 | outputs = [] 36 | for datum in zip(*data): 37 | if is_bearable(datum, Tuple[str, ...]): 38 | output = list(datum) 39 | else: 40 | output = torch.stack(datum) 41 | 42 | outputs.append(output) 43 | 44 | return tuple(outputs) 45 | 46 | # dataset classes 47 | 48 | class ImageDataset(Dataset): 49 | def __init__( 50 | self, 51 | folder, 52 | image_size, 53 | exts = ['jpg', 'jpeg', 'png', 'tiff'], 54 | augment_horizontal_flip = False, 55 | convert_image_to = None 56 | ): 57 | super().__init__() 58 | self.folder = folder 59 | self.image_size = image_size 60 | 61 | self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] 62 | 63 | assert len(self.paths) > 0, 'your folder contains no images' 64 | assert len(self.paths) > 100, 'you need at least 100 images, 10k for research paper, millions for miraculous results (try Laion-5B)' 65 | 66 | maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else nn.Identity() 67 | 68 | self.transform = T.Compose([ 69 | T.Lambda(maybe_convert_fn), 70 | T.Resize(image_size), 71 | T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(), 72 | T.CenterCrop(image_size), 73 | T.ToTensor() 74 | ]) 75 | 76 | def get_dataloader(self, *args, **kwargs): 77 | return DataLoader(self, *args, shuffle = True, drop_last = True, **kwargs) 78 | 79 | def __len__(self): 80 | return len(self.paths) 81 | 82 | def __getitem__(self, index): 83 | path = self.paths[index] 84 | img = Image.open(path) 85 | return self.transform(img) 86 | 87 | class TextImageDataset(Dataset): 88 | def __init__(self): 89 | raise NotImplementedError 90 | 91 | def get_dataloader(self, *args, **kwargs): 92 | return DataLoader(self, *args, collate_fn = collate_tensors_or_str, **kwargs) 93 | 94 | class MockTextImageDataset(TextImageDataset): 95 | def __init__( 96 | self, 97 | image_size, 98 | length = int(1e5), 99 | channels = 3 100 | ): 101 | self.image_size = image_size 102 | self.channels = channels 103 | self.length = length 104 | 105 | def get_dataloader(self, *args, **kwargs): 106 | return DataLoader(self, *args, collate_fn = collate_tensors_or_str, **kwargs) 107 | 108 | def __len__(self): 109 | return self.length 110 | 111 | def __getitem__(self, index): 112 | mock_image = torch.randn(self.channels, self.image_size, self.image_size) 113 | return mock_image, 'mock text' 114 | -------------------------------------------------------------------------------- /gigagan_pytorch/open_clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | import open_clip 5 | 6 | from einops import rearrange 7 | 8 | from beartype import beartype 9 | from beartype.typing import List, Optional 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | def l2norm(t): 15 | return F.normalize(t, dim = -1) 16 | 17 | class OpenClipAdapter(nn.Module): 18 | @beartype 19 | def __init__( 20 | self, 21 | name = 'ViT-B/32', 22 | pretrained = 'laion400m_e32', 23 | tokenizer_name = 'ViT-B-32-quickgelu', 24 | eos_id = 49407 25 | ): 26 | super().__init__() 27 | 28 | clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained) 29 | tokenizer = open_clip.get_tokenizer(tokenizer_name) 30 | 31 | self.clip = clip 32 | self.tokenizer = tokenizer 33 | self.eos_id = eos_id 34 | 35 | # hook for getting final text representation 36 | 37 | text_attention_final = self.find_layer('ln_final') 38 | self._dim_latent = text_attention_final.weight.shape[0] 39 | self.text_handle = text_attention_final.register_forward_hook(self._text_hook) 40 | 41 | # hook for getting final image representation 42 | # this is for vision-aided gan loss 43 | 44 | self._dim_image_latent = self.find_layer('visual.ln_post').weight.shape[0] 45 | 46 | num_visual_layers = len(clip.visual.transformer.resblocks) 47 | self.image_handles = [] 48 | 49 | for visual_layer in range(num_visual_layers): 50 | image_attention_final = self.find_layer(f'visual.transformer.resblocks.{visual_layer}') 51 | 52 | handle = image_attention_final.register_forward_hook(self._image_hook) 53 | self.image_handles.append(handle) 54 | 55 | # normalize fn 56 | 57 | self.clip_normalize = preprocess.transforms[-1] 58 | self.cleared = False 59 | 60 | @property 61 | def device(self): 62 | return next(self.parameters()).device 63 | 64 | def find_layer(self, layer): 65 | modules = dict([*self.clip.named_modules()]) 66 | return modules.get(layer, None) 67 | 68 | def clear(self): 69 | if self.cleared: 70 | return 71 | 72 | self.text_handle() 73 | self.image_handle() 74 | 75 | def _text_hook(self, _, inputs, outputs): 76 | self.text_encodings = outputs 77 | 78 | def _image_hook(self, _, inputs, outputs): 79 | if not hasattr(self, 'image_encodings'): 80 | self.image_encodings = [] 81 | 82 | self.image_encodings.append(outputs) 83 | 84 | @property 85 | def dim_latent(self): 86 | return self._dim_latent 87 | 88 | @property 89 | def image_size(self): 90 | image_size = self.clip.visual.image_size 91 | if isinstance(image_size, tuple): 92 | return max(image_size) 93 | return image_size 94 | 95 | @property 96 | def image_channels(self): 97 | return 3 98 | 99 | @property 100 | def max_text_len(self): 101 | return self.clip.positional_embedding.shape[0] 102 | 103 | @beartype 104 | def embed_texts( 105 | self, 106 | texts: List[str] 107 | ): 108 | ids = self.tokenizer(texts) 109 | ids = ids.to(self.device) 110 | ids = ids[..., :self.max_text_len] 111 | 112 | is_eos_id = (ids == self.eos_id) 113 | text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0 114 | text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True) 115 | text_mask = text_mask & (ids != 0) 116 | assert not self.cleared 117 | 118 | text_embed = self.clip.encode_text(ids) 119 | text_encodings = self.text_encodings 120 | text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.) 121 | del self.text_encodings 122 | return l2norm(text_embed.float()), text_encodings.float() 123 | 124 | def embed_images(self, images): 125 | if images.shape[-1] != self.image_size: 126 | images = F.interpolate(images, self.image_size) 127 | 128 | assert not self.cleared 129 | images = self.clip_normalize(images) 130 | image_embeds = self.clip.encode_image(images) 131 | 132 | image_encodings = rearrange(self.image_encodings, 'l n b d -> l b n d') 133 | del self.image_encodings 134 | 135 | return l2norm(image_embeds.float()), image_encodings.float() 136 | 137 | @beartype 138 | def contrastive_loss( 139 | self, 140 | images, 141 | texts: Optional[List[str]] = None, 142 | text_embeds: Optional[torch.Tensor] = None 143 | ): 144 | assert exists(texts) ^ exists(text_embeds) 145 | 146 | if not exists(text_embeds): 147 | text_embeds, _ = self.embed_texts(texts) 148 | 149 | image_embeds, _ = self.embed_images(images) 150 | 151 | n = text_embeds.shape[0] 152 | 153 | temperature = self.clip.logit_scale.exp() 154 | sim = einsum('i d, j d -> i j', text_embeds, image_embeds) * temperature 155 | 156 | labels = torch.arange(n, device = sim.device) 157 | 158 | return (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 159 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | ## GigaGAN - Pytorch 6 | 7 | Implementation of GigaGAN (project page), new SOTA GAN out of Adobe. 8 | 9 | I will also add a few findings from lightweight gan, for faster convergence (skip layer excitation) and better stability (reconstruction auxiliary loss in discriminator) 10 | 11 | It will also contain the code for the 1k - 4k upsamplers, which I find to be the highlight of this paper. 12 | 13 | Please join Join us on Discord if you are interested in helping out with the replication with the LAION community 14 | 15 | ## Appreciation 16 | 17 | - StabilityAI and 🤗 Huggingface for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence. 18 | 19 | - 🤗 Huggingface for their accelerate library 20 | 21 | - All the maintainers at OpenClip, for their SOTA open sourced contrastive learning text-image models 22 | 23 | - Xavier for the very helpful code review, and for discussions on how the scale invariance in the discriminator should be built! 24 | 25 | - @CerebralSeed for pull requesting the initial sampling code for both the generator and upsampler! 26 | 27 | - Keerth for the code review and pointing out some discrepancies with the paper! 28 | 29 | ## Install 30 | 31 | ```bash 32 | $ pip install gigagan-pytorch 33 | ``` 34 | 35 | ## Usage 36 | 37 | Simple unconditional GAN, for starters 38 | 39 | ```python 40 | import torch 41 | 42 | from gigagan_pytorch import ( 43 | GigaGAN, 44 | ImageDataset 45 | ) 46 | 47 | gan = GigaGAN( 48 | generator = dict( 49 | dim_capacity = 8, 50 | style_network = dict( 51 | dim = 64, 52 | depth = 4 53 | ), 54 | image_size = 256, 55 | dim_max = 512, 56 | num_skip_layers_excite = 4, 57 | unconditional = True 58 | ), 59 | discriminator = dict( 60 | dim_capacity = 16, 61 | dim_max = 512, 62 | image_size = 256, 63 | num_skip_layers_excite = 4, 64 | unconditional = True 65 | ), 66 | amp = True 67 | ).cuda() 68 | 69 | # dataset 70 | 71 | dataset = ImageDataset( 72 | folder = '/path/to/your/data', 73 | image_size = 256 74 | ) 75 | 76 | dataloader = dataset.get_dataloader(batch_size = 1) 77 | 78 | # you must then set the dataloader for the GAN before training 79 | 80 | gan.set_dataloader(dataloader) 81 | 82 | # training the discriminator and generator alternating 83 | # for 100 steps in this example, batch size 1, gradient accumulated 8 times 84 | 85 | gan( 86 | steps = 100, 87 | grad_accum_every = 8 88 | ) 89 | 90 | # after much training 91 | 92 | images = gan.generate(batch_size = 4) # (4, 3, 256, 256) 93 | ``` 94 | 95 | For unconditional Unet Upsampler 96 | 97 | ```python 98 | import torch 99 | from gigagan_pytorch import ( 100 | GigaGAN, 101 | ImageDataset 102 | ) 103 | 104 | gan = GigaGAN( 105 | train_upsampler = True, # set this to True 106 | generator = dict( 107 | style_network = dict( 108 | dim = 64, 109 | depth = 4 110 | ), 111 | dim = 32, 112 | image_size = 256, 113 | input_image_size = 64, 114 | unconditional = True 115 | ), 116 | discriminator = dict( 117 | dim_capacity = 16, 118 | dim_max = 512, 119 | image_size = 256, 120 | num_skip_layers_excite = 4, 121 | multiscale_input_resolutions = (128,), 122 | unconditional = True 123 | ), 124 | amp = True 125 | ).cuda() 126 | 127 | dataset = ImageDataset( 128 | folder = '/path/to/your/data', 129 | image_size = 256 130 | ) 131 | 132 | dataloader = dataset.get_dataloader(batch_size = 1) 133 | 134 | gan.set_dataloader(dataloader) 135 | 136 | # training the discriminator and generator alternating 137 | # for 100 steps in this example, batch size 1, gradient accumulated 8 times 138 | 139 | gan( 140 | steps = 100, 141 | grad_accum_every = 8 142 | ) 143 | 144 | # after much training 145 | 146 | lowres = torch.randn(1, 3, 64, 64).cuda() 147 | 148 | images = gan.generate(lowres) # (1, 3, 256, 256) 149 | ``` 150 | 151 | ## Losses 152 | 153 | * `G` - Generator 154 | * `MSG` - Multiscale Generator 155 | * `D` - Discriminator 156 | * `MSD` - Multiscale Discriminator 157 | * `GP` - Gradient Penalty 158 | * `SSL` - Auxiliary Reconstruction in Discriminator (from Lightweight GAN) 159 | * `VD` - Vision-aided Discriminator 160 | * `VG` - Vision-aided Generator 161 | * `CL` - Generator Constrastive Loss 162 | * `MAL` - Matching Aware Loss 163 | 164 | A healthy run would have `G`, `MSG`, `D`, `MSD` with values hovering between `0` to `10`, and usually staying pretty constant. If at any time after 1k training steps these values persist at triple digits, that would mean something is wrong. It is ok for generator and discriminator values to occasionally dip negative, but it should swing back up to the range above. 165 | 166 | `GP` and `SSL` should be pushed towards `0`. `GP` can occasionally spike; I like to imagine it as the networks undergoing some epiphany 167 | 168 | ## Multi-GPU Training 169 | 170 | The `GigaGAN` class is now equipped with 🤗 Accelerator. You can easily do multi-gpu training in two steps using their `accelerate` CLI 171 | 172 | At the project root directory, where the training script is, run 173 | 174 | ```python 175 | $ accelerate config 176 | ``` 177 | 178 | Then, in the same directory 179 | 180 | ```python 181 | $ accelerate launch train.py 182 | ``` 183 | 184 | ## Todo 185 | 186 | - [x] make sure it can be trained unconditionally 187 | - [x] read the relevant papers and knock out all 3 auxiliary losses 188 | - [x] matching aware loss 189 | - [x] clip loss 190 | - [x] vision-aided discriminator loss 191 | - [x] add reconstruction losses on arbitrary stages in the discriminator (lightweight gan) 192 | - [x] figure out how the random projections are used from projected-gan 193 | - [x] vision aided discriminator needs to extract N layers from the vision model in CLIP 194 | - [x] figure out whether to discard CLS token and reshape into image dimensions for convolution, or stick with attention and condition with adaptive layernorm - also turn off vision aided gan in unconditional case 195 | - [x] unet upsampler 196 | - [x] add adaptive conv 197 | - [x] modify latter stage of unet to also output rgb residuals, and pass the rgb into discriminator. make discriminator agnostic to rgb being passed in 198 | - [x] do pixel shuffle upsamples for unet 199 | - [x] get a code review for the multi-scale inputs and outputs, as the paper was a bit vague 200 | - [x] add upsampling network architecture 201 | - [x] make unconditional work for both base generator and upsampler 202 | - [x] make text conditioned training work for both base and upsampler 203 | - [x] make recon more efficient by random sampling patches 204 | - [x] make sure generator and discriminator can also accept pre-encoded CLIP text encodings 205 | - [x] do a review of the auxiliary losses 206 | - [x] add contrastive loss for generator 207 | - [x] add vision aided loss 208 | - [x] add gradient penalty for vision aided discr - make optional 209 | - [x] add matching awareness loss - figure out if rotating text conditions by one is good enough for mismatching (without drawing an additional batch from dataloader) 210 | - [x] make sure gradient accumulation works with matching aware loss 211 | - [x] matching awareness loss runs and is stable 212 | - [x] vision aided trains 213 | - [x] add some differentiable augmentations, proven technique from the old GAN days 214 | - [x] remove any magic being done with automatic rgbs processing, and have it explicitly passed in - offer functions on the discriminator that can process real images into the right multi-scales 215 | - [x] add horizontal flip for starters 216 | 217 | - [ ] move all modulation projections into the adaptive conv2d class 218 | - [ ] add accelerate 219 | - [x] works single machine 220 | - [x] works for mixed precision (make sure gradient penalty is scaled correctly), take care of manual scaler saving and reloading, borrow from imagen-pytorch 221 | - [x] make sure it works multi-GPU for one machine 222 | - [ ] have someone else try multiple machines 223 | 224 | - [ ] clip should be optional for all modules, and managed by `GigaGAN`, with text -> text embeds processed once 225 | - [ ] add ability to select a random subset from multiscale dimension, for efficiency 226 | 227 | - [ ] port over CLI from lightweight|stylegan2-pytorch 228 | - [ ] hook up laion dataset for text-image 229 | 230 | ## Citations 231 | 232 | ```bibtex 233 | @misc{https://doi.org/10.48550/arxiv.2303.05511, 234 | url = {https://arxiv.org/abs/2303.05511}, 235 | author = {Kang, Minguk and Zhu, Jun-Yan and Zhang, Richard and Park, Jaesik and Shechtman, Eli and Paris, Sylvain and Park, Taesung}, 236 | title = {Scaling up GANs for Text-to-Image Synthesis}, 237 | publisher = {arXiv}, 238 | year = {2023}, 239 | copyright = {arXiv.org perpetual, non-exclusive license} 240 | } 241 | ``` 242 | 243 | ```bibtex 244 | @article{Liu2021TowardsFA, 245 | title = {Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis}, 246 | author = {Bingchen Liu and Yizhe Zhu and Kunpeng Song and A. Elgammal}, 247 | journal = {ArXiv}, 248 | year = {2021}, 249 | volume = {abs/2101.04775} 250 | } 251 | ``` 252 | 253 | ```bibtex 254 | @inproceedings{dao2022flashattention, 255 | title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, 256 | author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, 257 | booktitle = {Advances in Neural Information Processing Systems}, 258 | year = {2022} 259 | } 260 | ``` 261 | 262 | ```bibtex 263 | @inproceedings{Karras2020ada, 264 | title = {Training Generative Adversarial Networks with Limited Data}, 265 | author = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila}, 266 | booktitle = {Proc. NeurIPS}, 267 | year = {2020} 268 | } 269 | ``` 270 | 271 | ```bibtex 272 | @article{Xu2024VideoGigaGANTD, 273 | title = {VideoGigaGAN: Towards Detail-rich Video Super-Resolution}, 274 | author = {Yiran Xu and Taesung Park and Richard Zhang and Yang Zhou and Eli Shechtman and Feng Liu and Jia-Bin Huang and Difan Liu}, 275 | journal = {ArXiv}, 276 | year = {2024}, 277 | volume = {abs/2404.12388}, 278 | url ={https://api.semanticscholar.org/CorpusID:269214195} 279 | } 280 | ``` 281 | -------------------------------------------------------------------------------- /gigagan_pytorch/unet_upsampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from math import log2 4 | from functools import partial 5 | from itertools import islice 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from torch.nn import Module, ModuleList 11 | 12 | from einops import rearrange, repeat, pack, unpack 13 | from einops.layers.torch import Rearrange 14 | 15 | from gigagan_pytorch.attend import Attend 16 | from gigagan_pytorch.gigagan_pytorch import ( 17 | BaseGenerator, 18 | StyleNetwork, 19 | AdaptiveConv2DMod, 20 | AdaptiveConv1DMod, 21 | TextEncoder, 22 | CrossAttentionBlock, 23 | Upsample, 24 | PixelShuffleUpsample, 25 | Blur 26 | ) 27 | 28 | from kornia.filters import filter3d, filter2d 29 | 30 | from beartype import beartype 31 | from beartype.typing import List, Dict, Iterable, Literal 32 | 33 | # helpers functions 34 | 35 | def exists(x): 36 | return x is not None 37 | 38 | def default(val, d): 39 | if exists(val): 40 | return val 41 | return d() if callable(d) else d 42 | 43 | def pack_one(t, pattern): 44 | return pack([t], pattern) 45 | 46 | def unpack_one(t, ps, pattern): 47 | return unpack(t, ps, pattern)[0] 48 | 49 | def cast_tuple(t, length = 1): 50 | if isinstance(t, tuple): 51 | return t 52 | return ((t,) * length) 53 | 54 | def identity(t, *args, **kwargs): 55 | return t 56 | 57 | def is_power_of_two(n): 58 | return log2(n).is_integer() 59 | 60 | def null_iterator(): 61 | while True: 62 | yield None 63 | 64 | def fold_space_into_batch(x): 65 | x = rearrange(x, 'b c t h w -> b h w c t') 66 | x, ps = pack_one(x, '* c t') 67 | 68 | def split_space_from_batch(out): 69 | out = unpack_one(x, ps, '* c t') 70 | out = rearrange(out, 'b h w c t -> b c t h w') 71 | return out 72 | 73 | return x, split_space_from_batch 74 | 75 | # small helper modules 76 | 77 | def interpolate_1d(x, length, mode = 'bilinear'): 78 | x = rearrange(x, 'b c t -> b c t 1') 79 | x = F.interpolate(x, (length, 1), mode = mode) 80 | return rearrange(x, 'b c t 1 -> b c t') 81 | 82 | class Downsample(Module): 83 | def __init__( 84 | self, 85 | dim, 86 | dim_out = None, 87 | skip_downsample = False, 88 | has_temporal_layers = False 89 | ): 90 | super().__init__() 91 | dim_out = default(dim_out, dim) 92 | 93 | self.skip_downsample = skip_downsample 94 | 95 | self.conv2d = nn.Conv2d(dim, dim_out, 3, padding = 1) 96 | 97 | self.has_temporal_layers = has_temporal_layers 98 | 99 | if has_temporal_layers: 100 | self.conv1d = nn.Conv1d(dim_out, dim_out, 3, padding = 1) 101 | 102 | nn.init.dirac_(self.conv1d.weight) 103 | nn.init.zeros_(self.conv1d.bias) 104 | 105 | self.register_buffer('filter', torch.Tensor([1., 2., 1.])) 106 | 107 | def forward(self, x): 108 | batch = x.shape[0] 109 | is_input_video = x.ndim == 5 110 | 111 | assert not (is_input_video and not self.has_temporal_layers) 112 | 113 | if is_input_video: 114 | x = rearrange(x, 'b c t h w -> (b t) c h w') 115 | 116 | x = self.conv2d(x) 117 | 118 | if is_input_video: 119 | x = rearrange(x, '(b t) c h w -> b h w c t', b = batch) 120 | x, ps = pack_one(x, '* c t') 121 | 122 | x = self.conv1d(x) 123 | 124 | x = unpack_one(x, ps, '* c t') 125 | x = rearrange(x, 'b h w c t -> b c t h w') 126 | 127 | # if not downsampling, early return 128 | 129 | if self.skip_downsample: 130 | return x, x[:, 0:0] 131 | 132 | # save before blur to subtract out for high frequency fmap skip connection 133 | 134 | before_blur_input = x 135 | 136 | # blur 2d or 3d, depending 137 | 138 | f = self.filter 139 | N = None 140 | 141 | if is_input_video: 142 | f = f[N, N, :] * f[N, :, N] * f[:, N, N] 143 | filter_fn = filter3d 144 | maxpool_fn = F.max_pool3d 145 | else: 146 | f = f[N, :] * f[:, N] 147 | filter_fn = filter2d 148 | maxpool_fn = F.max_pool2d 149 | 150 | blurred = filter_fn(x, f[N, ...], normalized = True) 151 | 152 | # get high frequency fmap 153 | 154 | high_freq_fmap = before_blur_input - blurred 155 | 156 | # max pool 2d or 3d, depending 157 | 158 | x = maxpool_fn(x, kernel_size = 2) 159 | 160 | return x, high_freq_fmap 161 | 162 | class TemporalBlur(Module): 163 | def __init__(self): 164 | super().__init__() 165 | f = torch.Tensor([1, 2, 1]) 166 | self.register_buffer('f', f) 167 | 168 | def forward(self, x): 169 | f = repeat(self.f, 't -> 1 t h w', h = 3, w = 3) 170 | return filter3d(x, f, normalized = True) 171 | 172 | class TemporalUpsample(Module): 173 | def __init__( 174 | self, 175 | dim, 176 | dim_out = None 177 | ): 178 | super().__init__() 179 | self.blur = TemporalBlur() 180 | 181 | def forward(self, x): 182 | assert x.ndim == 5 183 | time = x.shape[2] 184 | 185 | x = rearrange(x, 'b c t h w -> b h w c t') 186 | x, ps = pack_one(x, '* c t') 187 | 188 | x = interpolate_1d(x, time * 2, mode = 'bilinear') 189 | 190 | x = unpack_one(x, ps, '* c t') 191 | x = rearrange(x, 'b h w c t -> b c t h w') 192 | x = self.blur(x) 193 | return x 194 | 195 | class PixelShuffleTemporalUpsample(Module): 196 | def __init__(self, dim, dim_out = None): 197 | super().__init__() 198 | dim_out = default(dim_out, dim) 199 | 200 | conv = nn.Conv3d(dim, dim_out * 2, 1) 201 | 202 | self.net = nn.Sequential( 203 | conv, 204 | nn.SiLU(), 205 | Rearrange('b (c p) t h w -> b c (t p) h w', p = 2) 206 | ) 207 | 208 | self.init_conv_(conv) 209 | 210 | def init_conv_(self, conv): 211 | o, i, t, h, w = conv.weight.shape 212 | conv_weight = torch.empty(o // 2, i, t, h, w) 213 | nn.init.kaiming_uniform_(conv_weight) 214 | conv_weight = repeat(conv_weight, 'o ... -> (o 2) ...') 215 | 216 | conv.weight.data.copy_(conv_weight) 217 | nn.init.zeros_(conv.bias.data) 218 | 219 | def forward(self, x): 220 | return self.net(x) 221 | 222 | # norm 223 | 224 | class RMSNorm(Module): 225 | def __init__(self, dim): 226 | super().__init__() 227 | self.scale = dim ** 0.5 228 | self.gamma = nn.Parameter(torch.ones(dim)) 229 | 230 | def forward(self, x): 231 | spatial_dims = ((1,) * (x.ndim - 2)) 232 | gamma = self.gamma.reshape(-1, *spatial_dims) 233 | 234 | return F.normalize(x, dim = 1) * gamma * self.scale 235 | 236 | # building block modules 237 | 238 | class Block(Module): 239 | @beartype 240 | def __init__( 241 | self, 242 | dim, 243 | dim_out, 244 | num_conv_kernels = 0, 245 | conv_type: Literal['1d', '2d'] = '2d', 246 | ): 247 | super().__init__() 248 | 249 | adaptive_conv_klass = AdaptiveConv2DMod if conv_type == '2d' else AdaptiveConv1DMod 250 | 251 | self.proj = adaptive_conv_klass(dim, dim_out, kernel = 3, num_conv_kernels = num_conv_kernels) 252 | self.norm = RMSNorm(dim_out) 253 | self.act = nn.SiLU() 254 | 255 | def forward( 256 | self, 257 | x, 258 | conv_mods_iter: Iterable | None = None 259 | ): 260 | conv_mods_iter = default(conv_mods_iter, null_iterator()) 261 | 262 | x = self.proj( 263 | x, 264 | mod = next(conv_mods_iter), 265 | kernel_mod = next(conv_mods_iter) 266 | ) 267 | 268 | x = self.norm(x) 269 | x = self.act(x) 270 | return x 271 | 272 | class ResnetBlock(Module): 273 | @beartype 274 | def __init__( 275 | self, 276 | dim, 277 | dim_out, 278 | *, 279 | num_conv_kernels = 0, 280 | conv_type: Literal['1d', '2d'] = '2d', 281 | style_dims: List[int] = [] 282 | ): 283 | super().__init__() 284 | 285 | mod_dims = [ 286 | dim, 287 | num_conv_kernels, 288 | dim_out, 289 | num_conv_kernels 290 | ] 291 | 292 | style_dims.extend(mod_dims) 293 | 294 | self.num_mods = len(mod_dims) 295 | 296 | self.block1 = Block(dim, dim_out, num_conv_kernels = num_conv_kernels, conv_type = conv_type) 297 | self.block2 = Block(dim_out, dim_out, num_conv_kernels = num_conv_kernels, conv_type = conv_type) 298 | 299 | conv_klass = nn.Conv2d if conv_type == '2d' else nn.Conv1d 300 | self.res_conv = conv_klass(dim, dim_out, 1) if dim != dim_out else nn.Identity() 301 | 302 | def forward( 303 | self, 304 | x, 305 | conv_mods_iter: Iterable | None = None 306 | ): 307 | h = self.block1(x, conv_mods_iter = conv_mods_iter) 308 | h = self.block2(h, conv_mods_iter = conv_mods_iter) 309 | 310 | return h + self.res_conv(x) 311 | 312 | class LinearAttention(Module): 313 | def __init__( 314 | self, 315 | dim, 316 | heads = 4, 317 | dim_head = 32 318 | ): 319 | super().__init__() 320 | self.scale = dim_head ** -0.5 321 | self.heads = heads 322 | hidden_dim = dim_head * heads 323 | 324 | self.norm = RMSNorm(dim) 325 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 326 | 327 | self.to_out = nn.Sequential( 328 | nn.Conv2d(hidden_dim, dim, 1), 329 | RMSNorm(dim) 330 | ) 331 | 332 | def forward(self, x): 333 | b, c, h, w = x.shape 334 | 335 | x = self.norm(x) 336 | 337 | qkv = self.to_qkv(x).chunk(3, dim = 1) 338 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) 339 | 340 | q = q.softmax(dim = -2) 341 | k = k.softmax(dim = -1) 342 | 343 | q = q * self.scale 344 | 345 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v) 346 | 347 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q) 348 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) 349 | return self.to_out(out) 350 | 351 | class Attention(Module): 352 | def __init__( 353 | self, 354 | dim, 355 | heads = 4, 356 | dim_head = 32, 357 | flash = False 358 | ): 359 | super().__init__() 360 | self.heads = heads 361 | hidden_dim = dim_head * heads 362 | 363 | self.norm = RMSNorm(dim) 364 | self.attend = Attend(flash = flash) 365 | 366 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 367 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 368 | 369 | def forward(self, x): 370 | b, c, h, w = x.shape 371 | 372 | x = self.norm(x) 373 | 374 | qkv = self.to_qkv(x).chunk(3, dim = 1) 375 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h (x y) c', h = self.heads), qkv) 376 | 377 | out = self.attend(q, k, v) 378 | 379 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) 380 | return self.to_out(out) 381 | 382 | # feedforward 383 | 384 | def FeedForward(dim, mult = 4): 385 | return nn.Sequential( 386 | RMSNorm(dim), 387 | nn.Conv2d(dim, dim * mult, 1), 388 | nn.GELU(), 389 | nn.Conv2d(dim * mult, dim, 1) 390 | ) 391 | 392 | # transformers 393 | 394 | class Transformer(Module): 395 | def __init__( 396 | self, 397 | dim, 398 | dim_head = 64, 399 | heads = 8, 400 | depth = 1, 401 | flash_attn = True, 402 | ff_mult = 4 403 | ): 404 | super().__init__() 405 | self.layers = ModuleList([]) 406 | 407 | for _ in range(depth): 408 | self.layers.append(ModuleList([ 409 | Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash_attn), 410 | FeedForward(dim = dim, mult = ff_mult) 411 | ])) 412 | 413 | def forward(self, x): 414 | for attn, ff in self.layers: 415 | x = attn(x) + x 416 | x = ff(x) + x 417 | 418 | return x 419 | 420 | class LinearTransformer(Module): 421 | def __init__( 422 | self, 423 | dim, 424 | dim_head = 64, 425 | heads = 8, 426 | depth = 1, 427 | ff_mult = 4 428 | ): 429 | super().__init__() 430 | self.layers = ModuleList([]) 431 | 432 | for _ in range(depth): 433 | self.layers.append(ModuleList([ 434 | LinearAttention(dim = dim, dim_head = dim_head, heads = heads), 435 | FeedForward(dim = dim, mult = ff_mult) 436 | ])) 437 | 438 | def forward(self, x): 439 | for attn, ff in self.layers: 440 | x = attn(x) + x 441 | x = ff(x) + x 442 | 443 | return x 444 | 445 | # model 446 | 447 | class UnetUpsampler(BaseGenerator): 448 | 449 | @beartype 450 | def __init__( 451 | self, 452 | dim, 453 | *, 454 | image_size, 455 | input_image_size, 456 | init_dim = None, 457 | out_dim = None, 458 | text_encoder: TextEncoder | Dict | None = None, 459 | style_network: StyleNetwork | Dict | None = None, 460 | style_network_dim = None, 461 | dim_mults = (1, 2, 4, 8, 16), 462 | channels = 3, 463 | full_attn = (False, False, False, True, True), 464 | cross_attn = (False, False, False, True, True), 465 | flash_attn = True, 466 | self_attn_dim_head = 64, 467 | self_attn_heads = 8, 468 | self_attn_dot_product = True, 469 | self_attn_ff_mult = 4, 470 | attn_depths = (1, 1, 1, 1, 1), 471 | temporal_attn_depths = (1, 1, 1, 1, 1), 472 | cross_attn_dim_head = 64, 473 | cross_attn_heads = 8, 474 | cross_ff_mult = 4, 475 | has_temporal_layers = False, 476 | mid_attn_depth = 1, 477 | num_conv_kernels = 2, 478 | unconditional = True, 479 | skip_connect_scale = None 480 | ): 481 | super().__init__() 482 | 483 | # able to upsample video 484 | 485 | self.can_upsample_video = has_temporal_layers 486 | 487 | # style network 488 | 489 | if isinstance(text_encoder, dict): 490 | text_encoder = TextEncoder(**text_encoder) 491 | 492 | self.text_encoder = text_encoder 493 | 494 | if isinstance(style_network, dict): 495 | style_network = StyleNetwork(**style_network) 496 | 497 | self.style_network = style_network 498 | 499 | assert exists(style_network) ^ exists(style_network_dim), 'either style_network or style_network_dim must be passed in' 500 | 501 | # validate text conditioning and style network hparams 502 | 503 | self.unconditional = unconditional 504 | assert unconditional ^ exists(text_encoder), 'if unconditional, text encoder should not be given, and vice versa' 505 | assert not (unconditional and exists(style_network) and style_network.dim_text_latent > 0) 506 | assert unconditional or text_encoder.dim == style_network.dim_text_latent, 'the `dim_text_latent` on your StyleNetwork must be equal to the `dim` set for the TextEncoder' 507 | 508 | assert is_power_of_two(image_size) and is_power_of_two(input_image_size), 'both output image size and input image size must be power of 2' 509 | assert input_image_size < image_size, 'input image size must be smaller than the output image size, thus upsampling' 510 | 511 | num_layer_no_downsample = int(log2(image_size) - log2(input_image_size)) 512 | assert num_layer_no_downsample <= len(dim_mults), 'you need more stages in this unet for the level of upsampling' 513 | 514 | self.image_size = image_size 515 | self.input_image_size = input_image_size 516 | 517 | # setup adaptive conv 518 | 519 | style_embed_split_dims = [] 520 | 521 | # determine dimensions 522 | 523 | self.channels = channels 524 | input_channels = channels 525 | 526 | init_dim = default(init_dim, dim) 527 | self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) 528 | 529 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 530 | 531 | *_, mid_dim = dims 532 | 533 | in_out = list(zip(dims[:-1], dims[1:])) 534 | 535 | block_klass = partial( 536 | ResnetBlock, 537 | num_conv_kernels = num_conv_kernels, 538 | style_dims = style_embed_split_dims 539 | ) 540 | 541 | # attention 542 | 543 | full_attn = cast_tuple(full_attn, length = len(dim_mults)) 544 | assert len(full_attn) == len(dim_mults) 545 | 546 | FullAttention = partial(Transformer, flash_attn = flash_attn) 547 | 548 | cross_attn = cast_tuple(cross_attn, length = len(dim_mults)) 549 | assert unconditional or len(full_attn) == len(dim_mults) 550 | 551 | # skip connection scale 552 | 553 | self.skip_connect_scale = default(skip_connect_scale, 2 ** -0.5) 554 | 555 | # layers 556 | 557 | self.downs = ModuleList([]) 558 | self.ups = ModuleList([]) 559 | num_resolutions = len(in_out) 560 | skip_connect_dims = [] 561 | 562 | for ind, ((dim_in, dim_out), layer_full_attn, layer_cross_attn, layer_attn_depth, layer_temporal_attn_depth) in enumerate(zip(in_out, full_attn, cross_attn, attn_depths, temporal_attn_depths)): 563 | 564 | should_not_downsample = ind < num_layer_no_downsample 565 | has_cross_attn = not self.unconditional and layer_cross_attn 566 | 567 | attn_klass = FullAttention if layer_full_attn else LinearTransformer 568 | 569 | skip_connect_dims.append(dim_in) 570 | skip_connect_dims.append(dim_in + (dim_out if not should_not_downsample else 0)) 571 | 572 | temporal_resnet_block = None 573 | temporal_attn = None 574 | 575 | if has_temporal_layers: 576 | temporal_resnet_block = block_klass(dim_in, dim_in, conv_type = '1d') 577 | temporal_attn = FullAttention(dim_in, dim_head = self_attn_dim_head, heads = self_attn_heads, depth = layer_temporal_attn_depth) 578 | 579 | # all unet downsample stages 580 | 581 | self.downs.append(ModuleList([ 582 | block_klass(dim_in, dim_in), 583 | block_klass(dim_in, dim_in), 584 | CrossAttentionBlock(dim_in, dim_context = text_encoder.dim, dim_head = self_attn_dim_head, heads = self_attn_heads, ff_mult = self_attn_ff_mult) if has_cross_attn else None, 585 | attn_klass(dim_in, dim_head = self_attn_dim_head, heads = self_attn_heads, depth = layer_attn_depth), 586 | temporal_resnet_block, 587 | temporal_attn, 588 | Downsample(dim_in, dim_out, skip_downsample = should_not_downsample, has_temporal_layers = has_temporal_layers) 589 | ])) 590 | 591 | self.mid_block1 = block_klass(mid_dim, mid_dim) 592 | self.mid_attn = FullAttention(mid_dim, dim_head = self_attn_dim_head, heads = self_attn_heads, depth = mid_attn_depth) 593 | self.mid_block2 = block_klass(mid_dim, mid_dim) 594 | self.mid_to_rgb = nn.Conv2d(mid_dim, channels, 1) 595 | 596 | for ind, ((dim_in, dim_out), layer_cross_attn, layer_full_attn, layer_attn_depth, layer_temporal_attn_depth) in enumerate(zip(reversed(in_out), reversed(full_attn), reversed(cross_attn), reversed(attn_depths), reversed(temporal_attn_depths))): 597 | 598 | attn_klass = FullAttention if layer_full_attn else LinearTransformer 599 | has_cross_attn = not self.unconditional and layer_cross_attn 600 | 601 | temporal_upsample = None 602 | temporal_upsample_rgb = None 603 | temporal_resnet_block = None 604 | temporal_attn = None 605 | 606 | if has_temporal_layers: 607 | temporal_upsample = PixelShuffleTemporalUpsample(dim_in, dim_in) 608 | temporal_upsample_rgb = TemporalUpsample(dim_in, dim_in) 609 | 610 | temporal_resnet_block = block_klass(dim_in, dim_in, conv_type = '1d') 611 | temporal_attn = FullAttention(dim_in, dim_head = self_attn_dim_head, heads = self_attn_heads, depth = layer_temporal_attn_depth) 612 | 613 | self.ups.append(ModuleList([ 614 | PixelShuffleUpsample(dim_out, dim_in), 615 | Upsample(), 616 | temporal_upsample, 617 | temporal_upsample_rgb, 618 | nn.Conv2d(dim_in, channels, 1), 619 | block_klass(dim_in + skip_connect_dims.pop(), dim_in), 620 | block_klass(dim_in + skip_connect_dims.pop(), dim_in), 621 | CrossAttentionBlock(dim_in, dim_context = text_encoder.dim, dim_head = self_attn_dim_head, heads = self_attn_heads, ff_mult = cross_ff_mult) if has_cross_attn else None, 622 | attn_klass(dim_in, dim_head = cross_attn_dim_head, heads = self_attn_heads, depth = layer_attn_depth), 623 | temporal_resnet_block, 624 | temporal_attn 625 | ])) 626 | 627 | self.out_dim = default(out_dim, channels) 628 | 629 | self.final_res_block = block_klass(dim, dim) 630 | 631 | self.final_to_rgb = nn.Conv2d(dim, channels, 1) 632 | 633 | # determine the projection of the style embedding to convolutional modulation weights (+ adaptive kernel selection weights) for all layers 634 | 635 | self.style_to_conv_modulations = nn.Linear(style_network.dim, sum(style_embed_split_dims)) 636 | self.style_embed_split_dims = style_embed_split_dims 637 | 638 | @property 639 | def allowable_rgb_resolutions(self): 640 | input_res_base = int(log2(self.input_image_size)) 641 | output_res_base = int(log2(self.image_size)) 642 | allowed_rgb_res_base = list(range(input_res_base, output_res_base)) 643 | return [*map(lambda p: 2 ** p, allowed_rgb_res_base)] 644 | 645 | @property 646 | def device(self): 647 | return next(self.parameters()).device 648 | 649 | @property 650 | def total_params(self): 651 | return sum([p.numel() for p in self.parameters()]) 652 | 653 | def resize_to_same_dimensions(self, x, size): 654 | mode = 'trilinear' if x.ndim == 5 else 'bilinear' 655 | return F.interpolate(x, tuple(size), mode = mode) 656 | 657 | def forward( 658 | self, 659 | lowres_image_or_video, 660 | styles = None, 661 | noise = None, 662 | texts: List[str] | None = None, 663 | global_text_tokens = None, 664 | fine_text_tokens = None, 665 | text_mask = None, 666 | return_all_rgbs = False, 667 | replace_rgb_with_input_lowres_image = True # discriminator should also receive the low resolution image the upsampler sees 668 | ): 669 | x = lowres_image_or_video 670 | shape = x.shape 671 | batch_size = shape[0] 672 | 673 | assert shape[-2:] == ((self.input_image_size,) * 2) 674 | 675 | # take care of text encodings 676 | # which requires global text tokens to adaptively select the kernels from the main contribution in the paper 677 | # and fine text tokens to attend to using cross attention 678 | 679 | if not self.unconditional: 680 | if exists(texts): 681 | assert exists(self.text_encoder) 682 | global_text_tokens, fine_text_tokens, text_mask = self.text_encoder(texts) 683 | else: 684 | assert all([*map(exists, (global_text_tokens, fine_text_tokens, text_mask))]) 685 | else: 686 | assert not any([*map(exists, (texts, global_text_tokens, fine_text_tokens))]) 687 | 688 | # styles 689 | 690 | if not exists(styles): 691 | assert exists(self.style_network) 692 | 693 | noise = default(noise, torch.randn((batch_size, self.style_network.dim), device = self.device)) 694 | styles = self.style_network(noise, global_text_tokens) 695 | 696 | # project styles to conv modulations 697 | 698 | conv_mods = self.style_to_conv_modulations(styles) 699 | conv_mods = conv_mods.split(self.style_embed_split_dims, dim = -1) 700 | conv_mods = iter(conv_mods) 701 | 702 | # first detect whether input is image or video and handle accordingly 703 | 704 | input_is_video = lowres_image_or_video.ndim == 5 705 | assert not (not self.can_upsample_video and input_is_video), 'this network cannot upsample video unless you set `has_temporal_layers = True`' 706 | 707 | fold_time_into_batch = identity 708 | split_time_from_batch = identity 709 | 710 | if input_is_video: 711 | fold_time_into_batch = lambda t: rearrange(t, 'b c t h w -> (b t) c h w') 712 | split_time_from_batch = lambda t: rearrange(t, '(b t) c h w -> b c t h w', b = batch_size) 713 | 714 | x = fold_time_into_batch(x) 715 | 716 | # set lowres_images for final rgb output 717 | 718 | lowres_images = x 719 | 720 | # initial conv 721 | 722 | x = self.init_conv(x) 723 | 724 | h = [] 725 | 726 | # downsample stages 727 | 728 | for ( 729 | block1, 730 | block2, 731 | cross_attn, 732 | attn, 733 | temporal_block, 734 | temporal_attn, 735 | downsample, 736 | ) in self.downs: 737 | 738 | x = block1(x, conv_mods_iter = conv_mods) 739 | h.append(x) 740 | 741 | x = block2(x, conv_mods_iter = conv_mods) 742 | 743 | x = attn(x) 744 | 745 | if exists(cross_attn): 746 | x = cross_attn(x, context = fine_text_tokens, mask = text_mask) 747 | 748 | if input_is_video: 749 | x = split_time_from_batch(x) 750 | x, split_space_back = fold_space_into_batch(x) 751 | 752 | x = temporal_block(x, conv_mods_iter = conv_mods) 753 | 754 | x = rearrange(x, 'b c t -> b c t 1') 755 | x = temporal_attn(x) 756 | x = rearrange(x, 'b c t 1 -> b c t') 757 | 758 | x = split_space_back(x) 759 | x = fold_time_into_batch(x) 760 | 761 | elif self.can_upsample_video: 762 | conv_mods = islice(conv_mods, temporal_block.num_mods, None) 763 | 764 | skip_connect = x 765 | 766 | # downsample with hf shuttle 767 | 768 | x = split_time_from_batch(x) 769 | 770 | x, hf_fmap = downsample(x) 771 | 772 | x = fold_time_into_batch(x) 773 | hf_fmap = fold_time_into_batch(hf_fmap) 774 | 775 | # add high freq fmap to skip connection as proposed in videogigagan 776 | 777 | skip_connect = torch.cat((skip_connect, hf_fmap), dim = 1) 778 | 779 | h.append(skip_connect) 780 | 781 | x = self.mid_block1(x, conv_mods_iter = conv_mods) 782 | x = self.mid_attn(x) 783 | x = self.mid_block2(x, conv_mods_iter = conv_mods) 784 | 785 | # rgbs 786 | 787 | rgbs = [] 788 | 789 | init_rgb_shape = list(x.shape) 790 | init_rgb_shape[1] = self.channels 791 | 792 | rgb = self.mid_to_rgb(x) 793 | rgbs.append(rgb) 794 | 795 | # upsample stages 796 | 797 | for ( 798 | upsample, 799 | upsample_rgb, 800 | temporal_upsample, 801 | temporal_upsample_rgb, 802 | to_rgb, 803 | block1, 804 | block2, 805 | cross_attn, 806 | attn, 807 | temporal_block, 808 | temporal_attn, 809 | ) in self.ups: 810 | 811 | x = upsample(x) 812 | rgb = upsample_rgb(rgb) 813 | 814 | if input_is_video: 815 | x = split_time_from_batch(x) 816 | rgb = split_time_from_batch(rgb) 817 | 818 | x = temporal_upsample(x) 819 | rgb = temporal_upsample_rgb(rgb) 820 | 821 | x = fold_time_into_batch(x) 822 | rgb = fold_time_into_batch(rgb) 823 | 824 | res1 = h.pop() * self.skip_connect_scale 825 | res2 = h.pop() * self.skip_connect_scale 826 | 827 | # handle skip connections not being the same shape 828 | 829 | if x.shape[0] != res1.shape[0] or x.shape[2:] != res1.shape[2:]: 830 | x = split_time_from_batch(x) 831 | res1 = split_time_from_batch(res1) 832 | res2 = split_time_from_batch(res2) 833 | 834 | res1 = self.resize_to_same_dimensions(res1, x.shape[2:]) 835 | res2 = self.resize_to_same_dimensions(res2, x.shape[2:]) 836 | 837 | x = fold_time_into_batch(x) 838 | res1 = fold_time_into_batch(res1) 839 | res2 = fold_time_into_batch(res2) 840 | 841 | # concat skip connections 842 | 843 | x = torch.cat((x, res1), dim = 1) 844 | x = block1(x, conv_mods_iter = conv_mods) 845 | 846 | x = torch.cat((x, res2), dim = 1) 847 | x = block2(x, conv_mods_iter = conv_mods) 848 | 849 | if exists(cross_attn): 850 | x = cross_attn(x, context = fine_text_tokens, mask = text_mask) 851 | 852 | x = attn(x) 853 | 854 | if input_is_video: 855 | x = split_time_from_batch(x) 856 | x, split_space_back = fold_space_into_batch(x) 857 | 858 | x = temporal_block(x, conv_mods_iter = conv_mods) 859 | 860 | x = rearrange(x, 'b c t -> b c t 1') 861 | x = temporal_attn(x) 862 | x = rearrange(x, 'b c t 1 -> b c t') 863 | 864 | x = split_space_back(x) 865 | x = fold_time_into_batch(x) 866 | 867 | elif self.can_upsample_video: 868 | conv_mods = islice(conv_mods, temporal_block.num_mods, None) 869 | 870 | rgb = rgb + to_rgb(x) 871 | rgbs.append(rgb) 872 | 873 | x = self.final_res_block(x, conv_mods_iter = conv_mods) 874 | 875 | assert len([*conv_mods]) == 0 876 | 877 | rgb = rgb + self.final_to_rgb(x) 878 | 879 | # handle video input 880 | 881 | if input_is_video: 882 | rgb = split_time_from_batch(rgb) 883 | 884 | if not return_all_rgbs: 885 | return rgb 886 | 887 | # only keep those rgbs whose feature map is greater than the input image to be upsampled 888 | 889 | rgbs = list(filter(lambda t: t.shape[-1] > shape[-1], rgbs)) 890 | 891 | # and return the original input image as the smallest rgb 892 | 893 | rgbs = [lowres_images, *rgbs] 894 | 895 | if input_is_video: 896 | rgbs = [*map(split_time_from_batch, rgbs)] 897 | 898 | return rgb, rgbs 899 | -------------------------------------------------------------------------------- /gigagan_pytorch/gigagan_pytorch.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections import namedtuple 4 | from pathlib import Path 5 | from math import log2, sqrt 6 | from random import random 7 | from functools import partial 8 | 9 | from torchvision import utils 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn, einsum, Tensor 14 | from torch.autograd import grad as torch_grad 15 | from torch.utils.data import DataLoader 16 | from torch.cuda.amp import GradScaler 17 | 18 | from beartype import beartype 19 | from beartype.typing import List, Tuple, Dict, Iterable 20 | 21 | from einops import rearrange, pack, unpack, repeat, reduce 22 | from einops.layers.torch import Rearrange, Reduce 23 | 24 | from kornia.filters import filter2d 25 | 26 | from ema_pytorch import EMA 27 | 28 | from gigagan_pytorch.version import __version__ 29 | from gigagan_pytorch.open_clip import OpenClipAdapter 30 | from gigagan_pytorch.optimizer import get_optimizer 31 | from gigagan_pytorch.distributed import all_gather 32 | 33 | from tqdm import tqdm 34 | 35 | from numerize import numerize 36 | 37 | from accelerate import Accelerator, DistributedType 38 | from accelerate.utils import DistributedDataParallelKwargs 39 | 40 | # helpers 41 | 42 | def exists(val): 43 | return val is not None 44 | 45 | @beartype 46 | def is_empty(arr: Iterable): 47 | return len(arr) == 0 48 | 49 | def default(*vals): 50 | for val in vals: 51 | if exists(val): 52 | return val 53 | return None 54 | 55 | def cast_tuple(t, length = 1): 56 | return t if isinstance(t, tuple) else ((t,) * length) 57 | 58 | def is_power_of_two(n): 59 | return log2(n).is_integer() 60 | 61 | def safe_unshift(arr): 62 | if len(arr) == 0: 63 | return None 64 | return arr.pop(0) 65 | 66 | def divisible_by(numer, denom): 67 | return (numer % denom) == 0 68 | 69 | def group_by_num_consecutive(arr, num): 70 | out = [] 71 | for ind, el in enumerate(arr): 72 | if ind > 0 and divisible_by(ind, num): 73 | yield out 74 | out = [] 75 | 76 | out.append(el) 77 | 78 | if len(out) > 0: 79 | yield out 80 | 81 | def is_unique(arr): 82 | return len(set(arr)) == len(arr) 83 | 84 | def cycle(dl): 85 | while True: 86 | for data in dl: 87 | yield data 88 | 89 | def num_to_groups(num, divisor): 90 | groups, remainder = divmod(num, divisor) 91 | arr = [divisor] * groups 92 | if remainder > 0: 93 | arr.append(remainder) 94 | return arr 95 | 96 | def mkdir_if_not_exists(path): 97 | path.mkdir(exist_ok = True, parents = True) 98 | 99 | @beartype 100 | def set_requires_grad_( 101 | m: nn.Module, 102 | requires_grad: bool 103 | ): 104 | for p in m.parameters(): 105 | p.requires_grad = requires_grad 106 | 107 | # activation functions 108 | 109 | def leaky_relu(neg_slope = 0.2): 110 | return nn.LeakyReLU(neg_slope) 111 | 112 | def conv2d_3x3(dim_in, dim_out): 113 | return nn.Conv2d(dim_in, dim_out, 3, padding = 1) 114 | 115 | # tensor helpers 116 | 117 | def log(t, eps = 1e-20): 118 | return t.clamp(min = eps).log() 119 | 120 | def gradient_penalty( 121 | images, 122 | outputs, 123 | grad_output_weights = None, 124 | weight = 10, 125 | scaler: GradScaler | None = None, 126 | eps = 1e-4 127 | ): 128 | if not isinstance(outputs, (list, tuple)): 129 | outputs = [outputs] 130 | 131 | if exists(scaler): 132 | outputs = [*map(scaler.scale, outputs)] 133 | 134 | if not exists(grad_output_weights): 135 | grad_output_weights = (1,) * len(outputs) 136 | 137 | maybe_scaled_gradients, *_ = torch_grad( 138 | outputs = outputs, 139 | inputs = images, 140 | grad_outputs = [(torch.ones_like(output) * weight) for output, weight in zip(outputs, grad_output_weights)], 141 | create_graph = True, 142 | retain_graph = True, 143 | only_inputs = True 144 | ) 145 | 146 | gradients = maybe_scaled_gradients 147 | 148 | if exists(scaler): 149 | scale = scaler.get_scale() 150 | inv_scale = 1. / max(scale, eps) 151 | gradients = maybe_scaled_gradients * inv_scale 152 | 153 | gradients = rearrange(gradients, 'b ... -> b (...)') 154 | return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean() 155 | 156 | # hinge gan losses 157 | 158 | def generator_hinge_loss(fake): 159 | return fake.mean() 160 | 161 | def discriminator_hinge_loss(real, fake): 162 | return (F.relu(1 + real) + F.relu(1 - fake)).mean() 163 | 164 | # auxiliary losses 165 | 166 | def aux_matching_loss(real, fake): 167 | """ 168 | making logits negative, as in this framework, discriminator is 0 for real, high value for fake. GANs can have this arbitrarily swapped, as it only matters if the generator and discriminator are opposites 169 | """ 170 | return (log(1 + (-real).exp()) + log(1 + (-fake).exp())).mean() 171 | 172 | @beartype 173 | def aux_clip_loss( 174 | clip: OpenClipAdapter, 175 | images: Tensor, 176 | texts: List[str] | None = None, 177 | text_embeds: Tensor | None = None 178 | ): 179 | assert exists(texts) ^ exists(text_embeds) 180 | 181 | images, batch_sizes = all_gather(images, 0, None) 182 | 183 | if exists(texts): 184 | text_embeds, _ = clip.embed_texts(texts) 185 | text_embeds, _ = all_gather(text_embeds, 0, batch_sizes) 186 | 187 | return clip.contrastive_loss(images = images, text_embeds = text_embeds) 188 | 189 | # differentiable augmentation - Karras et al. stylegan-ada 190 | # start with horizontal flip 191 | 192 | class DiffAugment(nn.Module): 193 | def __init__( 194 | self, 195 | *, 196 | prob, 197 | horizontal_flip, 198 | horizontal_flip_prob = 0.5 199 | ): 200 | super().__init__() 201 | self.prob = prob 202 | assert 0 <= prob <= 1. 203 | 204 | self.horizontal_flip = horizontal_flip 205 | self.horizontal_flip_prob = horizontal_flip_prob 206 | 207 | def forward( 208 | self, 209 | images, 210 | rgbs: List[Tensor] 211 | ): 212 | if random() >= self.prob: 213 | return images, rgbs 214 | 215 | if random() < self.horizontal_flip_prob: 216 | images = torch.flip(images, (-1,)) 217 | rgbs = [torch.flip(rgb, (-1,)) for rgb in rgbs] 218 | 219 | return images, rgbs 220 | 221 | # rmsnorm (newer papers show mean-centering in layernorm not necessary) 222 | 223 | class ChannelRMSNorm(nn.Module): 224 | def __init__(self, dim): 225 | super().__init__() 226 | self.scale = dim ** 0.5 227 | self.gamma = nn.Parameter(torch.ones(dim, 1, 1)) 228 | 229 | def forward(self, x): 230 | normed = F.normalize(x, dim = 1) 231 | return normed * self.scale * self.gamma 232 | 233 | class RMSNorm(nn.Module): 234 | def __init__(self, dim): 235 | super().__init__() 236 | self.scale = dim ** 0.5 237 | self.gamma = nn.Parameter(torch.ones(dim)) 238 | 239 | def forward(self, x): 240 | normed = F.normalize(x, dim = -1) 241 | return normed * self.scale * self.gamma 242 | 243 | # down and upsample 244 | 245 | class Blur(nn.Module): 246 | def __init__(self): 247 | super().__init__() 248 | f = torch.Tensor([1, 2, 1]) 249 | self.register_buffer('f', f) 250 | 251 | def forward(self, x): 252 | f = self.f 253 | f = f[None, None, :] * f[None, :, None] 254 | return filter2d(x, f, normalized = True) 255 | 256 | def Upsample(*args): 257 | return nn.Sequential( 258 | nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), 259 | Blur() 260 | ) 261 | 262 | class PixelShuffleUpsample(nn.Module): 263 | def __init__(self, dim, dim_out = None): 264 | super().__init__() 265 | dim_out = default(dim_out, dim) 266 | conv = nn.Conv2d(dim, dim_out * 4, 1) 267 | 268 | self.net = nn.Sequential( 269 | conv, 270 | nn.SiLU(), 271 | nn.PixelShuffle(2) 272 | ) 273 | 274 | self.init_conv_(conv) 275 | 276 | def init_conv_(self, conv): 277 | o, i, h, w = conv.weight.shape 278 | conv_weight = torch.empty(o // 4, i, h, w) 279 | nn.init.kaiming_uniform_(conv_weight) 280 | conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...') 281 | 282 | conv.weight.data.copy_(conv_weight) 283 | nn.init.zeros_(conv.bias.data) 284 | 285 | def forward(self, x): 286 | return self.net(x) 287 | 288 | def Downsample(dim): 289 | return nn.Sequential( 290 | Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2), 291 | nn.Conv2d(dim * 4, dim, 1) 292 | ) 293 | 294 | # skip layer excitation 295 | 296 | def SqueezeExcite(dim, dim_out, reduction = 4, dim_min = 32): 297 | dim_hidden = max(dim_out // reduction, dim_min) 298 | 299 | return nn.Sequential( 300 | Reduce('b c h w -> b c', 'mean'), 301 | nn.Linear(dim, dim_hidden), 302 | nn.SiLU(), 303 | nn.Linear(dim_hidden, dim_out), 304 | nn.Sigmoid(), 305 | Rearrange('b c -> b c 1 1') 306 | ) 307 | 308 | # adaptive conv 309 | # the main novelty of the paper - they propose to learn a softmax weighted sum of N convolutional kernels, depending on the text embedding 310 | 311 | def get_same_padding(size, kernel, dilation, stride): 312 | return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2 313 | 314 | class AdaptiveConv2DMod(nn.Module): 315 | def __init__( 316 | self, 317 | dim, 318 | dim_out, 319 | kernel, 320 | *, 321 | demod = True, 322 | stride = 1, 323 | dilation = 1, 324 | eps = 1e-8, 325 | num_conv_kernels = 1 # set this to be greater than 1 for adaptive 326 | ): 327 | super().__init__() 328 | self.eps = eps 329 | 330 | self.dim_out = dim_out 331 | 332 | self.kernel = kernel 333 | self.stride = stride 334 | self.dilation = dilation 335 | self.adaptive = num_conv_kernels > 1 336 | 337 | self.weights = nn.Parameter(torch.randn((num_conv_kernels, dim_out, dim, kernel, kernel))) 338 | 339 | self.demod = demod 340 | 341 | nn.init.kaiming_normal_(self.weights, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu') 342 | 343 | def forward( 344 | self, 345 | fmap, 346 | mod: Tensor, 347 | kernel_mod: Tensor | None = None 348 | ): 349 | """ 350 | notation 351 | 352 | b - batch 353 | n - convs 354 | o - output 355 | i - input 356 | k - kernel 357 | """ 358 | 359 | b, h = fmap.shape[0], fmap.shape[-2] 360 | 361 | # account for feature map that has been expanded by the scale in the first dimension 362 | # due to multiscale inputs and outputs 363 | 364 | if mod.shape[0] != b: 365 | mod = repeat(mod, 'b ... -> (s b) ...', s = b // mod.shape[0]) 366 | 367 | if exists(kernel_mod): 368 | kernel_mod_has_el = kernel_mod.numel() > 0 369 | 370 | assert self.adaptive or not kernel_mod_has_el 371 | 372 | if kernel_mod_has_el and kernel_mod.shape[0] != b: 373 | kernel_mod = repeat(kernel_mod, 'b ... -> (s b) ...', s = b // kernel_mod.shape[0]) 374 | 375 | # prepare weights for modulation 376 | 377 | weights = self.weights 378 | 379 | if self.adaptive: 380 | weights = repeat(weights, '... -> b ...', b = b) 381 | 382 | # determine an adaptive weight and 'select' the kernel to use with softmax 383 | 384 | assert exists(kernel_mod) and kernel_mod.numel() > 0 385 | 386 | kernel_attn = kernel_mod.softmax(dim = -1) 387 | kernel_attn = rearrange(kernel_attn, 'b n -> b n 1 1 1 1') 388 | 389 | weights = reduce(weights * kernel_attn, 'b n ... -> b ...', 'sum') 390 | 391 | # do the modulation, demodulation, as done in stylegan2 392 | 393 | mod = rearrange(mod, 'b i -> b 1 i 1 1') 394 | 395 | weights = weights * (mod + 1) 396 | 397 | if self.demod: 398 | inv_norm = reduce(weights ** 2, 'b o i k1 k2 -> b o 1 1 1', 'sum').clamp(min = self.eps).rsqrt() 399 | weights = weights * inv_norm 400 | 401 | fmap = rearrange(fmap, 'b c h w -> 1 (b c) h w') 402 | 403 | weights = rearrange(weights, 'b o ... -> (b o) ...') 404 | 405 | padding = get_same_padding(h, self.kernel, self.dilation, self.stride) 406 | fmap = F.conv2d(fmap, weights, padding = padding, groups = b) 407 | 408 | return rearrange(fmap, '1 (b o) ... -> b o ...', b = b) 409 | 410 | class AdaptiveConv1DMod(nn.Module): 411 | """ 1d version of adaptive conv, for time dimension in videogigagan """ 412 | 413 | def __init__( 414 | self, 415 | dim, 416 | dim_out, 417 | kernel, 418 | *, 419 | demod = True, 420 | stride = 1, 421 | dilation = 1, 422 | eps = 1e-8, 423 | num_conv_kernels = 1 # set this to be greater than 1 for adaptive 424 | ): 425 | super().__init__() 426 | self.eps = eps 427 | 428 | self.dim_out = dim_out 429 | 430 | self.kernel = kernel 431 | self.stride = stride 432 | self.dilation = dilation 433 | self.adaptive = num_conv_kernels > 1 434 | 435 | self.weights = nn.Parameter(torch.randn((num_conv_kernels, dim_out, dim, kernel))) 436 | 437 | self.demod = demod 438 | 439 | nn.init.kaiming_normal_(self.weights, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu') 440 | 441 | def forward( 442 | self, 443 | fmap, 444 | mod: Tensor, 445 | kernel_mod: Tensor | None = None 446 | ): 447 | """ 448 | notation 449 | 450 | b - batch 451 | n - convs 452 | o - output 453 | i - input 454 | k - kernel 455 | """ 456 | 457 | b, t = fmap.shape[0], fmap.shape[-1] 458 | 459 | # account for feature map that has been expanded by the scale in the first dimension 460 | # due to multiscale inputs and outputs 461 | 462 | if mod.shape[0] != b: 463 | mod = repeat(mod, 'b ... -> (s b) ...', s = b // mod.shape[0]) 464 | 465 | if exists(kernel_mod): 466 | kernel_mod_has_el = kernel_mod.numel() > 0 467 | 468 | assert self.adaptive or not kernel_mod_has_el 469 | 470 | if kernel_mod_has_el and kernel_mod.shape[0] != b: 471 | kernel_mod = repeat(kernel_mod, 'b ... -> (s b) ...', s = b // kernel_mod.shape[0]) 472 | 473 | # prepare weights for modulation 474 | 475 | weights = self.weights 476 | 477 | if self.adaptive: 478 | weights = repeat(weights, '... -> b ...', b = b) 479 | 480 | # determine an adaptive weight and 'select' the kernel to use with softmax 481 | 482 | assert exists(kernel_mod) and kernel_mod.numel() > 0 483 | 484 | kernel_attn = kernel_mod.softmax(dim = -1) 485 | kernel_attn = rearrange(kernel_attn, 'b n -> b n 1 1 1') 486 | 487 | weights = reduce(weights * kernel_attn, 'b n ... -> b ...', 'sum') 488 | 489 | # do the modulation, demodulation, as done in stylegan2 490 | 491 | mod = rearrange(mod, 'b i -> b 1 i 1') 492 | 493 | weights = weights * (mod + 1) 494 | 495 | if self.demod: 496 | inv_norm = reduce(weights ** 2, 'b o i k -> b o 1 1', 'sum').clamp(min = self.eps).rsqrt() 497 | weights = weights * inv_norm 498 | 499 | fmap = rearrange(fmap, 'b c t -> 1 (b c) t') 500 | 501 | weights = rearrange(weights, 'b o ... -> (b o) ...') 502 | 503 | padding = get_same_padding(t, self.kernel, self.dilation, self.stride) 504 | fmap = F.conv1d(fmap, weights, padding = padding, groups = b) 505 | 506 | return rearrange(fmap, '1 (b o) ... -> b o ...', b = b) 507 | 508 | # attention 509 | # they use an attention with a better Lipchitz constant - l2 distance similarity instead of dot product - also shared query / key space - shown in vitgan to be more stable 510 | # not sure what they did about token attention to self, so masking out, as done in some other papers using shared query / key space 511 | 512 | class SelfAttention(nn.Module): 513 | def __init__( 514 | self, 515 | dim, 516 | dim_head = 64, 517 | heads = 8, 518 | dot_product = False 519 | ): 520 | super().__init__() 521 | self.heads = heads 522 | self.scale = dim_head ** -0.5 523 | dim_inner = dim_head * heads 524 | 525 | self.dot_product = dot_product 526 | 527 | self.norm = ChannelRMSNorm(dim) 528 | 529 | self.to_q = nn.Conv2d(dim, dim_inner, 1, bias = False) 530 | self.to_k = nn.Conv2d(dim, dim_inner, 1, bias = False) if dot_product else None 531 | self.to_v = nn.Conv2d(dim, dim_inner, 1, bias = False) 532 | 533 | self.null_kv = nn.Parameter(torch.randn(2, heads, dim_head)) 534 | 535 | self.to_out = nn.Conv2d(dim_inner, dim, 1, bias = False) 536 | 537 | def forward(self, fmap): 538 | """ 539 | einstein notation 540 | 541 | b - batch 542 | h - heads 543 | x - height 544 | y - width 545 | d - dimension 546 | i - source seq (attend from) 547 | j - target seq (attend to) 548 | """ 549 | batch = fmap.shape[0] 550 | 551 | fmap = self.norm(fmap) 552 | 553 | x, y = fmap.shape[-2:] 554 | 555 | h = self.heads 556 | 557 | q, v = self.to_q(fmap), self.to_v(fmap) 558 | 559 | k = self.to_k(fmap) if exists(self.to_k) else q 560 | 561 | q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = self.heads), (q, k, v)) 562 | 563 | # add a null key / value, so network can choose to pay attention to nothing 564 | 565 | nk, nv = map(lambda t: repeat(t, 'h d -> (b h) 1 d', b = batch), self.null_kv) 566 | 567 | k = torch.cat((nk, k), dim = -2) 568 | v = torch.cat((nv, v), dim = -2) 569 | 570 | # l2 distance or dot product 571 | 572 | if self.dot_product: 573 | sim = einsum('b i d, b j d -> b i j', q, k) 574 | else: 575 | # using pytorch cdist leads to nans in lightweight gan training framework, at least 576 | q_squared = (q * q).sum(dim = -1) 577 | k_squared = (k * k).sum(dim = -1) 578 | l2dist_squared = rearrange(q_squared, 'b i -> b i 1') + rearrange(k_squared, 'b j -> b 1 j') - 2 * einsum('b i d, b j d -> b i j', q, k) # hope i'm mathing right 579 | sim = -l2dist_squared 580 | 581 | # scale 582 | 583 | sim = sim * self.scale 584 | 585 | # attention 586 | 587 | attn = sim.softmax(dim = -1) 588 | 589 | out = einsum('b i j, b j d -> b i d', attn, v) 590 | 591 | out = rearrange(out, '(b h) (x y) d -> b (h d) x y', x = x, y = y, h = h) 592 | 593 | return self.to_out(out) 594 | 595 | class CrossAttention(nn.Module): 596 | def __init__( 597 | self, 598 | dim, 599 | dim_context, 600 | dim_head = 64, 601 | heads = 8 602 | ): 603 | super().__init__() 604 | self.heads = heads 605 | self.scale = dim_head ** -0.5 606 | dim_inner = dim_head * heads 607 | kv_input_dim = default(dim_context, dim) 608 | 609 | self.norm = ChannelRMSNorm(dim) 610 | self.norm_context = RMSNorm(kv_input_dim) 611 | 612 | self.to_q = nn.Conv2d(dim, dim_inner, 1, bias = False) 613 | self.to_kv = nn.Linear(kv_input_dim, dim_inner * 2, bias = False) 614 | self.to_out = nn.Conv2d(dim_inner, dim, 1, bias = False) 615 | 616 | def forward(self, fmap, context, mask = None): 617 | """ 618 | einstein notation 619 | 620 | b - batch 621 | h - heads 622 | x - height 623 | y - width 624 | d - dimension 625 | i - source seq (attend from) 626 | j - target seq (attend to) 627 | """ 628 | 629 | fmap = self.norm(fmap) 630 | context = self.norm_context(context) 631 | 632 | x, y = fmap.shape[-2:] 633 | 634 | h = self.heads 635 | 636 | q, k, v = (self.to_q(fmap), *self.to_kv(context).chunk(2, dim = -1)) 637 | 638 | k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (k, v)) 639 | 640 | q = rearrange(q, 'b (h d) x y -> (b h) (x y) d', h = self.heads) 641 | 642 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 643 | 644 | if exists(mask): 645 | mask = repeat(mask, 'b j -> (b h) 1 j', h = self.heads) 646 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) 647 | 648 | attn = sim.softmax(dim = -1) 649 | 650 | out = einsum('b i j, b j d -> b i d', attn, v) 651 | 652 | out = rearrange(out, '(b h) (x y) d -> b (h d) x y', x = x, y = y, h = h) 653 | 654 | return self.to_out(out) 655 | 656 | # classic transformer attention, stick with l2 distance 657 | 658 | class TextAttention(nn.Module): 659 | def __init__( 660 | self, 661 | dim, 662 | dim_head = 64, 663 | heads = 8 664 | ): 665 | super().__init__() 666 | self.heads = heads 667 | self.scale = dim_head ** -0.5 668 | dim_inner = dim_head * heads 669 | 670 | self.norm = RMSNorm(dim) 671 | self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False) 672 | 673 | self.null_kv = nn.Parameter(torch.randn(2, heads, dim_head)) 674 | 675 | self.to_out = nn.Linear(dim_inner, dim, bias = False) 676 | 677 | def forward(self, encodings, mask = None): 678 | """ 679 | einstein notation 680 | 681 | b - batch 682 | h - heads 683 | x - height 684 | y - width 685 | d - dimension 686 | i - source seq (attend from) 687 | j - target seq (attend to) 688 | """ 689 | batch = encodings.shape[0] 690 | 691 | encodings = self.norm(encodings) 692 | 693 | h = self.heads 694 | 695 | q, k, v = self.to_qkv(encodings).chunk(3, dim = -1) 696 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v)) 697 | 698 | # add a null key / value, so network can choose to pay attention to nothing 699 | 700 | nk, nv = map(lambda t: repeat(t, 'h d -> (b h) 1 d', b = batch), self.null_kv) 701 | 702 | k = torch.cat((nk, k), dim = -2) 703 | v = torch.cat((nv, v), dim = -2) 704 | 705 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 706 | 707 | # key padding mask 708 | 709 | if exists(mask): 710 | mask = F.pad(mask, (1, 0), value = True) 711 | mask = repeat(mask, 'b n -> (b h) 1 n', h = h) 712 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) 713 | 714 | # attention 715 | 716 | attn = sim.softmax(dim = -1) 717 | out = einsum('b i j, b j d -> b i d', attn, v) 718 | 719 | out = rearrange(out, '(b h) n d -> b n (h d)', h = h) 720 | 721 | return self.to_out(out) 722 | 723 | # feedforward 724 | 725 | def FeedForward( 726 | dim, 727 | mult = 4, 728 | channel_first = False 729 | ): 730 | dim_hidden = int(dim * mult) 731 | norm_klass = ChannelRMSNorm if channel_first else RMSNorm 732 | proj = partial(nn.Conv2d, kernel_size = 1) if channel_first else nn.Linear 733 | 734 | return nn.Sequential( 735 | norm_klass(dim), 736 | proj(dim, dim_hidden), 737 | nn.GELU(), 738 | proj(dim_hidden, dim) 739 | ) 740 | 741 | # different types of transformer blocks or transformers (multiple blocks) 742 | 743 | class SelfAttentionBlock(nn.Module): 744 | def __init__( 745 | self, 746 | dim, 747 | dim_head = 64, 748 | heads = 8, 749 | ff_mult = 4, 750 | dot_product = False 751 | ): 752 | super().__init__() 753 | self.attn = SelfAttention(dim = dim, dim_head = dim_head, heads = heads, dot_product = dot_product) 754 | self.ff = FeedForward(dim = dim, mult = ff_mult, channel_first = True) 755 | 756 | def forward(self, x): 757 | x = self.attn(x) + x 758 | x = self.ff(x) + x 759 | return x 760 | 761 | class CrossAttentionBlock(nn.Module): 762 | def __init__( 763 | self, 764 | dim, 765 | dim_context, 766 | dim_head = 64, 767 | heads = 8, 768 | ff_mult = 4 769 | ): 770 | super().__init__() 771 | self.attn = CrossAttention(dim = dim, dim_context = dim_context, dim_head = dim_head, heads = heads) 772 | self.ff = FeedForward(dim = dim, mult = ff_mult, channel_first = True) 773 | 774 | def forward(self, x, context, mask = None): 775 | x = self.attn(x, context = context, mask = mask) + x 776 | x = self.ff(x) + x 777 | return x 778 | 779 | class Transformer(nn.Module): 780 | def __init__( 781 | self, 782 | dim, 783 | depth, 784 | dim_head = 64, 785 | heads = 8, 786 | ff_mult = 4 787 | ): 788 | super().__init__() 789 | self.layers = nn.ModuleList([]) 790 | for _ in range(depth): 791 | self.layers.append(nn.ModuleList([ 792 | TextAttention(dim = dim, dim_head = dim_head, heads = heads), 793 | FeedForward(dim = dim, mult = ff_mult) 794 | ])) 795 | 796 | self.norm = RMSNorm(dim) 797 | 798 | def forward(self, x, mask = None): 799 | for attn, ff in self.layers: 800 | x = attn(x, mask = mask) + x 801 | x = ff(x) + x 802 | 803 | return self.norm(x) 804 | 805 | # text encoder 806 | 807 | class TextEncoder(nn.Module): 808 | @beartype 809 | def __init__( 810 | self, 811 | *, 812 | dim, 813 | depth, 814 | clip: OpenClipAdapter | None = None, 815 | dim_head = 64, 816 | heads = 8, 817 | ): 818 | super().__init__() 819 | self.dim = dim 820 | 821 | if not exists(clip): 822 | clip = OpenClipAdapter() 823 | 824 | self.clip = clip 825 | set_requires_grad_(clip, False) 826 | 827 | self.learned_global_token = nn.Parameter(torch.randn(dim)) 828 | 829 | self.project_in = nn.Linear(clip.dim_latent, dim) if clip.dim_latent != dim else nn.Identity() 830 | 831 | self.transformer = Transformer( 832 | dim = dim, 833 | depth = depth, 834 | dim_head = dim_head, 835 | heads = heads 836 | ) 837 | 838 | @beartype 839 | def forward( 840 | self, 841 | texts: List[str] | None = None, 842 | text_encodings: Tensor | None = None 843 | ): 844 | assert exists(texts) ^ exists(text_encodings) 845 | 846 | if not exists(text_encodings): 847 | with torch.no_grad(): 848 | self.clip.eval() 849 | _, text_encodings = self.clip.embed_texts(texts) 850 | 851 | mask = (text_encodings != 0.).any(dim = -1) 852 | 853 | text_encodings = self.project_in(text_encodings) 854 | 855 | mask_with_global = F.pad(mask, (1, 0), value = True) 856 | 857 | batch = text_encodings.shape[0] 858 | global_tokens = repeat(self.learned_global_token, 'd -> b d', b = batch) 859 | 860 | text_encodings, ps = pack([global_tokens, text_encodings], 'b * d') 861 | 862 | text_encodings = self.transformer(text_encodings, mask = mask_with_global) 863 | 864 | global_tokens, text_encodings = unpack(text_encodings, ps, 'b * d') 865 | 866 | return global_tokens, text_encodings, mask 867 | 868 | # style mapping network 869 | 870 | class EqualLinear(nn.Module): 871 | def __init__( 872 | self, 873 | dim, 874 | dim_out, 875 | lr_mul = 1, 876 | bias = True 877 | ): 878 | super().__init__() 879 | self.weight = nn.Parameter(torch.randn(dim_out, dim)) 880 | if bias: 881 | self.bias = nn.Parameter(torch.zeros(dim_out)) 882 | 883 | self.lr_mul = lr_mul 884 | 885 | def forward(self, input): 886 | return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul) 887 | 888 | class StyleNetwork(nn.Module): 889 | def __init__( 890 | self, 891 | dim, 892 | depth, 893 | lr_mul = 0.1, 894 | dim_text_latent = 0 895 | ): 896 | super().__init__() 897 | self.dim = dim 898 | self.dim_text_latent = dim_text_latent 899 | 900 | layers = [] 901 | for i in range(depth): 902 | is_first = i == 0 903 | dim_in = (dim + dim_text_latent) if is_first else dim 904 | 905 | layers.extend([EqualLinear(dim_in, dim, lr_mul), leaky_relu()]) 906 | 907 | self.net = nn.Sequential(*layers) 908 | 909 | def forward( 910 | self, 911 | x, 912 | text_latent = None 913 | ): 914 | x = F.normalize(x, dim = 1) 915 | 916 | if self.dim_text_latent > 0: 917 | assert exists(text_latent) 918 | x = torch.cat((x, text_latent), dim = -1) 919 | 920 | return self.net(x) 921 | 922 | # noise 923 | 924 | class Noise(nn.Module): 925 | def __init__(self, dim): 926 | super().__init__() 927 | self.weight = nn.Parameter(torch.zeros(dim, 1, 1)) 928 | 929 | def forward( 930 | self, 931 | x, 932 | noise = None 933 | ): 934 | b, _, h, w, device = *x.shape, x.device 935 | 936 | if not exists(noise): 937 | noise = torch.randn(b, 1, h, w, device = device) 938 | 939 | return x + self.weight * noise 940 | 941 | # generator 942 | 943 | class BaseGenerator(nn.Module): 944 | pass 945 | 946 | class Generator(BaseGenerator): 947 | @beartype 948 | def __init__( 949 | self, 950 | *, 951 | image_size, 952 | dim_capacity = 16, 953 | dim_max = 2048, 954 | channels = 3, 955 | style_network: StyleNetwork | Dict | None = None, 956 | style_network_dim = None, 957 | text_encoder: TextEncoder | Dict | None = None, 958 | dim_latent = 512, 959 | self_attn_resolutions: Tuple[int, ...] = (32, 16), 960 | self_attn_dim_head = 64, 961 | self_attn_heads = 8, 962 | self_attn_dot_product = True, 963 | self_attn_ff_mult = 4, 964 | cross_attn_resolutions: Tuple[int, ...] = (32, 16), 965 | cross_attn_dim_head = 64, 966 | cross_attn_heads = 8, 967 | cross_attn_ff_mult = 4, 968 | num_conv_kernels = 2, # the number of adaptive conv kernels 969 | num_skip_layers_excite = 0, 970 | unconditional = False, 971 | pixel_shuffle_upsample = False 972 | ): 973 | super().__init__() 974 | self.channels = channels 975 | 976 | if isinstance(style_network, dict): 977 | style_network = StyleNetwork(**style_network) 978 | 979 | self.style_network = style_network 980 | 981 | assert exists(style_network) ^ exists(style_network_dim), 'style_network_dim must be given to the generator if StyleNetwork not passed in as style_network' 982 | 983 | if not exists(style_network_dim): 984 | style_network_dim = style_network.dim 985 | 986 | self.style_network_dim = style_network_dim 987 | 988 | if isinstance(text_encoder, dict): 989 | text_encoder = TextEncoder(**text_encoder) 990 | 991 | self.text_encoder = text_encoder 992 | 993 | self.unconditional = unconditional 994 | 995 | assert not (unconditional and exists(text_encoder)) 996 | assert not (unconditional and exists(style_network) and style_network.dim_text_latent > 0) 997 | assert unconditional or (exists(text_encoder) and text_encoder.dim == style_network.dim_text_latent), 'the `dim_text_latent` on your StyleNetwork must be equal to the `dim` set for the TextEncoder' 998 | 999 | assert is_power_of_two(image_size) 1000 | num_layers = int(log2(image_size) - 1) 1001 | self.num_layers = num_layers 1002 | 1003 | # generator requires convolutions conditioned by the style vector 1004 | # and also has N convolutional kernels adaptively selected (one of the only novelties of the paper) 1005 | 1006 | is_adaptive = num_conv_kernels > 1 1007 | dim_kernel_mod = num_conv_kernels if is_adaptive else 0 1008 | 1009 | style_embed_split_dims = [] 1010 | 1011 | adaptive_conv = partial(AdaptiveConv2DMod, kernel = 3, num_conv_kernels = num_conv_kernels) 1012 | 1013 | # initial 4x4 block and conv 1014 | 1015 | self.init_block = nn.Parameter(torch.randn(dim_latent, 4, 4)) 1016 | self.init_conv = adaptive_conv(dim_latent, dim_latent) 1017 | 1018 | style_embed_split_dims.extend([ 1019 | dim_latent, 1020 | dim_kernel_mod 1021 | ]) 1022 | 1023 | # main network 1024 | 1025 | num_layers = int(log2(image_size) - 1) 1026 | self.num_layers = num_layers 1027 | 1028 | resolutions = image_size / ((2 ** torch.arange(num_layers).flip(0))) 1029 | resolutions = resolutions.long().tolist() 1030 | 1031 | dim_layers = (2 ** (torch.arange(num_layers) + 1)) * dim_capacity 1032 | dim_layers.clamp_(max = dim_max) 1033 | 1034 | dim_layers = torch.flip(dim_layers, (0,)) 1035 | dim_layers = F.pad(dim_layers, (1, 0), value = dim_latent) 1036 | 1037 | dim_layers = dim_layers.tolist() 1038 | 1039 | dim_pairs = list(zip(dim_layers[:-1], dim_layers[1:])) 1040 | 1041 | self.num_skip_layers_excite = num_skip_layers_excite 1042 | 1043 | self.layers = nn.ModuleList([]) 1044 | 1045 | # go through layers and construct all parameters 1046 | 1047 | for ind, ((dim_in, dim_out), resolution) in enumerate(zip(dim_pairs, resolutions)): 1048 | is_last = (ind + 1) == len(dim_pairs) 1049 | is_first = ind == 0 1050 | 1051 | should_upsample = not is_first 1052 | should_upsample_rgb = not is_last 1053 | should_skip_layer_excite = num_skip_layers_excite > 0 and (ind + num_skip_layers_excite) < len(dim_pairs) 1054 | 1055 | has_self_attn = resolution in self_attn_resolutions 1056 | has_cross_attn = resolution in cross_attn_resolutions and not unconditional 1057 | 1058 | skip_squeeze_excite = None 1059 | if should_skip_layer_excite: 1060 | dim_skip_in, _ = dim_pairs[ind + num_skip_layers_excite] 1061 | skip_squeeze_excite = SqueezeExcite(dim_in, dim_skip_in) 1062 | 1063 | resnet_block = nn.ModuleList([ 1064 | adaptive_conv(dim_in, dim_out), 1065 | Noise(dim_out), 1066 | leaky_relu(), 1067 | adaptive_conv(dim_out, dim_out), 1068 | Noise(dim_out), 1069 | leaky_relu() 1070 | ]) 1071 | 1072 | to_rgb = AdaptiveConv2DMod(dim_out, channels, 1, num_conv_kernels = 1, demod = False) 1073 | 1074 | self_attn = cross_attn = rgb_upsample = upsample = None 1075 | 1076 | upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample 1077 | 1078 | upsample = upsample_klass(dim_in) if should_upsample else None 1079 | rgb_upsample = upsample_klass(channels) if should_upsample_rgb else None 1080 | 1081 | if has_self_attn: 1082 | self_attn = SelfAttentionBlock( 1083 | dim_out, 1084 | dim_head = self_attn_dim_head, 1085 | heads = self_attn_heads, 1086 | ff_mult = self_attn_ff_mult, 1087 | dot_product = self_attn_dot_product 1088 | ) 1089 | 1090 | if has_cross_attn: 1091 | cross_attn = CrossAttentionBlock( 1092 | dim_out, 1093 | dim_context = text_encoder.dim, 1094 | dim_head = cross_attn_dim_head, 1095 | heads = cross_attn_heads, 1096 | ff_mult = cross_attn_ff_mult, 1097 | ) 1098 | 1099 | style_embed_split_dims.extend([ 1100 | dim_in, # for first conv in resnet block 1101 | dim_kernel_mod, # first conv kernel selection 1102 | dim_out, # second conv in resnet block 1103 | dim_kernel_mod, # second conv kernel selection 1104 | dim_out, # to RGB conv 1105 | 0, # RGB conv kernel selection 1106 | ]) 1107 | 1108 | self.layers.append(nn.ModuleList([ 1109 | skip_squeeze_excite, 1110 | resnet_block, 1111 | to_rgb, 1112 | self_attn, 1113 | cross_attn, 1114 | upsample, 1115 | rgb_upsample 1116 | ])) 1117 | 1118 | # determine the projection of the style embedding to convolutional modulation weights (+ adaptive kernel selection weights) for all layers 1119 | 1120 | self.style_to_conv_modulations = nn.Linear(style_network_dim, sum(style_embed_split_dims)) 1121 | self.style_embed_split_dims = style_embed_split_dims 1122 | 1123 | self.apply(self.init_) 1124 | nn.init.normal_(self.init_block, std = 0.02) 1125 | 1126 | def init_(self, m): 1127 | if type(m) in {nn.Conv2d, nn.Linear}: 1128 | nn.init.kaiming_normal_(m.weight, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu') 1129 | 1130 | @property 1131 | def total_params(self): 1132 | return sum([p.numel() for p in self.parameters() if p.requires_grad]) 1133 | 1134 | @property 1135 | def device(self): 1136 | return next(self.parameters()).device 1137 | 1138 | @beartype 1139 | def forward( 1140 | self, 1141 | styles = None, 1142 | noise = None, 1143 | texts: List[str] | None = None, 1144 | text_encodings: Tensor | None = None, 1145 | global_text_tokens = None, 1146 | fine_text_tokens = None, 1147 | text_mask = None, 1148 | batch_size = 1, 1149 | return_all_rgbs = False 1150 | ): 1151 | # take care of text encodings 1152 | # which requires global text tokens to adaptively select the kernels from the main contribution in the paper 1153 | # and fine text tokens to attend to using cross attention 1154 | 1155 | if not self.unconditional: 1156 | if exists(texts) or exists(text_encodings): 1157 | assert exists(texts) ^ exists(text_encodings), 'either raw texts as List[str] or text_encodings (from clip) as Tensor is passed in, but not both' 1158 | assert exists(self.text_encoder) 1159 | 1160 | if exists(texts): 1161 | text_encoder_kwargs = dict(texts = texts) 1162 | elif exists(text_encodings): 1163 | text_encoder_kwargs = dict(text_encodings = text_encodings) 1164 | 1165 | global_text_tokens, fine_text_tokens, text_mask = self.text_encoder(**text_encoder_kwargs) 1166 | else: 1167 | assert all([*map(exists, (global_text_tokens, fine_text_tokens, text_mask))]), 'raw text or text embeddings were not passed in for conditional training' 1168 | else: 1169 | assert not any([*map(exists, (texts, global_text_tokens, fine_text_tokens))]) 1170 | 1171 | # determine styles 1172 | 1173 | if not exists(styles): 1174 | assert exists(self.style_network) 1175 | 1176 | if not exists(noise): 1177 | noise = torch.randn((batch_size, self.style_network_dim), device = self.device) 1178 | 1179 | styles = self.style_network(noise, global_text_tokens) 1180 | 1181 | # project styles to conv modulations 1182 | 1183 | conv_mods = self.style_to_conv_modulations(styles) 1184 | conv_mods = conv_mods.split(self.style_embed_split_dims, dim = -1) 1185 | conv_mods = iter(conv_mods) 1186 | 1187 | # prepare initial block 1188 | 1189 | batch_size = styles.shape[0] 1190 | 1191 | x = repeat(self.init_block, 'c h w -> b c h w', b = batch_size) 1192 | x = self.init_conv(x, mod = next(conv_mods), kernel_mod = next(conv_mods)) 1193 | 1194 | rgb = torch.zeros((batch_size, self.channels, 4, 4), device = self.device, dtype = x.dtype) 1195 | 1196 | # skip layer squeeze excitations 1197 | 1198 | excitations = [None] * self.num_skip_layers_excite 1199 | 1200 | # all the rgb's of each layer of the generator is to be saved for multi-resolution input discrimination 1201 | 1202 | rgbs = [] 1203 | 1204 | # main network 1205 | 1206 | for squeeze_excite, (resnet_conv1, noise1, act1, resnet_conv2, noise2, act2), to_rgb_conv, self_attn, cross_attn, upsample, upsample_rgb in self.layers: 1207 | 1208 | if exists(upsample): 1209 | x = upsample(x) 1210 | 1211 | if exists(squeeze_excite): 1212 | skip_excite = squeeze_excite(x) 1213 | excitations.append(skip_excite) 1214 | 1215 | excite = safe_unshift(excitations) 1216 | if exists(excite): 1217 | x = x * excite 1218 | 1219 | x = resnet_conv1(x, mod = next(conv_mods), kernel_mod = next(conv_mods)) 1220 | x = noise1(x) 1221 | x = act1(x) 1222 | 1223 | x = resnet_conv2(x, mod = next(conv_mods), kernel_mod = next(conv_mods)) 1224 | x = noise2(x) 1225 | x = act2(x) 1226 | 1227 | if exists(self_attn): 1228 | x = self_attn(x) 1229 | 1230 | if exists(cross_attn): 1231 | x = cross_attn(x, context = fine_text_tokens, mask = text_mask) 1232 | 1233 | layer_rgb = to_rgb_conv(x, mod = next(conv_mods), kernel_mod = next(conv_mods)) 1234 | 1235 | rgb = rgb + layer_rgb 1236 | 1237 | rgbs.append(rgb) 1238 | 1239 | if exists(upsample_rgb): 1240 | rgb = upsample_rgb(rgb) 1241 | 1242 | # sanity check 1243 | 1244 | assert is_empty([*conv_mods]), 'convolutions were incorrectly modulated' 1245 | 1246 | if return_all_rgbs: 1247 | return rgb, rgbs 1248 | 1249 | return rgb 1250 | 1251 | # discriminator 1252 | 1253 | @beartype 1254 | class SimpleDecoder(nn.Module): 1255 | def __init__( 1256 | self, 1257 | dim, 1258 | *, 1259 | dims: Tuple[int, ...], 1260 | patch_dim: int = 1, 1261 | frac_patches: float = 1., 1262 | dropout: float = 0.5 1263 | ): 1264 | super().__init__() 1265 | assert 0 < frac_patches <= 1. 1266 | 1267 | self.patch_dim = patch_dim 1268 | self.frac_patches = frac_patches 1269 | 1270 | self.dropout = nn.Dropout(dropout) 1271 | 1272 | dims = [dim, *dims] 1273 | 1274 | layers = [conv2d_3x3(dim, dim)] 1275 | 1276 | for dim_in, dim_out in zip(dims[:-1], dims[1:]): 1277 | layers.append(nn.Sequential( 1278 | Upsample(dim_in), 1279 | conv2d_3x3(dim_in, dim_out), 1280 | leaky_relu() 1281 | )) 1282 | 1283 | self.net = nn.Sequential(*layers) 1284 | 1285 | @property 1286 | def device(self): 1287 | return next(self.parameters()).device 1288 | 1289 | def forward( 1290 | self, 1291 | fmap, 1292 | orig_image 1293 | ): 1294 | fmap = self.dropout(fmap) 1295 | 1296 | if self.frac_patches < 1.: 1297 | batch, patch_dim = fmap.shape[0], self.patch_dim 1298 | fmap_size, img_size = fmap.shape[-1], orig_image.shape[-1] 1299 | 1300 | assert divisible_by(fmap_size, patch_dim), f'feature map dimensions are {fmap_size}, but the patch dim was designated to be {patch_dim}' 1301 | assert divisible_by(img_size, patch_dim), f'image size is {img_size} but the patch dim was specified to be {patch_dim}' 1302 | 1303 | fmap, orig_image = map(lambda t: rearrange(t, 'b c (p1 h) (p2 w) -> b (p1 p2) c h w', p1 = patch_dim, p2 = patch_dim), (fmap, orig_image)) 1304 | 1305 | total_patches = patch_dim ** 2 1306 | num_patches_recon = max(int(self.frac_patches * total_patches), 1) 1307 | 1308 | batch_arange = torch.arange(batch, device = self.device)[..., None] 1309 | batch_randperm = torch.randn((batch, total_patches)).sort(dim = -1).indices 1310 | patch_indices = batch_randperm[..., :num_patches_recon] 1311 | 1312 | fmap, orig_image = map(lambda t: t[batch_arange, patch_indices], (fmap, orig_image)) 1313 | fmap, orig_image = map(lambda t: rearrange(t, 'b p ... -> (b p) ...'), (fmap, orig_image)) 1314 | 1315 | recon = self.net(fmap) 1316 | return F.mse_loss(recon, orig_image) 1317 | 1318 | class RandomFixedProjection(nn.Module): 1319 | def __init__( 1320 | self, 1321 | dim, 1322 | dim_out, 1323 | channel_first = True 1324 | ): 1325 | super().__init__() 1326 | weights = torch.randn(dim, dim_out) 1327 | nn.init.kaiming_normal_(weights, mode = 'fan_out', nonlinearity = 'linear') 1328 | 1329 | self.channel_first = channel_first 1330 | self.register_buffer('fixed_weights', weights) 1331 | 1332 | def forward(self, x): 1333 | if not self.channel_first: 1334 | return x @ self.fixed_weights 1335 | 1336 | return einsum('b c ..., c d -> b d ...', x, self.fixed_weights) 1337 | 1338 | class VisionAidedDiscriminator(nn.Module): 1339 | """ the vision-aided gan loss """ 1340 | 1341 | @beartype 1342 | def __init__( 1343 | self, 1344 | *, 1345 | depth = 2, 1346 | dim_head = 64, 1347 | heads = 8, 1348 | clip: OpenClipAdapter | None = None, 1349 | layer_indices = (-1, -2, -3), 1350 | conv_dim = None, 1351 | text_dim = None, 1352 | unconditional = False, 1353 | num_conv_kernels = 2 1354 | ): 1355 | super().__init__() 1356 | 1357 | if not exists(clip): 1358 | clip = OpenClipAdapter() 1359 | 1360 | self.clip = clip 1361 | dim = clip._dim_image_latent 1362 | 1363 | self.unconditional = unconditional 1364 | text_dim = default(text_dim, dim) 1365 | conv_dim = default(conv_dim, dim) 1366 | 1367 | self.layer_discriminators = nn.ModuleList([]) 1368 | self.layer_indices = layer_indices 1369 | 1370 | conv_klass = partial(AdaptiveConv2DMod, kernel = 3, num_conv_kernels = num_conv_kernels) if not unconditional else conv2d_3x3 1371 | 1372 | for _ in layer_indices: 1373 | self.layer_discriminators.append(nn.ModuleList([ 1374 | RandomFixedProjection(dim, conv_dim), 1375 | conv_klass(conv_dim, conv_dim), 1376 | nn.Linear(text_dim, conv_dim) if not unconditional else None, 1377 | nn.Linear(text_dim, num_conv_kernels) if not unconditional else None, 1378 | nn.Sequential( 1379 | conv2d_3x3(conv_dim, 1), 1380 | Rearrange('b 1 ... -> b ...') 1381 | ) 1382 | ])) 1383 | 1384 | def parameters(self): 1385 | return self.layer_discriminators.parameters() 1386 | 1387 | @property 1388 | def total_params(self): 1389 | return sum([p.numel() for p in self.parameters()]) 1390 | 1391 | @beartype 1392 | def forward( 1393 | self, 1394 | images, 1395 | texts: List[str] | None = None, 1396 | text_embeds: Tensor | None = None, 1397 | return_clip_encodings = False 1398 | ): 1399 | 1400 | assert self.unconditional or (exists(text_embeds) ^ exists(texts)) 1401 | 1402 | with torch.no_grad(): 1403 | if not self.unconditional and exists(texts): 1404 | self.clip.eval() 1405 | text_embeds = self.clip.embed_texts 1406 | 1407 | _, image_encodings = self.clip.embed_images(images) 1408 | 1409 | logits = [] 1410 | 1411 | for layer_index, (rand_proj, conv, to_conv_mod, to_conv_kernel_mod, to_logits) in zip(self.layer_indices, self.layer_discriminators): 1412 | image_encoding = image_encodings[layer_index] 1413 | 1414 | cls_token, rest_tokens = image_encoding[:, :1], image_encoding[:, 1:] 1415 | height_width = int(sqrt(rest_tokens.shape[-2])) # assume square 1416 | 1417 | img_fmap = rearrange(rest_tokens, 'b (h w) d -> b d h w', h = height_width) 1418 | 1419 | img_fmap = img_fmap + rearrange(cls_token, 'b 1 d -> b d 1 1 ') # pool the cls token into the rest of the tokens 1420 | 1421 | img_fmap = rand_proj(img_fmap) 1422 | 1423 | if self.unconditional: 1424 | img_fmap = conv(img_fmap) 1425 | else: 1426 | assert exists(text_embeds) 1427 | 1428 | img_fmap = conv( 1429 | img_fmap, 1430 | mod = to_conv_mod(text_embeds), 1431 | kernel_mod = to_conv_kernel_mod(text_embeds) 1432 | ) 1433 | 1434 | layer_logits = to_logits(img_fmap) 1435 | 1436 | logits.append(layer_logits) 1437 | 1438 | if not return_clip_encodings: 1439 | return logits 1440 | 1441 | return logits, image_encodings 1442 | 1443 | class Predictor(nn.Module): 1444 | def __init__( 1445 | self, 1446 | dim, 1447 | depth = 4, 1448 | num_conv_kernels = 2, 1449 | unconditional = False 1450 | ): 1451 | super().__init__() 1452 | self.unconditional = unconditional 1453 | self.residual_fn = nn.Conv2d(dim, dim, 1) 1454 | self.residual_scale = 2 ** -0.5 1455 | 1456 | self.layers = nn.ModuleList([]) 1457 | 1458 | klass = nn.Conv2d if unconditional else partial(AdaptiveConv2DMod, num_conv_kernels = num_conv_kernels) 1459 | klass_kwargs = dict(padding = 1) if unconditional else dict() 1460 | 1461 | for ind in range(depth): 1462 | self.layers.append(nn.ModuleList([ 1463 | klass(dim, dim, 3, **klass_kwargs), 1464 | leaky_relu(), 1465 | klass(dim, dim, 3, **klass_kwargs), 1466 | leaky_relu() 1467 | ])) 1468 | 1469 | self.to_logits = nn.Conv2d(dim, 1, 1) 1470 | 1471 | def forward( 1472 | self, 1473 | x, 1474 | mod = None, 1475 | kernel_mod = None 1476 | ): 1477 | residual = self.residual_fn(x) 1478 | 1479 | kwargs = dict() 1480 | 1481 | if not self.unconditional: 1482 | kwargs = dict(mod = mod, kernel_mod = kernel_mod) 1483 | 1484 | for conv1, activation, conv2, activation in self.layers: 1485 | 1486 | inner_residual = x 1487 | 1488 | x = conv1(x, **kwargs) 1489 | x = activation(x) 1490 | x = conv2(x, **kwargs) 1491 | x = activation(x) 1492 | 1493 | x = x + inner_residual 1494 | x = x * self.residual_scale 1495 | 1496 | x = x + residual 1497 | return self.to_logits(x) 1498 | 1499 | class Discriminator(nn.Module): 1500 | @beartype 1501 | def __init__( 1502 | self, 1503 | *, 1504 | dim_capacity = 16, 1505 | image_size, 1506 | dim_max = 2048, 1507 | channels = 3, 1508 | attn_resolutions: Tuple[int, ...] = (32, 16), 1509 | attn_dim_head = 64, 1510 | attn_heads = 8, 1511 | self_attn_dot_product = False, 1512 | ff_mult = 4, 1513 | text_encoder: TextEncoder | Dict | None = None, 1514 | text_dim = None, 1515 | filter_input_resolutions: bool = True, 1516 | multiscale_input_resolutions: Tuple[int, ...] = (64, 32, 16, 8), 1517 | multiscale_output_skip_stages: int = 1, 1518 | aux_recon_resolutions: Tuple[int, ...] = (8,), 1519 | aux_recon_patch_dims: Tuple[int, ...] = (2,), 1520 | aux_recon_frac_patches: Tuple[float, ...] = (0.25,), 1521 | aux_recon_fmap_dropout: float = 0.5, 1522 | resize_mode = 'bilinear', 1523 | num_conv_kernels = 2, 1524 | num_skip_layers_excite = 0, 1525 | unconditional = False, 1526 | predictor_depth = 2 1527 | ): 1528 | super().__init__() 1529 | self.unconditional = unconditional 1530 | assert not (unconditional and exists(text_encoder)) 1531 | 1532 | assert is_power_of_two(image_size) 1533 | assert all([*map(is_power_of_two, attn_resolutions)]) 1534 | 1535 | if filter_input_resolutions: 1536 | multiscale_input_resolutions = [*filter(lambda t: t < image_size, multiscale_input_resolutions)] 1537 | 1538 | assert is_unique(multiscale_input_resolutions) 1539 | assert all([*map(is_power_of_two, multiscale_input_resolutions)]) 1540 | assert all([*map(lambda t: t < image_size, multiscale_input_resolutions)]) 1541 | 1542 | self.multiscale_input_resolutions = multiscale_input_resolutions 1543 | 1544 | assert multiscale_output_skip_stages > 0 1545 | multiscale_output_resolutions = [resolution // (2 ** multiscale_output_skip_stages) for resolution in multiscale_input_resolutions] 1546 | 1547 | assert all([*map(lambda t: t >= 4, multiscale_output_resolutions)]) 1548 | 1549 | assert all([*map(lambda t: t < image_size, multiscale_output_resolutions)]) 1550 | 1551 | if len(multiscale_input_resolutions) > 0 and len(multiscale_output_resolutions) > 0: 1552 | assert max(multiscale_input_resolutions) > max(multiscale_output_resolutions) 1553 | assert min(multiscale_input_resolutions) > min(multiscale_output_resolutions) 1554 | 1555 | self.multiscale_output_resolutions = multiscale_output_resolutions 1556 | 1557 | assert all([*map(is_power_of_two, aux_recon_resolutions)]) 1558 | assert len(aux_recon_resolutions) == len(aux_recon_patch_dims) == len(aux_recon_frac_patches) 1559 | 1560 | self.aux_recon_resolutions_to_patches = {resolution: (patch_dim, frac_patches) for resolution, patch_dim, frac_patches in zip(aux_recon_resolutions, aux_recon_patch_dims, aux_recon_frac_patches)} 1561 | 1562 | self.resize_mode = resize_mode 1563 | 1564 | num_layers = int(log2(image_size) - 1) 1565 | self.num_layers = num_layers 1566 | self.image_size = image_size 1567 | 1568 | resolutions = image_size / ((2 ** torch.arange(num_layers))) 1569 | resolutions = resolutions.long().tolist() 1570 | 1571 | dim_layers = (2 ** (torch.arange(num_layers) + 1)) * dim_capacity 1572 | dim_layers = F.pad(dim_layers, (1, 0), value = channels) 1573 | dim_layers.clamp_(max = dim_max) 1574 | 1575 | dim_layers = dim_layers.tolist() 1576 | dim_last = dim_layers[-1] 1577 | dim_pairs = list(zip(dim_layers[:-1], dim_layers[1:])) 1578 | 1579 | self.num_skip_layers_excite = num_skip_layers_excite 1580 | 1581 | self.residual_scale = 2 ** -0.5 1582 | self.layers = nn.ModuleList([]) 1583 | 1584 | upsample_dims = [] 1585 | predictor_dims = [] 1586 | dim_kernel_attn = (num_conv_kernels if num_conv_kernels > 1 else 0) 1587 | 1588 | for ind, ((dim_in, dim_out), resolution) in enumerate(zip(dim_pairs, resolutions)): 1589 | is_first = ind == 0 1590 | is_last = (ind + 1) == len(dim_pairs) 1591 | should_downsample = not is_last 1592 | should_skip_layer_excite = not is_first and num_skip_layers_excite > 0 and (ind + num_skip_layers_excite) < len(dim_pairs) 1593 | 1594 | has_attn = resolution in attn_resolutions 1595 | has_multiscale_output = resolution in multiscale_output_resolutions 1596 | 1597 | has_aux_recon_decoder = resolution in aux_recon_resolutions 1598 | upsample_dims.insert(0, dim_in) 1599 | 1600 | skip_squeeze_excite = None 1601 | if should_skip_layer_excite: 1602 | dim_skip_in, _ = dim_pairs[ind + num_skip_layers_excite] 1603 | skip_squeeze_excite = SqueezeExcite(dim_in, dim_skip_in) 1604 | 1605 | # multi-scale rgb input to feature dimension 1606 | 1607 | from_rgb = nn.Conv2d(channels, dim_in, 7, padding = 3) 1608 | 1609 | # residual convolution 1610 | 1611 | residual_conv = nn.Conv2d(dim_in, dim_out, 1, stride = (2 if should_downsample else 1)) 1612 | 1613 | # main resnet block 1614 | 1615 | resnet_block = nn.Sequential( 1616 | conv2d_3x3(dim_in, dim_out), 1617 | leaky_relu(), 1618 | conv2d_3x3(dim_out, dim_out), 1619 | leaky_relu() 1620 | ) 1621 | 1622 | # multi-scale output 1623 | 1624 | multiscale_output_predictor = None 1625 | 1626 | if has_multiscale_output: 1627 | multiscale_output_predictor = Predictor(dim_out, num_conv_kernels = num_conv_kernels, depth = 2, unconditional = unconditional) 1628 | predictor_dims.extend([dim_out, dim_kernel_attn]) 1629 | 1630 | aux_recon_decoder = None 1631 | 1632 | if has_aux_recon_decoder: 1633 | patch_dim, frac_patches = self.aux_recon_resolutions_to_patches[resolution] 1634 | 1635 | aux_recon_decoder = SimpleDecoder( 1636 | dim_out, 1637 | dims = tuple(upsample_dims), 1638 | patch_dim = patch_dim, 1639 | frac_patches = frac_patches, 1640 | dropout = aux_recon_fmap_dropout 1641 | ) 1642 | 1643 | self.layers.append(nn.ModuleList([ 1644 | skip_squeeze_excite, 1645 | from_rgb, 1646 | resnet_block, 1647 | residual_conv, 1648 | SelfAttentionBlock(dim_out, heads = attn_heads, dim_head = attn_dim_head, ff_mult = ff_mult, dot_product = self_attn_dot_product) if has_attn else None, 1649 | multiscale_output_predictor, 1650 | aux_recon_decoder, 1651 | Downsample(dim_out) if should_downsample else None, 1652 | ])) 1653 | 1654 | self.to_logits = nn.Sequential( 1655 | conv2d_3x3(dim_last, dim_last), 1656 | Rearrange('b c h w -> b (c h w)'), 1657 | nn.Linear(dim_last * (4 ** 2), 1), 1658 | Rearrange('b 1 -> b') 1659 | ) 1660 | 1661 | # take care of text conditioning in the multiscale predictor branches 1662 | 1663 | assert unconditional or (exists(text_dim) ^ exists(text_encoder)) 1664 | 1665 | if not unconditional: 1666 | if isinstance(text_encoder, dict): 1667 | text_encoder = TextEncoder(**text_encoder) 1668 | 1669 | self.text_dim = default(text_dim, text_encoder.dim) 1670 | 1671 | self.predictor_dims = predictor_dims 1672 | self.text_to_conv_conditioning = nn.Linear(self.text_dim, sum(predictor_dims)) if exists(self.text_dim) else None 1673 | 1674 | self.text_encoder = text_encoder 1675 | 1676 | self.apply(self.init_) 1677 | 1678 | def init_(self, m): 1679 | if type(m) in {nn.Conv2d, nn.Linear}: 1680 | nn.init.kaiming_normal_(m.weight, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu') 1681 | 1682 | def resize_image_to(self, images, resolution): 1683 | return F.interpolate(images, resolution, mode = self.resize_mode) 1684 | 1685 | def real_images_to_rgbs(self, images): 1686 | return [self.resize_image_to(images, resolution) for resolution in self.multiscale_input_resolutions] 1687 | 1688 | @property 1689 | def total_params(self): 1690 | return sum([p.numel() for p in self.parameters()]) 1691 | 1692 | @property 1693 | def device(self): 1694 | return next(self.parameters()).device 1695 | 1696 | @beartype 1697 | def forward( 1698 | self, 1699 | images, 1700 | rgbs: List[Tensor], # multi-resolution inputs (rgbs) from the generator 1701 | texts: List[str] | None = None, 1702 | text_encodings: Tensor | None = None, 1703 | text_embeds = None, 1704 | real_images = None, # if this were passed in, the network will automatically append the real to the presumably generated images passed in as the first argument, and generate all intermediate resolutions through resizing and concat appropriately 1705 | return_multiscale_outputs = True, # can force it not to return multi-scale logits 1706 | calc_aux_loss = True 1707 | ): 1708 | if not self.unconditional: 1709 | assert (exists(texts) ^ exists(text_encodings)) ^ exists(text_embeds), 'either texts as List[str] is passed in, or clip text_encodings as Tensor' 1710 | 1711 | if exists(texts): 1712 | assert exists(self.text_encoder) 1713 | text_embeds, *_ = self.text_encoder(texts = texts) 1714 | 1715 | elif exists(text_encodings): 1716 | assert exists(self.text_encoder) 1717 | text_embeds, *_ = self.text_encoder(text_encodings = text_encodings) 1718 | 1719 | assert exists(text_embeds), 'raw text or text embeddings were not passed into discriminator for conditional training' 1720 | 1721 | conv_mods = self.text_to_conv_conditioning(text_embeds).split(self.predictor_dims, dim = -1) 1722 | conv_mods = iter(conv_mods) 1723 | 1724 | else: 1725 | assert not any([*map(exists, (texts, text_embeds))]) 1726 | 1727 | x = images 1728 | 1729 | image_size = (self.image_size, self.image_size) 1730 | 1731 | assert x.shape[-2:] == image_size 1732 | 1733 | batch = x.shape[0] 1734 | 1735 | # index the rgbs by resolution 1736 | 1737 | rgbs_index = {t.shape[-1]: t for t in rgbs} if exists(rgbs) else {} 1738 | 1739 | # assert that the necessary resolutions are there 1740 | 1741 | assert is_empty(set(self.multiscale_input_resolutions) - set(rgbs_index.keys())), f'rgbs of necessary resolution {self.multiscale_input_resolutions} were not passed in' 1742 | 1743 | # hold multiscale outputs 1744 | 1745 | multiscale_outputs = [] 1746 | 1747 | # hold auxiliary recon losses 1748 | 1749 | aux_recon_losses = [] 1750 | 1751 | # excitations 1752 | 1753 | excitations = [None] * (self.num_skip_layers_excite + 1) # +1 since first image in pixel space is not excited 1754 | 1755 | for squeeze_excite, from_rgb, block, residual_fn, attn, predictor, recon_decoder, downsample in self.layers: 1756 | resolution = x.shape[-1] 1757 | 1758 | if exists(squeeze_excite): 1759 | skip_excite = squeeze_excite(x) 1760 | excitations.append(skip_excite) 1761 | 1762 | excite = safe_unshift(excitations) 1763 | 1764 | if exists(excite): 1765 | excite = repeat(excite, 'b ... -> (s b) ...', s = x.shape[0] // excite.shape[0]) 1766 | x = x * excite 1767 | 1768 | batch_prev_stage = x.shape[0] 1769 | has_multiscale_input = resolution in self.multiscale_input_resolutions 1770 | 1771 | if has_multiscale_input: 1772 | rgb = rgbs_index.get(resolution, None) 1773 | 1774 | # multi-scale input features 1775 | 1776 | multi_scale_input_feats = from_rgb(rgb) 1777 | 1778 | # expand multi-scale input features, as could include extra scales from previous stage 1779 | 1780 | multi_scale_input_feats = repeat(multi_scale_input_feats, 'b ... -> (s b) ...', s = x.shape[0] // rgb.shape[0]) 1781 | 1782 | # add the multi-scale input features to the current hidden state from main stem 1783 | 1784 | x = x + multi_scale_input_feats 1785 | 1786 | # and also concat for scale invariance 1787 | 1788 | x = torch.cat((x, multi_scale_input_feats), dim = 0) 1789 | 1790 | residual = residual_fn(x) 1791 | x = block(x) 1792 | 1793 | if exists(attn): 1794 | x = attn(x) 1795 | 1796 | if exists(predictor): 1797 | pred_kwargs = dict() 1798 | if not self.unconditional: 1799 | pred_kwargs = dict(mod = next(conv_mods), kernel_mod = next(conv_mods)) 1800 | 1801 | if return_multiscale_outputs: 1802 | predictor_input = x[:batch_prev_stage] 1803 | multiscale_outputs.append(predictor(predictor_input, **pred_kwargs)) 1804 | 1805 | if exists(downsample): 1806 | x = downsample(x) 1807 | 1808 | x = x + residual 1809 | x = x * self.residual_scale 1810 | 1811 | if exists(recon_decoder) and calc_aux_loss: 1812 | 1813 | recon_output = x[:batch_prev_stage] 1814 | recon_output = rearrange(x, '(s b) ... -> s b ...', b = batch) 1815 | 1816 | aux_recon_target = images 1817 | 1818 | # only use the input real images for aux recon 1819 | 1820 | recon_output = recon_output[0] 1821 | 1822 | # only reconstruct a fraction of images across batch and scale 1823 | # for efficiency 1824 | 1825 | aux_recon_loss = recon_decoder(recon_output, aux_recon_target) 1826 | aux_recon_losses.append(aux_recon_loss) 1827 | 1828 | # sanity check 1829 | 1830 | assert self.unconditional or is_empty([*conv_mods]), 'convolutions were incorrectly modulated' 1831 | 1832 | # to logits 1833 | 1834 | logits = self.to_logits(x) 1835 | logits = rearrange(logits, '(s b) ... -> s b ...', b = batch) 1836 | 1837 | return logits, multiscale_outputs, aux_recon_losses 1838 | 1839 | # gan 1840 | 1841 | TrainDiscrLosses = namedtuple('TrainDiscrLosses', [ 1842 | 'divergence', 1843 | 'multiscale_divergence', 1844 | 'vision_aided_divergence', 1845 | 'total_matching_aware_loss', 1846 | 'gradient_penalty', 1847 | 'aux_reconstruction' 1848 | ]) 1849 | 1850 | TrainGenLosses = namedtuple('TrainGenLosses', [ 1851 | 'divergence', 1852 | 'multiscale_divergence', 1853 | 'total_vd_divergence', 1854 | 'contrastive_loss' 1855 | ]) 1856 | 1857 | class GigaGAN(nn.Module): 1858 | @beartype 1859 | def __init__( 1860 | self, 1861 | *, 1862 | generator: BaseGenerator | Dict, 1863 | discriminator: Discriminator | Dict, 1864 | vision_aided_discriminator: VisionAidedDiscriminator | Dict | None = None, 1865 | diff_augment: DiffAugment | Dict | None = None, 1866 | learning_rate = 2e-4, 1867 | betas = (0.5, 0.9), 1868 | weight_decay = 0., 1869 | discr_aux_recon_loss_weight = 1., 1870 | multiscale_divergence_loss_weight = 0.1, 1871 | vision_aided_divergence_loss_weight = 0.5, 1872 | generator_contrastive_loss_weight = 0.1, 1873 | matching_awareness_loss_weight = 0.1, 1874 | calc_multiscale_loss_every = 1, 1875 | apply_gradient_penalty_every = 4, 1876 | resize_image_mode = 'bilinear', 1877 | train_upsampler = False, 1878 | log_steps_every = 20, 1879 | create_ema_generator_at_init = True, 1880 | save_and_sample_every = 1000, 1881 | early_save_thres_steps = 2500, 1882 | early_save_and_sample_every = 100, 1883 | num_samples = 25, 1884 | model_folder = './gigagan-models', 1885 | results_folder = './gigagan-results', 1886 | sample_upsampler_dl: DataLoader | None = None, 1887 | accelerator: Accelerator | None = None, 1888 | accelerate_kwargs: dict = {}, 1889 | find_unused_parameters = True, 1890 | amp = False, 1891 | mixed_precision_type = 'fp16' 1892 | ): 1893 | super().__init__() 1894 | 1895 | # create accelerator 1896 | 1897 | if accelerator: 1898 | self.accelerator = accelerator 1899 | assert is_empty(accelerate_kwargs) 1900 | else: 1901 | kwargs = DistributedDataParallelKwargs(find_unused_parameters = find_unused_parameters) 1902 | 1903 | self.accelerator = Accelerator( 1904 | kwargs_handlers = [kwargs], 1905 | mixed_precision = mixed_precision_type if amp else 'no', 1906 | **accelerate_kwargs 1907 | ) 1908 | 1909 | # whether to train upsampler or not 1910 | 1911 | self.train_upsampler = train_upsampler 1912 | 1913 | if train_upsampler: 1914 | from gigagan_pytorch.unet_upsampler import UnetUpsampler 1915 | generator_klass = UnetUpsampler 1916 | else: 1917 | generator_klass = Generator 1918 | 1919 | # gradient penalty and auxiliary recon loss 1920 | 1921 | self.apply_gradient_penalty_every = apply_gradient_penalty_every 1922 | self.calc_multiscale_loss_every = calc_multiscale_loss_every 1923 | 1924 | if isinstance(generator, dict): 1925 | generator = generator_klass(**generator) 1926 | 1927 | if isinstance(discriminator, dict): 1928 | discriminator = Discriminator(**discriminator) 1929 | 1930 | if exists(vision_aided_discriminator) and isinstance(vision_aided_discriminator, dict): 1931 | vision_aided_discriminator = VisionAidedDiscriminator(**vision_aided_discriminator) 1932 | 1933 | assert isinstance(generator, generator_klass) 1934 | 1935 | # diff augment 1936 | 1937 | if isinstance(diff_augment, dict): 1938 | diff_augment = DiffAugment(**diff_augment) 1939 | 1940 | self.diff_augment = diff_augment 1941 | 1942 | # use _base to designate unwrapped models 1943 | 1944 | self.G = generator 1945 | self.D = discriminator 1946 | self.VD = vision_aided_discriminator 1947 | 1948 | # validate multiscale input resolutions 1949 | 1950 | if train_upsampler: 1951 | assert is_empty(set(discriminator.multiscale_input_resolutions) - set(generator.allowable_rgb_resolutions)), f'only multiscale input resolutions of {generator.allowable_rgb_resolutions} is allowed based on the unet input and output image size. simply do Discriminator(multiscale_input_resolutions = unet.allowable_rgb_resolutions) to resolve this error' 1952 | 1953 | # ema 1954 | 1955 | self.has_ema_generator = False 1956 | 1957 | if self.is_main and create_ema_generator_at_init: 1958 | self.create_ema_generator() 1959 | 1960 | # print number of parameters 1961 | 1962 | self.print('\n') 1963 | 1964 | self.print(f'Generator: {numerize.numerize(generator.total_params)}') 1965 | self.print(f'Discriminator: {numerize.numerize(discriminator.total_params)}') 1966 | 1967 | if exists(self.VD): 1968 | self.print(f'Vision Discriminator: {numerize.numerize(vision_aided_discriminator.total_params)}') 1969 | 1970 | self.print('\n') 1971 | 1972 | # text encoder 1973 | 1974 | assert generator.unconditional == discriminator.unconditional 1975 | assert not exists(vision_aided_discriminator) or vision_aided_discriminator.unconditional == generator.unconditional 1976 | 1977 | self.unconditional = generator.unconditional 1978 | 1979 | # optimizers 1980 | 1981 | self.G_opt = get_optimizer(self.G.parameters(), lr = learning_rate, betas = betas, weight_decay = weight_decay) 1982 | self.D_opt = get_optimizer(self.D.parameters(), lr = learning_rate, betas = betas, weight_decay = weight_decay) 1983 | 1984 | # prepare for distributed 1985 | 1986 | self.G, self.D, self.G_opt, self.D_opt = self.accelerator.prepare(self.G, self.D, self.G_opt, self.D_opt) 1987 | 1988 | # vision aided discriminator optimizer 1989 | 1990 | if exists(self.VD): 1991 | self.VD_opt = get_optimizer(self.VD.parameters(), lr = learning_rate, betas = betas, weight_decay = weight_decay) 1992 | self.VD_opt = self.accelerator.prepare(self.VD_opt) 1993 | 1994 | # loss related 1995 | 1996 | self.discr_aux_recon_loss_weight = discr_aux_recon_loss_weight 1997 | self.multiscale_divergence_loss_weight = multiscale_divergence_loss_weight 1998 | self.vision_aided_divergence_loss_weight = vision_aided_divergence_loss_weight 1999 | self.generator_contrastive_loss_weight = generator_contrastive_loss_weight 2000 | self.matching_awareness_loss_weight = matching_awareness_loss_weight 2001 | 2002 | # resize image mode 2003 | 2004 | self.resize_image_mode = resize_image_mode 2005 | 2006 | # steps 2007 | 2008 | self.log_steps_every = log_steps_every 2009 | 2010 | self.register_buffer('steps', torch.ones(1, dtype = torch.long)) 2011 | 2012 | # save and sample 2013 | 2014 | self.save_and_sample_every = save_and_sample_every 2015 | self.early_save_thres_steps = early_save_thres_steps 2016 | self.early_save_and_sample_every = early_save_and_sample_every 2017 | 2018 | self.num_samples = num_samples 2019 | 2020 | self.train_dl = None 2021 | 2022 | self.sample_upsampler_dl_iter = None 2023 | if exists(sample_upsampler_dl): 2024 | self.sample_upsampler_dl_iter = cycle(self.sample_upsampler_dl) 2025 | 2026 | self.results_folder = Path(results_folder) 2027 | self.model_folder = Path(model_folder) 2028 | 2029 | mkdir_if_not_exists(self.results_folder) 2030 | mkdir_if_not_exists(self.model_folder) 2031 | 2032 | def save(self, path, overwrite = True): 2033 | path = Path(path) 2034 | mkdir_if_not_exists(path.parents[0]) 2035 | 2036 | assert overwrite or not path.exists() 2037 | 2038 | pkg = dict( 2039 | G = self.unwrapped_G.state_dict(), 2040 | D = self.unwrapped_D.state_dict(), 2041 | G_opt = self.G_opt.state_dict(), 2042 | D_opt = self.D_opt.state_dict(), 2043 | steps = self.steps.item(), 2044 | version = __version__ 2045 | ) 2046 | 2047 | if exists(self.G_opt.scaler): 2048 | pkg['G_scaler'] = self.G_opt.scaler.state_dict() 2049 | 2050 | if exists(self.D_opt.scaler): 2051 | pkg['D_scaler'] = self.D_opt.scaler.state_dict() 2052 | 2053 | if exists(self.VD): 2054 | pkg['VD'] = self.unwrapped_VD.state_dict() 2055 | pkg['VD_opt'] = self.VD_opt.state_dict() 2056 | 2057 | if exists(self.VD_opt.scaler): 2058 | pkg['VD_scaler'] = self.VD_opt.scaler.state_dict() 2059 | 2060 | if self.has_ema_generator: 2061 | pkg['G_ema'] = self.G_ema.state_dict() 2062 | 2063 | torch.save(pkg, str(path)) 2064 | 2065 | def load(self, path, strict = False): 2066 | path = Path(path) 2067 | assert path.exists() 2068 | 2069 | pkg = torch.load(str(path)) 2070 | 2071 | if 'version' in pkg and pkg['version'] != __version__: 2072 | print(f"trying to load from version {pkg['version']}") 2073 | 2074 | self.unwrapped_G.load_state_dict(pkg['G'], strict = strict) 2075 | self.unwrapped_D.load_state_dict(pkg['D'], strict = strict) 2076 | 2077 | if exists(self.VD): 2078 | self.unwrapped_VD.load_state_dict(pkg['VD'], strict = strict) 2079 | 2080 | if self.has_ema_generator: 2081 | self.G_ema.load_state_dict(pkg['G_ema']) 2082 | 2083 | if 'steps' in pkg: 2084 | self.steps.copy_(torch.tensor([pkg['steps']])) 2085 | 2086 | if 'G_opt'not in pkg or 'D_opt' not in pkg: 2087 | return 2088 | 2089 | try: 2090 | self.G_opt.load_state_dict(pkg['G_opt']) 2091 | self.D_opt.load_state_dict(pkg['D_opt']) 2092 | 2093 | if exists(self.VD): 2094 | self.VD_opt.load_state_dict(pkg['VD_opt']) 2095 | 2096 | if 'G_scaler' in pkg and exists(self.G_opt.scaler): 2097 | self.G_opt.scaler.load_state_dict(pkg['G_scaler']) 2098 | 2099 | if 'D_scaler' in pkg and exists(self.D_opt.scaler): 2100 | self.D_opt.scaler.load_state_dict(pkg['D_scaler']) 2101 | 2102 | if 'VD_scaler' in pkg and exists(self.VD_opt.scaler): 2103 | self.VD_opt.scaler.load_state_dict(pkg['VD_scaler']) 2104 | 2105 | except Exception as e: 2106 | self.print(f'unable to load optimizers {e.msg}- optimizer states will be reset') 2107 | pass 2108 | 2109 | # accelerate related 2110 | 2111 | @property 2112 | def device(self): 2113 | return self.accelerator.device 2114 | 2115 | @property 2116 | def unwrapped_G(self): 2117 | return self.accelerator.unwrap_model(self.G) 2118 | 2119 | @property 2120 | def unwrapped_D(self): 2121 | return self.accelerator.unwrap_model(self.D) 2122 | 2123 | @property 2124 | def unwrapped_VD(self): 2125 | return self.accelerator.unwrap_model(self.VD) 2126 | 2127 | @property 2128 | def need_vision_aided_discriminator(self): 2129 | return exists(self.VD) and self.vision_aided_divergence_loss_weight > 0. 2130 | 2131 | @property 2132 | def need_contrastive_loss(self): 2133 | return self.generator_contrastive_loss_weight > 0. and not self.unconditional 2134 | 2135 | def print(self, msg): 2136 | self.accelerator.print(msg) 2137 | 2138 | @property 2139 | def is_distributed(self): 2140 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) 2141 | 2142 | @property 2143 | def is_main(self): 2144 | return self.accelerator.is_main_process 2145 | 2146 | @property 2147 | def is_local_main(self): 2148 | return self.accelerator.is_local_main_process 2149 | 2150 | def resize_image_to(self, images, resolution): 2151 | return F.interpolate(images, resolution, mode = self.resize_image_mode) 2152 | 2153 | @beartype 2154 | def set_dataloader(self, dl: DataLoader): 2155 | assert not exists(self.train_dl), 'training dataloader has already been set' 2156 | 2157 | self.train_dl = dl 2158 | self.train_dl_batch_size = dl.batch_size 2159 | 2160 | self.train_dl = self.accelerator.prepare(self.train_dl) 2161 | 2162 | # generate function 2163 | 2164 | @torch.inference_mode() 2165 | def generate(self, *args, **kwargs): 2166 | model = self.G_ema if self.has_ema_generator else self.G 2167 | model.eval() 2168 | return model(*args, **kwargs) 2169 | 2170 | # create EMA generator 2171 | 2172 | def create_ema_generator( 2173 | self, 2174 | update_every = 10, 2175 | update_after_step = 100, 2176 | decay = 0.995 2177 | ): 2178 | if not self.is_main: 2179 | return 2180 | 2181 | assert not self.has_ema_generator, 'EMA generator has already been created' 2182 | 2183 | self.G_ema = EMA(self.unwrapped_G, update_every = update_every, update_after_step = update_after_step, beta = decay) 2184 | self.has_ema_generator = True 2185 | 2186 | def generate_kwargs(self, dl_iter, batch_size): 2187 | # what to pass into the generator 2188 | # depends on whether training upsampler or not 2189 | 2190 | maybe_text_kwargs = dict() 2191 | if self.train_upsampler or not self.unconditional: 2192 | assert exists(dl_iter) 2193 | 2194 | if self.unconditional: 2195 | real_images = next(dl_iter) 2196 | else: 2197 | result = next(dl_iter) 2198 | assert isinstance(result, tuple), 'dataset should return a tuple of two items for text conditioned training, (images: Tensor, texts: List[str])' 2199 | real_images, texts = result 2200 | 2201 | maybe_text_kwargs['texts'] = texts[:batch_size] 2202 | 2203 | real_images = real_images.to(self.device) 2204 | 2205 | # if training upsample generator, need to downsample real images 2206 | 2207 | if self.train_upsampler: 2208 | size = self.unwrapped_G.input_image_size 2209 | lowres_real_images = F.interpolate(real_images, (size, size)) 2210 | 2211 | G_kwargs = dict(lowres_image = lowres_real_images) 2212 | else: 2213 | assert exists(batch_size) 2214 | 2215 | G_kwargs = dict(batch_size = batch_size) 2216 | 2217 | # create noise 2218 | 2219 | noise = torch.randn(batch_size, self.unwrapped_G.style_network.dim, device = self.device) 2220 | 2221 | G_kwargs.update(noise = noise) 2222 | 2223 | return G_kwargs, maybe_text_kwargs 2224 | 2225 | @beartype 2226 | def train_discriminator_step( 2227 | self, 2228 | dl_iter: Iterable, 2229 | grad_accum_every = 1, 2230 | apply_gradient_penalty = False, 2231 | calc_multiscale_loss = True 2232 | ): 2233 | total_divergence = 0. 2234 | total_vision_aided_divergence = 0. 2235 | 2236 | total_gp_loss = 0. 2237 | total_aux_loss = 0. 2238 | 2239 | total_multiscale_divergence = 0. if calc_multiscale_loss else None 2240 | 2241 | has_matching_awareness = not self.unconditional and self.matching_awareness_loss_weight > 0. 2242 | 2243 | total_matching_aware_loss = 0. 2244 | 2245 | all_texts = [] 2246 | all_fake_images = [] 2247 | all_fake_rgbs = [] 2248 | all_real_images = [] 2249 | 2250 | self.G.train() 2251 | 2252 | self.D.train() 2253 | self.D_opt.zero_grad() 2254 | 2255 | if self.need_vision_aided_discriminator: 2256 | self.VD.train() 2257 | self.VD_opt.zero_grad() 2258 | 2259 | for _ in range(grad_accum_every): 2260 | 2261 | if self.unconditional: 2262 | real_images = next(dl_iter) 2263 | else: 2264 | result = next(dl_iter) 2265 | assert isinstance(result, tuple), 'dataset should return a tuple of two items for text conditioned training, (images: Tensor, texts: List[str])' 2266 | real_images, texts = result 2267 | 2268 | all_real_images.append(real_images) 2269 | all_texts.extend(texts) 2270 | 2271 | # requires grad for real images, for gradient penalty 2272 | 2273 | real_images = real_images.to(self.device) 2274 | real_images.requires_grad_() 2275 | 2276 | real_images_rgbs = self.unwrapped_D.real_images_to_rgbs(real_images) 2277 | 2278 | # diff augment real images 2279 | 2280 | if exists(self.diff_augment): 2281 | real_images, real_images_rgbs = self.diff_augment(real_images, real_images_rgbs) 2282 | 2283 | # batch size 2284 | 2285 | batch_size = real_images.shape[0] 2286 | 2287 | # for discriminator training, fit upsampler and image synthesis logic under same function 2288 | 2289 | G_kwargs, maybe_text_kwargs = self.generate_kwargs(dl_iter, batch_size) 2290 | 2291 | # generator 2292 | 2293 | with torch.no_grad(), self.accelerator.autocast(): 2294 | images, rgbs = self.G( 2295 | **G_kwargs, 2296 | **maybe_text_kwargs, 2297 | return_all_rgbs = True 2298 | ) 2299 | 2300 | all_fake_images.append(images) 2301 | all_fake_rgbs.append(rgbs) 2302 | 2303 | # diff augment 2304 | 2305 | if exists(self.diff_augment): 2306 | images, rgbs = self.diff_augment(images, rgbs) 2307 | 2308 | # detach output of generator, as training discriminator only 2309 | 2310 | images.detach_() 2311 | 2312 | for rgb in rgbs: 2313 | rgb.detach_() 2314 | 2315 | # main divergence loss 2316 | 2317 | with self.accelerator.autocast(): 2318 | 2319 | fake_logits, fake_multiscale_logits, _ = self.D( 2320 | images, 2321 | rgbs, 2322 | **maybe_text_kwargs, 2323 | return_multiscale_outputs = calc_multiscale_loss, 2324 | calc_aux_loss = False 2325 | ) 2326 | 2327 | real_logits, real_multiscale_logits, aux_recon_losses = self.D( 2328 | real_images, 2329 | real_images_rgbs, 2330 | **maybe_text_kwargs, 2331 | return_multiscale_outputs = calc_multiscale_loss, 2332 | calc_aux_loss = True 2333 | ) 2334 | 2335 | divergence = discriminator_hinge_loss(real_logits, fake_logits) 2336 | total_divergence += (divergence.item() / grad_accum_every) 2337 | 2338 | # handle multi-scale divergence 2339 | 2340 | multiscale_divergence = 0. 2341 | 2342 | if self.multiscale_divergence_loss_weight > 0. and len(fake_multiscale_logits) > 0: 2343 | 2344 | for multiscale_fake, multiscale_real in zip(fake_multiscale_logits, real_multiscale_logits): 2345 | multiscale_loss = discriminator_hinge_loss(multiscale_real, multiscale_fake) 2346 | multiscale_divergence = multiscale_divergence + multiscale_loss 2347 | 2348 | total_multiscale_divergence += (multiscale_divergence.item() / grad_accum_every) 2349 | 2350 | # figure out gradient penalty if needed 2351 | 2352 | gp_loss = 0. 2353 | 2354 | if apply_gradient_penalty: 2355 | gp_loss = gradient_penalty( 2356 | real_images, 2357 | outputs = [real_logits, *real_multiscale_logits], 2358 | grad_output_weights = [1., *(self.multiscale_divergence_loss_weight,) * len(real_multiscale_logits)], 2359 | scaler = self.D_opt.scaler 2360 | ) 2361 | 2362 | if not torch.isnan(gp_loss): 2363 | total_gp_loss += (gp_loss.item() / grad_accum_every) 2364 | 2365 | # handle vision aided discriminator, if needed 2366 | 2367 | vd_loss = 0. 2368 | 2369 | if self.need_vision_aided_discriminator: 2370 | 2371 | fake_vision_aided_logits = self.VD(images, **maybe_text_kwargs) 2372 | real_vision_aided_logits, clip_encodings = self.VD(real_images, return_clip_encodings = True, **maybe_text_kwargs) 2373 | 2374 | for fake_logits, real_logits in zip(fake_vision_aided_logits, real_vision_aided_logits): 2375 | vd_loss = vd_loss + discriminator_hinge_loss(real_logits, fake_logits) 2376 | 2377 | total_vision_aided_divergence += (vd_loss.item() / grad_accum_every) 2378 | 2379 | # handle gradient penalty for vision aided discriminator 2380 | 2381 | if apply_gradient_penalty: 2382 | 2383 | vd_gp_loss = gradient_penalty( 2384 | clip_encodings, 2385 | outputs = real_vision_aided_logits, 2386 | grad_output_weights = [self.vision_aided_divergence_loss_weight] * len(real_vision_aided_logits), 2387 | scaler = self.VD_opt.scaler 2388 | ) 2389 | 2390 | if not torch.isnan(vd_gp_loss): 2391 | gp_loss = gp_loss + vd_gp_loss 2392 | 2393 | total_gp_loss += (vd_gp_loss.item() / grad_accum_every) 2394 | 2395 | # sum up losses 2396 | 2397 | total_loss = divergence + gp_loss 2398 | 2399 | if self.multiscale_divergence_loss_weight > 0.: 2400 | total_loss = total_loss + multiscale_divergence * self.multiscale_divergence_loss_weight 2401 | 2402 | if self.vision_aided_divergence_loss_weight > 0.: 2403 | total_loss = total_loss + vd_loss * self.vision_aided_divergence_loss_weight 2404 | 2405 | if self.discr_aux_recon_loss_weight > 0.: 2406 | aux_loss = sum(aux_recon_losses) 2407 | 2408 | total_aux_loss += (aux_loss.item() / grad_accum_every) 2409 | 2410 | total_loss = total_loss + aux_loss * self.discr_aux_recon_loss_weight 2411 | 2412 | # backwards 2413 | 2414 | self.accelerator.backward(total_loss / grad_accum_every) 2415 | 2416 | 2417 | # matching awareness loss 2418 | # strategy would be to rotate the texts by one and assume batch is shuffled enough for mismatched conditions 2419 | 2420 | if has_matching_awareness: 2421 | 2422 | # rotate texts 2423 | 2424 | all_texts = [*all_texts[1:], all_texts[0]] 2425 | all_texts = group_by_num_consecutive(texts, batch_size) 2426 | 2427 | zipped_data = zip( 2428 | all_fake_images, 2429 | all_fake_rgbs, 2430 | all_real_images, 2431 | all_texts 2432 | ) 2433 | 2434 | total_loss = 0. 2435 | 2436 | for fake_images, fake_rgbs, real_images, texts in zipped_data: 2437 | 2438 | with self.accelerator.autocast(): 2439 | fake_logits, *_ = self.D( 2440 | fake_images, 2441 | fake_rgbs, 2442 | texts = texts, 2443 | return_multiscale_outputs = False, 2444 | calc_aux_loss = False 2445 | ) 2446 | 2447 | real_images_rgbs = self.D.real_images_to_rgbs(real_images) 2448 | 2449 | real_logits, *_ = self.D( 2450 | real_images, 2451 | real_images_rgbs, 2452 | texts = texts, 2453 | return_multiscale_outputs = False, 2454 | calc_aux_loss = False 2455 | ) 2456 | 2457 | matching_loss = aux_matching_loss(real_logits, fake_logits) 2458 | 2459 | total_matching_aware_loss = (matching_loss.item() / grad_accum_every) 2460 | 2461 | loss = matching_loss * self.matching_awareness_loss_weight 2462 | 2463 | self.accelerator.backward(loss / grad_accum_every) 2464 | 2465 | self.D_opt.step() 2466 | 2467 | if self.need_vision_aided_discriminator: 2468 | self.VD_opt.step() 2469 | 2470 | return TrainDiscrLosses( 2471 | total_divergence, 2472 | total_multiscale_divergence, 2473 | total_vision_aided_divergence, 2474 | total_matching_aware_loss, 2475 | total_gp_loss, 2476 | total_aux_loss 2477 | ) 2478 | 2479 | def train_generator_step( 2480 | self, 2481 | batch_size = None, 2482 | dl_iter: Iterable | None = None, 2483 | grad_accum_every = 1, 2484 | calc_multiscale_loss = True 2485 | ): 2486 | total_divergence = 0. 2487 | total_multiscale_divergence = 0. if calc_multiscale_loss else None 2488 | total_vd_divergence = 0. 2489 | contrastive_loss = 0. 2490 | 2491 | self.G.train() 2492 | self.D.train() 2493 | 2494 | self.D_opt.zero_grad() 2495 | self.G_opt.zero_grad() 2496 | 2497 | all_images = [] 2498 | all_texts = [] 2499 | 2500 | for _ in range(grad_accum_every): 2501 | 2502 | # generator 2503 | 2504 | G_kwargs, maybe_text_kwargs = self.generate_kwargs(dl_iter, batch_size) 2505 | 2506 | with self.accelerator.autocast(): 2507 | images, rgbs = self.G( 2508 | **G_kwargs, 2509 | **maybe_text_kwargs, 2510 | return_all_rgbs = True 2511 | ) 2512 | 2513 | # diff augment 2514 | 2515 | if exists(self.diff_augment): 2516 | images, rgbs = self.diff_augment(images, rgbs) 2517 | 2518 | # accumulate all images and texts for maybe contrastive loss 2519 | 2520 | if self.need_contrastive_loss: 2521 | all_images.append(images) 2522 | all_texts.extend(maybe_text_kwargs['texts']) 2523 | 2524 | # discriminator 2525 | 2526 | logits, multiscale_logits, _ = self.D( 2527 | images, 2528 | rgbs, 2529 | **maybe_text_kwargs, 2530 | return_multiscale_outputs = calc_multiscale_loss, 2531 | calc_aux_loss = False 2532 | ) 2533 | 2534 | # generator hinge loss discriminator and multiscale 2535 | 2536 | divergence = generator_hinge_loss(logits) 2537 | 2538 | total_divergence += (divergence.item() / grad_accum_every) 2539 | 2540 | total_loss = divergence 2541 | 2542 | if self.multiscale_divergence_loss_weight > 0. and len(multiscale_logits) > 0: 2543 | multiscale_divergence = 0. 2544 | 2545 | for multiscale_logit in multiscale_logits: 2546 | multiscale_divergence = multiscale_divergence + generator_hinge_loss(multiscale_logit) 2547 | 2548 | total_multiscale_divergence += (multiscale_divergence.item() / grad_accum_every) 2549 | 2550 | total_loss = total_loss + multiscale_divergence * self.multiscale_divergence_loss_weight 2551 | 2552 | # vision aided generator hinge loss 2553 | 2554 | if self.need_vision_aided_discriminator: 2555 | vd_loss = 0. 2556 | 2557 | logits = self.VD(images, **maybe_text_kwargs) 2558 | 2559 | for logit in logits: 2560 | vd_loss = vd_loss + generator_hinge_loss(logit) 2561 | 2562 | total_vd_divergence += (vd_loss.item() / grad_accum_every) 2563 | 2564 | total_loss = total_loss + vd_loss * self.vision_aided_divergence_loss_weight 2565 | 2566 | self.accelerator.backward(total_loss / grad_accum_every, retain_graph = self.need_contrastive_loss) 2567 | 2568 | # if needs the generator contrastive loss 2569 | # gather up all images and texts and calculate it 2570 | 2571 | if self.need_contrastive_loss: 2572 | all_images = torch.cat(all_images, dim = 0) 2573 | 2574 | contrastive_loss = aux_clip_loss( 2575 | clip = self.G.text_encoder.clip, 2576 | texts = all_texts, 2577 | images = all_images 2578 | ) 2579 | 2580 | self.accelerator.backward(contrastive_loss * self.generator_contrastive_loss_weight) 2581 | 2582 | # generator optimizer step 2583 | 2584 | self.G_opt.step() 2585 | 2586 | # update exponentially moving averaged generator 2587 | 2588 | self.accelerator.wait_for_everyone() 2589 | 2590 | if self.is_main and self.has_ema_generator: 2591 | self.G_ema.update() 2592 | 2593 | return TrainGenLosses( 2594 | total_divergence, 2595 | total_multiscale_divergence, 2596 | total_vd_divergence, 2597 | contrastive_loss 2598 | ) 2599 | 2600 | def sample(self, model, dl_iter, batch_size): 2601 | G_kwargs, maybe_text_kwargs = self.generate_kwargs(dl_iter, batch_size) 2602 | 2603 | with self.accelerator.autocast(): 2604 | generator_output = model(**G_kwargs, **maybe_text_kwargs) 2605 | 2606 | if not self.train_upsampler: 2607 | return generator_output 2608 | 2609 | output_size = generator_output.shape[-1] 2610 | lowres_image = G_kwargs['lowres_image'] 2611 | lowres_image = F.interpolate(lowres_image, (output_size, output_size)) 2612 | 2613 | return torch.cat([lowres_image, generator_output]) 2614 | 2615 | @torch.inference_mode() 2616 | def save_sample( 2617 | self, 2618 | batch_size, 2619 | dl_iter = None 2620 | ): 2621 | milestone = self.steps.item() // self.save_and_sample_every 2622 | nrow_mult = 2 if self.train_upsampler else 1 2623 | batches = num_to_groups(self.num_samples, batch_size) 2624 | 2625 | if self.train_upsampler: 2626 | dl_iter = default(self.sample_upsampler_dl_iter, dl_iter) 2627 | 2628 | assert exists(dl_iter) 2629 | 2630 | sample_models_and_output_file_name = [(self.unwrapped_G, f'sample-{milestone}.png')] 2631 | 2632 | if self.has_ema_generator: 2633 | sample_models_and_output_file_name.append((self.G_ema, f'ema-sample-{milestone}.png')) 2634 | 2635 | for model, filename in sample_models_and_output_file_name: 2636 | model.eval() 2637 | 2638 | all_images_list = list(map(lambda n: self.sample(model, dl_iter, n), batches)) 2639 | all_images = torch.cat(all_images_list, dim = 0) 2640 | 2641 | all_images.clamp_(0., 1.) 2642 | 2643 | utils.save_image( 2644 | all_images, 2645 | str(self.results_folder / filename), 2646 | nrow = int(sqrt(self.num_samples)) * nrow_mult 2647 | ) 2648 | 2649 | # Possible to do: Include some metric to save if improved, include some sampler dict text entries 2650 | self.save(str(self.model_folder / f'model-{milestone}.ckpt')) 2651 | 2652 | @beartype 2653 | def forward( 2654 | self, 2655 | *, 2656 | steps, 2657 | grad_accum_every = 1 2658 | ): 2659 | assert exists(self.train_dl), 'you need to set the dataloader by running .set_dataloader(dl: Dataloader)' 2660 | 2661 | batch_size = self.train_dl_batch_size 2662 | dl_iter = cycle(self.train_dl) 2663 | 2664 | last_gp_loss = 0. 2665 | last_multiscale_d_loss = 0. 2666 | last_multiscale_g_loss = 0. 2667 | 2668 | for _ in tqdm(range(steps), initial = self.steps.item()): 2669 | steps = self.steps.item() 2670 | is_first_step = steps == 1 2671 | 2672 | apply_gradient_penalty = self.apply_gradient_penalty_every > 0 and divisible_by(steps, self.apply_gradient_penalty_every) 2673 | calc_multiscale_loss = self.calc_multiscale_loss_every > 0 and divisible_by(steps, self.calc_multiscale_loss_every) 2674 | 2675 | ( 2676 | d_loss, 2677 | multiscale_d_loss, 2678 | vision_aided_d_loss, 2679 | matching_aware_loss, 2680 | gp_loss, 2681 | recon_loss 2682 | ) = self.train_discriminator_step( 2683 | dl_iter = dl_iter, 2684 | grad_accum_every = grad_accum_every, 2685 | apply_gradient_penalty = apply_gradient_penalty, 2686 | calc_multiscale_loss = calc_multiscale_loss 2687 | ) 2688 | 2689 | self.accelerator.wait_for_everyone() 2690 | 2691 | ( 2692 | g_loss, 2693 | multiscale_g_loss, 2694 | vision_aided_g_loss, 2695 | contrastive_loss 2696 | ) = self.train_generator_step( 2697 | dl_iter = dl_iter, 2698 | batch_size = batch_size, 2699 | grad_accum_every = grad_accum_every, 2700 | calc_multiscale_loss = calc_multiscale_loss 2701 | ) 2702 | 2703 | if exists(gp_loss): 2704 | last_gp_loss = gp_loss 2705 | 2706 | if exists(multiscale_d_loss): 2707 | last_multiscale_d_loss = multiscale_d_loss 2708 | 2709 | if exists(multiscale_g_loss): 2710 | last_multiscale_g_loss = multiscale_g_loss 2711 | 2712 | if is_first_step or divisible_by(steps, self.log_steps_every): 2713 | 2714 | losses = ( 2715 | ('G', g_loss), 2716 | ('MSG', last_multiscale_g_loss), 2717 | ('VG', vision_aided_g_loss), 2718 | ('D', d_loss), 2719 | ('MSD', last_multiscale_d_loss), 2720 | ('VD', vision_aided_d_loss), 2721 | ('GP', last_gp_loss), 2722 | ('SSL', recon_loss), 2723 | ('CL', contrastive_loss), 2724 | ('MAL', matching_aware_loss) 2725 | ) 2726 | 2727 | losses_str = ' | '.join([f'{loss_name}: {loss:.2f}' for loss_name, loss in losses]) 2728 | 2729 | self.print(losses_str) 2730 | 2731 | self.accelerator.wait_for_everyone() 2732 | 2733 | if self.is_main and (is_first_step or divisible_by(steps, self.save_and_sample_every) or (steps <= self.early_save_thres_steps and divisible_by(steps, self.early_save_and_sample_every))): 2734 | self.save_sample(batch_size, dl_iter) 2735 | 2736 | self.steps += 1 2737 | 2738 | self.print(f'complete {steps} training steps') 2739 | --------------------------------------------------------------------------------