├── .gitignore ├── LICENSE ├── README.md ├── setup.py └── simple_diffusion ├── __init__.py └── modeling_uvit.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Suraj Patil 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [WiP]simple-diffusion 2 | An implementation of [simple diffusion](https://arxiv.org/abs/2301.11093): End-to-end diffusion for high resolution images in PyTorch (and maybe JAX) 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | import os 3 | 4 | here = os.path.abspath(os.path.dirname(__file__)) 5 | 6 | with open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 7 | long_description = f.read() 8 | 9 | setuptools.setup( 10 | name="simple-diffusion", 11 | packages=setuptools.find_packages(), 12 | version="0.0.1", 13 | license="MIT", 14 | description="simple diffusion: End-to-end diffusion for high resolution images in PyTorch and JAX", 15 | long_description=long_description, 16 | long_description_content_type="text/markdown", 17 | author = "Suraj Patil", 18 | author_email = "surajp815@gmail.com", 19 | url = "https://github.com/patil-suraj/simple-diffusion", 20 | keywords = [ 21 | 'artificial intelligence', 22 | 'deep learning', 23 | 'transformers', 24 | 'attention mechanism', 25 | 'text-to-image' 26 | ], 27 | install_requires=[ 28 | "accelerate", 29 | "diffusers", 30 | "pillow", 31 | "sentencepiece", 32 | "torch>=1.6", 33 | "transformers", 34 | "torch>=1.6", 35 | "torchvision", 36 | ], 37 | classifiers=[ 38 | 'Development Status :: 4 - Beta', 39 | 'Intended Audience :: Developers', 40 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 41 | 'License :: OSI Approved :: MIT License', 42 | 'Programming Language :: Python :: 3.6', 43 | ], 44 | ) -------------------------------------------------------------------------------- /simple_diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patil-suraj/simple-diffusion/ec676aba7d618e01f373587dcfe2d6cd62818dbf/simple_diffusion/__init__.py -------------------------------------------------------------------------------- /simple_diffusion/modeling_uvit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | class Mish(torch.nn.Module): 6 | def forward(self, hidden_states): 7 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 8 | 9 | class Upsample2D(nn.Module): 10 | """ 11 | An upsampling layer with an optional convolution. 12 | 13 | Parameters: 14 | channels: channels in the inputs and outputs. 15 | use_conv: a bool determining if a convolution is applied. 16 | use_conv_transpose: 17 | out_channels: 18 | """ 19 | 20 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None): 21 | super().__init__() 22 | self.channels = channels 23 | self.out_channels = out_channels or channels 24 | self.use_conv = use_conv 25 | self.use_conv_transpose = use_conv_transpose 26 | 27 | if use_conv_transpose: 28 | self.conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) 29 | elif use_conv: 30 | self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) 31 | 32 | def forward(self, hidden_states, output_size=None): 33 | assert hidden_states.shape[1] == self.channels 34 | 35 | if self.use_conv_transpose: 36 | return self.conv(hidden_states) 37 | 38 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 39 | if hidden_states.shape[0] >= 64: 40 | hidden_states = hidden_states.contiguous() 41 | 42 | # if `output_size` is passed we force the interpolation output 43 | # size and do not make use of `scale_factor=2` 44 | if output_size is None: 45 | hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") 46 | else: 47 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 48 | 49 | if self.use_conv: 50 | hidden_states = self.conv(hidden_states) 51 | 52 | return hidden_states 53 | 54 | 55 | class Downsample2D(nn.Module): 56 | """ 57 | A downsampling layer with an optional convolution. 58 | 59 | Parameters: 60 | channels: channels in the inputs and outputs. 61 | use_conv: a bool determining if a convolution is applied. 62 | out_channels: 63 | padding: 64 | """ 65 | 66 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1): 67 | super().__init__() 68 | self.channels = channels 69 | self.out_channels = out_channels or channels 70 | self.use_conv = use_conv 71 | self.padding = padding 72 | stride = 2 73 | 74 | if use_conv: 75 | self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 76 | else: 77 | assert self.channels == self.out_channels 78 | self.conv = nn.AvgPool2d(kernel_size=stride, stride=stride) 79 | 80 | def forward(self, hidden_states): 81 | assert hidden_states.shape[1] == self.channels 82 | if self.use_conv and self.padding == 0: 83 | pad = (0, 1, 0, 1) 84 | hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) 85 | assert hidden_states.shape[1] == self.channels 86 | hidden_states = self.conv(hidden_states) 87 | return hidden_states 88 | 89 | 90 | class ResnetBlock2D(nn.Module): 91 | def __init__( 92 | self, 93 | in_channels, 94 | out_channels=None, 95 | conv_shortcut=False, 96 | dropout=0.0, 97 | temb_channels=512, 98 | groups=32, 99 | groups_out=None, 100 | pre_norm=True, 101 | eps=1e-6, 102 | time_embedding_norm="default", 103 | output_scale_factor=1.0, 104 | use_in_shortcut=None, 105 | ): 106 | super().__init__() 107 | self.pre_norm = pre_norm 108 | self.pre_norm = True 109 | self.in_channels = in_channels 110 | out_channels = in_channels if out_channels is None else out_channels 111 | self.out_channels = out_channels 112 | self.use_conv_shortcut = conv_shortcut 113 | self.time_embedding_norm = time_embedding_norm 114 | self.output_scale_factor = output_scale_factor 115 | 116 | if groups_out is None: 117 | groups_out = groups 118 | 119 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 120 | 121 | self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 122 | 123 | if temb_channels is not None: 124 | if self.time_embedding_norm == "default": 125 | time_emb_proj_out_channels = out_channels 126 | elif self.time_embedding_norm == "scale_shift": 127 | time_emb_proj_out_channels = out_channels * 2 128 | else: 129 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 130 | 131 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 132 | else: 133 | self.time_emb_proj = None 134 | 135 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 136 | self.dropout = torch.nn.Dropout(dropout) 137 | self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 138 | 139 | self.nonlinearity = nn.SiLU() 140 | 141 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 142 | 143 | self.conv_shortcut = None 144 | if self.use_in_shortcut: 145 | self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 146 | 147 | def forward(self, input_tensor, temb): 148 | hidden_states = input_tensor 149 | 150 | hidden_states = self.norm1(hidden_states) 151 | hidden_states = self.nonlinearity(hidden_states) 152 | hidden_states = self.conv1(hidden_states) 153 | 154 | if temb is not None: 155 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] 156 | if temb is not None and self.time_embedding_norm == "default": 157 | hidden_states = hidden_states + temb 158 | 159 | hidden_states = self.norm2(hidden_states) 160 | 161 | if temb is not None and self.time_embedding_norm == "scale_shift": 162 | scale, shift = torch.chunk(temb, 2, dim=1) 163 | hidden_states = hidden_states * (1 + scale) + shift 164 | 165 | hidden_states = self.nonlinearity(hidden_states) 166 | hidden_states = self.dropout(hidden_states) 167 | hidden_states = self.conv2(hidden_states) 168 | 169 | if self.conv_shortcut is not None: 170 | input_tensor = self.conv_shortcut(input_tensor) 171 | 172 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 173 | 174 | return output_tensor 175 | 176 | 177 | class DownBlock2D(nn.Module): 178 | def __init__( 179 | self, 180 | in_channels: int, 181 | out_channels: int, 182 | temb_channels: int, 183 | dropout: float = 0.0, 184 | num_layers: int = 1, 185 | resnet_eps: float = 1e-6, 186 | resnet_time_scale_shift: str = "default", 187 | resnet_groups: int = 32, 188 | resnet_pre_norm: bool = True, 189 | output_scale_factor=1.0, 190 | add_downsample=True, 191 | downsample_padding=1, 192 | ): 193 | super().__init__() 194 | resnets = [] 195 | 196 | for i in range(num_layers): 197 | in_channels = in_channels if i == 0 else out_channels 198 | resnets.append( 199 | ResnetBlock2D( 200 | in_channels=in_channels, 201 | out_channels=out_channels, 202 | temb_channels=temb_channels, 203 | eps=resnet_eps, 204 | groups=resnet_groups, 205 | dropout=dropout, 206 | time_embedding_norm=resnet_time_scale_shift, 207 | output_scale_factor=output_scale_factor, 208 | pre_norm=resnet_pre_norm, 209 | ) 210 | ) 211 | 212 | self.resnets = nn.ModuleList(resnets) 213 | 214 | self.downsample = None 215 | if add_downsample: 216 | self.downsample = Downsample2D( 217 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding 218 | ) 219 | 220 | self.gradient_checkpointing = False 221 | 222 | def forward(self, hidden_states, temb=None): 223 | output_states = () 224 | 225 | for resnet in self.resnets: 226 | if self.training and self.gradient_checkpointing: 227 | def create_custom_forward(module): 228 | def custom_forward(*inputs): 229 | return module(*inputs) 230 | return custom_forward 231 | 232 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 233 | else: 234 | hidden_states = resnet(hidden_states, temb) 235 | 236 | output_states += (hidden_states,) 237 | 238 | if self.downsample is not None: 239 | hidden_states = self.downsample(hidden_states) 240 | output_states += (hidden_states,) 241 | 242 | return hidden_states, output_states 243 | 244 | 245 | class UpBlock2D(nn.Module): 246 | def __init__( 247 | self, 248 | in_channels: int, 249 | prev_output_channel: int, 250 | out_channels: int, 251 | temb_channels: int, 252 | dropout: float = 0.0, 253 | num_layers: int = 1, 254 | resnet_eps: float = 1e-6, 255 | resnet_time_scale_shift: str = "default", 256 | resnet_groups: int = 32, 257 | resnet_pre_norm: bool = True, 258 | output_scale_factor=1.0, 259 | add_upsample=True, 260 | ): 261 | super().__init__() 262 | resnets = [] 263 | 264 | for i in range(num_layers): 265 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 266 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 267 | 268 | resnets.append( 269 | ResnetBlock2D( 270 | in_channels=resnet_in_channels + res_skip_channels, 271 | out_channels=out_channels, 272 | temb_channels=temb_channels, 273 | eps=resnet_eps, 274 | groups=resnet_groups, 275 | dropout=dropout, 276 | time_embedding_norm=resnet_time_scale_shift, 277 | pre_norm=resnet_pre_norm, 278 | ) 279 | ) 280 | 281 | self.resnets = nn.ModuleList(resnets) 282 | 283 | self.upsample = None 284 | if add_upsample: 285 | self.upsample = Upsample2D(out_channels, use_conv=True, out_channels=out_channels) 286 | 287 | self.gradient_checkpointing = False 288 | 289 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): 290 | for resnet in self.resnets: 291 | # pop res hidden states 292 | res_hidden_states = res_hidden_states_tuple[-1] 293 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 294 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 295 | 296 | if self.training and self.gradient_checkpointing: 297 | def create_custom_forward(module): 298 | def custom_forward(*inputs): 299 | return module(*inputs) 300 | return custom_forward 301 | 302 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 303 | else: 304 | hidden_states = resnet(hidden_states, temb) 305 | 306 | if self.upsample is not None: 307 | hidden_states = self.upsample(hidden_states, upsample_size) 308 | 309 | return hidden_states 310 | 311 | 312 | class MLP(nn.Module): 313 | def __init__(self, hidden_size, embedding_dim, transformer_dropout=0.0): 314 | super().__init__() 315 | self.transformer_dropout = transformer_dropout 316 | self.norm = nn.LayerNorm(embedding_dim) 317 | self.dense1 = nn.Linear(hidden_size, hidden_size * 4) 318 | self.scale = nn.Linear(embedding_dim, hidden_size * 4) 319 | self.shift = nn.Linear(embedding_dim, hidden_size * 4) 320 | self.out = nn.Linear(hidden_size * 4, hidden_size) 321 | 322 | def forward(self, x, emb): 323 | x = self.norm(x) 324 | mlp_h = self.dense1(emb) 325 | mlp_h = F.silu(mlp_h) 326 | scale = self.scale(mlp_h) 327 | shift = self.shift(mlp_h) 328 | mlp_h = mlp_h * (1 + scale) + shift 329 | if self.transformer_dropout > 0.0: 330 | mlp_h = nn.functional.dropout(mlp_h, p=self.transformer_dropout, training=self.training) 331 | out = self.out(out) 332 | return out 333 | 334 | class UViT(nn.Module): 335 | pass --------------------------------------------------------------------------------