├── .gitignore ├── LICENSE.txt ├── README.md ├── examples ├── cifar10.py └── stl10.py ├── setup.py └── simple_diffusion_model ├── __init__.py ├── diffusion_wrapper.py └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Charles Foster 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # simple-diffusion-model 2 | Pedagogical codebase for a simplified score-based generative model design, with training loop 3 | -------------------------------------------------------------------------------- /examples/cifar10.py: -------------------------------------------------------------------------------- 1 | # This code was adapted from lucidrains existing `x-transformers` repository. 2 | from simple_diffusion_model import Model 3 | from simple_diffusion_model import DiffusionWrapper 4 | 5 | import tqdm 6 | import time 7 | import wandb 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torch.nn.functional as F 12 | import torchvision.transforms as transforms 13 | import numpy as np 14 | 15 | import torch_fidelity 16 | from torchvision.datasets import CIFAR10 17 | from torch.utils.data import DataLoader 18 | 19 | # constants 20 | 21 | NUM_BATCHES = int(1e5) 22 | BATCH_SIZE = 4 23 | GRADIENT_ACCUMULATE_EVERY = 4 24 | LEARNING_RATE = 1e-4 25 | VALIDATE_EVERY = 100 26 | GENERATE_EVERY = 500 27 | EVALUATE = False 28 | EVALUATE_EVERY = 100000 29 | EVALUATE_BATCH_SIZE = 50 30 | 31 | # helpers 32 | 33 | def cycle(loader): 34 | while True: 35 | for data in loader: 36 | yield data 37 | 38 | def scale(x): 39 | return x * 2 - 1 40 | 41 | def rescale(x): 42 | return (x + 1) / 2 43 | 44 | class FidelityWrapper(nn.Module): 45 | def __init__(self, generator): 46 | super().__init__() 47 | self.generator = generator 48 | 49 | def forward(self, z): 50 | out = self.generator.generate(len(z)) 51 | return rescale(out).mul(255).round().clamp(0, 255).to(torch.uint8) 52 | 53 | def train(): 54 | wandb.init(project="simple-diffusion-model") 55 | 56 | model = DiffusionWrapper(Model(), input_shape=(3, 32, 32)) 57 | model.cuda() 58 | 59 | train_dataset = CIFAR10(root='./data', train=True, transform=transforms.ToTensor(), download=True) 60 | val_dataset = CIFAR10(root='./data', train=False, transform=transforms.ToTensor(), download=True) 61 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE)) 62 | val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE)) 63 | 64 | # optimizer 65 | 66 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 67 | 68 | # training 69 | 70 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'): 71 | start_time = time.time() 72 | model.train() 73 | 74 | for __ in range(GRADIENT_ACCUMULATE_EVERY): 75 | batch, _ = next(train_loader) 76 | loss = model(scale(batch)) 77 | loss.backward() 78 | 79 | end_time = time.time() 80 | print(f'training loss: {loss.item()}') 81 | train_loss = loss.item() 82 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 83 | optim.step() 84 | optim.zero_grad() 85 | 86 | 87 | if i % VALIDATE_EVERY == 0: 88 | model.eval() 89 | with torch.no_grad(): 90 | batch, _ = next(val_loader) 91 | loss = model(scale(batch)) 92 | print(f'validation loss: {loss.item()}') 93 | val_loss = loss.item() 94 | 95 | if i % GENERATE_EVERY == 0: 96 | model.eval() 97 | samples = model.generate(1) 98 | image_array = rescale(samples) 99 | images = wandb.Image(image_array, caption="Generated") 100 | wandb.log({"examples": images}, commit=False) 101 | 102 | logs = {} 103 | 104 | logs = { 105 | **logs, 106 | 'iter': i, 107 | 'step_time': end_time - start_time, 108 | 'train_loss': train_loss, 109 | 'val_loss': val_loss, 110 | } 111 | 112 | wandb.log(logs) 113 | 114 | if EVALUATE: 115 | if (i % EVALUATE_EVERY == 0 and i != 0) or i == NUM_BATCHES - 1: 116 | model.eval() 117 | with torch.no_grad(): 118 | wrapped_inner = FidelityWrapper(model) 119 | wrapped = torch_fidelity.GenerativeModelModuleWrapper(wrapped_inner, 120 | 1, 'normal', 0) 121 | metrics = torch_fidelity.calculate_metrics(input1=wrapped, 122 | input1_model_num_samples=10000, 123 | input2='cifar10-train', 124 | batch_size=EVALUATE_BATCH_SIZE, 125 | fid=True, 126 | verbose=True) 127 | wandb.log({"fid": metrics['frechet_inception_distance']}) 128 | 129 | wandb.finish() 130 | 131 | if __name__ == '__main__': 132 | train() 133 | -------------------------------------------------------------------------------- /examples/stl10.py: -------------------------------------------------------------------------------- 1 | # This code was adapted from lucidrains existing `x-transformers` repository. 2 | from simple_diffusion_model import Model 3 | from simple_diffusion_model import DiffusionWrapper 4 | 5 | import tqdm 6 | import time 7 | import wandb 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torch.nn.functional as F 12 | import torchvision.transforms as transforms 13 | import numpy as np 14 | 15 | from torchvision.datasets import STL10 16 | from torch.utils.data import DataLoader 17 | 18 | # constants 19 | 20 | NUM_BATCHES = int(1e5) 21 | BATCH_SIZE = 4 22 | GRADIENT_ACCUMULATE_EVERY = 4 23 | LEARNING_RATE = 1e-4 24 | VALIDATE_EVERY = 100 25 | GENERATE_EVERY = 500 26 | 27 | # helpers 28 | 29 | def cycle(loader): 30 | while True: 31 | for data in loader: 32 | yield data 33 | 34 | def scale(x): 35 | return x * 2 - 1 36 | 37 | def rescale(x): 38 | return (x + 1) / 2 39 | 40 | def train(): 41 | wandb.init(project="simple-diffusion-model") 42 | 43 | model = DiffusionWrapper(Model(), input_shape=(3, 96, 96)) 44 | model.cuda() 45 | 46 | train_dataset = STL10(root='./data', split='train', transform=transforms.ToTensor(), download=True) 47 | val_dataset = STL10(root='./data', split='test', transform=transforms.ToTensor(), download=True) 48 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE)) 49 | val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE)) 50 | 51 | # optimizer 52 | 53 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 54 | 55 | # training 56 | 57 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'): 58 | start_time = time.time() 59 | model.train() 60 | 61 | for __ in range(GRADIENT_ACCUMULATE_EVERY): 62 | batch, _ = next(train_loader) 63 | loss = model(scale(batch)) 64 | loss.backward() 65 | 66 | end_time = time.time() 67 | print(f'training loss: {loss.item()}') 68 | train_loss = loss.item() 69 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 70 | optim.step() 71 | optim.zero_grad() 72 | 73 | 74 | if i % VALIDATE_EVERY == 0: 75 | model.eval() 76 | with torch.no_grad(): 77 | batch, _ = next(val_loader) 78 | loss = model(scale(batch)) 79 | print(f'validation loss: {loss.item()}') 80 | val_loss = loss.item() 81 | 82 | if i % GENERATE_EVERY == 0: 83 | model.eval() 84 | samples = model.generate(1) 85 | image_array = rescale(samples) 86 | images = wandb.Image(image_array, caption="Generated") 87 | wandb.log({"examples": images}, commit=False) 88 | 89 | logs = {} 90 | 91 | logs = { 92 | **logs, 93 | 'iter': i, 94 | 'step_time': end_time - start_time, 95 | 'train_loss': train_loss, 96 | 'val_loss': val_loss, 97 | } 98 | 99 | wandb.log(logs) 100 | 101 | wandb.finish() 102 | 103 | if __name__ == '__main__': 104 | train() 105 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'simple-diffusion-model', 5 | packages = find_packages(exclude=['examples']), 6 | version = '0.0.1', 7 | license='BSD 3-Clause', 8 | description = 'Simple Diffusion Model', 9 | author = 'Charles Foster', 10 | author_email = 'cfoster0@alumni.stanford.edu', 11 | url = 'https://github.com/cfoster0/simple-diffusion-model', 12 | keywords = [ 13 | 'artificial intelligence', 14 | ], 15 | install_requires=[ 16 | 'einops', 17 | 'numpy', 18 | 'torch', 19 | 'torch-fidelity', 20 | 'torchvision', 21 | 'wandb' 22 | ], 23 | classifiers=[ 24 | 'Development Status :: 4 - Beta', 25 | 'Intended Audience :: Developers', 26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 27 | 'License :: OSI Approved :: MIT License', 28 | 'Programming Language :: Python :: 3.6', 29 | ], 30 | ) 31 | -------------------------------------------------------------------------------- /simple_diffusion_model/__init__.py: -------------------------------------------------------------------------------- 1 | from simple_diffusion_model.model import Model 2 | from simple_diffusion_model.diffusion_wrapper import DiffusionWrapper 3 | -------------------------------------------------------------------------------- /simple_diffusion_model/diffusion_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from torch import einsum 6 | from torch.nn import Module 7 | 8 | def beta_schedule(timesteps): 9 | return np.linspace(1e-4, 0.02, timesteps).astype('float32') 10 | 11 | class DiffusionWrapper(Module): 12 | def __init__(self, net, input_shape, timesteps=1000): 13 | super().__init__() 14 | self.net = net 15 | self.input_shape = input_shape 16 | self.timesteps = timesteps 17 | self.register_buffer('beta_schedule', torch.from_numpy(beta_schedule(timesteps))) 18 | self.register_buffer('alpha_schedule', torch.from_numpy(1.0 - beta_schedule(timesteps))) 19 | self.register_buffer('alpha_bar_schedule', torch.from_numpy(np.cumprod(1.0 - beta_schedule(timesteps)))) 20 | 21 | @torch.no_grad() 22 | def generate(self, n, *args, **kwargs): 23 | was_training = self.net.training 24 | self.net.eval() 25 | device = self.beta_schedule.device 26 | x = torch.randn((n,) + self.input_shape, device=device) 27 | for t in reversed(range(self.timesteps)): 28 | timestep = torch.full((n,), t, device=device) 29 | x = (self.alpha_schedule[t] ** -0.5) * (x - ((1.0 - self.alpha_schedule[t]) * (1.0 - self.alpha_bar_schedule[t]) ** -0.5) * self.net(x, timestep, *args, **kwargs)) 30 | if t > 0: 31 | z = torch.randn((n,) + self.input_shape, device=device) 32 | x += (self.beta_schedule[t] ** 0.5) * z 33 | self.net.train(was_training) 34 | return x 35 | 36 | def forward(self, x, *args, **kwargs): 37 | device = self.beta_schedule.device 38 | x = x.to(device) 39 | noise = torch.randn(x.shape, device=device) 40 | timestep = torch.randint(0, self.timesteps, (x.shape[0],), device=device) 41 | alpha_bar = torch.gather(self.alpha_bar_schedule, 0, timestep) 42 | noised = einsum("b , b ... -> b ...", alpha_bar ** 0.5, x) + einsum("b , b ... -> b ...", (1.0 - alpha_bar) ** 0.5, noise) 43 | predicted_noise = self.net(noised, timestep, *args, **kwargs) 44 | loss = F.mse_loss(predicted_noise, noise) 45 | return loss 46 | 47 | 48 | -------------------------------------------------------------------------------- /simple_diffusion_model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch import einsum 6 | from einops import rearrange, reduce, repeat 7 | from typing import Sequence, Tuple, Callable 8 | from torch.nn import Module, ModuleList, Linear, LayerNorm, GroupNorm, Conv2d 9 | 10 | 11 | class ConditionNHWC(Module): 12 | def __init__(self, out_features): 13 | super().__init__() 14 | inv_freq = 1. / torch.logspace(-5, 5, out_features // 2) 15 | self.register_buffer('inv_freq', inv_freq) 16 | 17 | def forward(self, x, condition): 18 | freqs = torch.outer(condition, self.inv_freq) # c = d / 2 19 | posemb = repeat(freqs, "b c -> b (2 c)") 20 | odds, evens = rearrange(x, '... (j c) -> ... j c', j = 2).unbind(dim = -2) 21 | rotated = torch.cat((-evens, odds), dim = -1) 22 | return einsum("b ... d , b ... d -> b ... d", x, posemb.cos()) + einsum("b ... d , b ... d -> b ... d", rotated, posemb.sin()) 23 | 24 | class ConditionNCHW(Module): 25 | def __init__(self, out_features): 26 | super().__init__() 27 | inv_freq = 1. / torch.logspace(-5, 5, out_features // 2) 28 | self.register_buffer('inv_freq', inv_freq) 29 | 30 | def forward(self, x, condition): 31 | freqs = torch.outer(condition, self.inv_freq) # c = d / 2 32 | posemb = repeat(freqs, "b c -> b (2 c)") 33 | odds, evens = rearrange(x, 'b (j c) ... -> b j c ...', j = 2).unbind(dim = 1) 34 | rotated = torch.cat((-evens, odds), dim = 1) 35 | return einsum("b d ... , b d ... -> b d ...", x, posemb.cos()) + einsum("b d ... , b d ... -> b d ...", rotated, posemb.sin()) 36 | 37 | class SelfAttention(Module): 38 | def __init__(self, head_dim: int, heads: int): 39 | super().__init__() 40 | hidden_dim = head_dim * heads 41 | self.head_dim = head_dim 42 | self.heads = heads 43 | self.hidden_dim = hidden_dim 44 | self.in_proj = Linear(hidden_dim, hidden_dim * 3) 45 | self.out_proj = Linear(hidden_dim, hidden_dim) 46 | 47 | def forward(self, x): 48 | b, h, w, d = x.shape 49 | x = rearrange(x, "b h w d -> b (h w) d") 50 | p = self.in_proj(x) 51 | q, k, v = torch.split(p, [ 52 | self.hidden_dim, 53 | self.hidden_dim, 54 | self.hidden_dim, 55 | ], -1) 56 | (q, k, v) = map(lambda x: rearrange(x, "b i (h d) -> b i h d", h=self.heads), (q, k, v)) 57 | a = einsum("b i h d, b j h d -> b h i j", q, k) * (self.head_dim ** -0.5) 58 | a = F.softmax(a, dim=-1) 59 | o = einsum("b h i j, b j h d -> b i h d", a, v) 60 | o = rearrange(o, "b i h d -> b i (h d)") 61 | x = self.out_proj(o) 62 | x = rearrange(x, "b (h w) d -> b h w d", h=h, w=w) 63 | return x 64 | 65 | class ConditionedSequential(Module): 66 | def __init__(self, *layers): 67 | super().__init__() 68 | self.layers = ModuleList(layers) 69 | 70 | def forward(self, x, *args, **kwargs): 71 | for layer in self.layers: 72 | x = layer(x, *args, **kwargs) 73 | return x 74 | 75 | class ResidualBlock(Module): 76 | def __init__(self, in_channels, out_channels): 77 | super().__init__() 78 | self.condition = ConditionNCHW(out_channels) if out_channels % 2 == 0 else None 79 | self.layers = ModuleList([ 80 | Conv2d(in_channels, out_channels, (1, 1)), 81 | Conv2d(out_channels, out_channels, (3, 3), stride=1, padding=1), 82 | Conv2d(out_channels, out_channels, (3, 3), stride=1, padding=1), 83 | ]) 84 | self.norm = GroupNorm(1, out_channels) 85 | 86 | def forward(self, x, condition): 87 | for i, layer in enumerate(self.layers): 88 | if i == 0: 89 | x = layer(x) 90 | else: 91 | if self.condition: 92 | x = x + layer(self.condition(F.gelu(self.norm(x)), condition=condition)) 93 | else: 94 | x = x + layer(F.gelu(self.norm(x))) 95 | return x 96 | 97 | class BottleneckBlock(Module): 98 | def __init__(self, channels): 99 | super().__init__() 100 | self.condition = ConditionNHWC(channels) 101 | self.layers = ModuleList([SelfAttention(channels // 8, 8) for _ in range(4)]) 102 | self.norm = LayerNorm(channels) 103 | 104 | def forward(self, x, condition): 105 | x = rearrange(x, "b c h w -> b h w c") 106 | for layer in self.layers: 107 | x = x + layer(self.condition(self.norm(x), condition=condition)) 108 | x = rearrange(x, "b h w c -> b c h w") 109 | return x 110 | 111 | class Bicubic(Module): 112 | def __init__(self, scale_factor): 113 | super().__init__() 114 | self.scale_factor = scale_factor 115 | 116 | def forward(self, x, *args, **kwargs): 117 | return F.interpolate(x, scale_factor=self.scale_factor, mode='bicubic') 118 | 119 | class UNet(Module): 120 | def __init__(self, encoders_decoders: Sequence[Tuple[Module, Module]], bottleneck: Module): 121 | super().__init__() 122 | outer_pair, *inner_remaining = encoders_decoders 123 | self.encoder, self.decoder = outer_pair 124 | if inner_remaining: 125 | self.bottleneck = UNet(inner_remaining, bottleneck) 126 | else: 127 | self.bottleneck = bottleneck 128 | 129 | def forward(self, x, condition): 130 | encoded = self.encoder(x, condition=condition) 131 | bottlenecked = self.bottleneck(encoded, condition=condition) 132 | return self.decoder(torch.cat([encoded, bottlenecked], dim=1), condition=condition) 133 | 134 | class Model(Module): 135 | def __init__(self): 136 | super().__init__() 137 | self.net = UNet([ 138 | (ResidualBlock(3, 64), ResidualBlock(64+64, 3)), 139 | (ConditionedSequential(Bicubic(1/2), ResidualBlock(64, 128)), ConditionedSequential(ResidualBlock(128+128, 64), Bicubic(2))), 140 | (ConditionedSequential(Bicubic(1/2), ResidualBlock(128, 256)), ConditionedSequential(ResidualBlock(256+256, 128), Bicubic(2))), 141 | (ConditionedSequential(Bicubic(1/2), ResidualBlock(256, 512)), ConditionedSequential(ResidualBlock(512+512, 256), Bicubic(2))), 142 | ], BottleneckBlock(512) 143 | ) 144 | 145 | def forward(self, x, condition): 146 | return self.net(x, condition=condition) 147 | --------------------------------------------------------------------------------