├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── difformer_pytorch ├── __init__.py └── difformer.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .mypy_cache 3 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: end-of-file-fixer 6 | - id: trailing-whitespace 7 | 8 | # Formats code correctly 9 | - repo: https://github.com/psf/black 10 | rev: 22.3.0 11 | hooks: 12 | - id: black 13 | args: [ 14 | '--experimental-string-processing' 15 | ] 16 | 17 | # Sorts imports 18 | - repo: https://github.com/pycqa/isort 19 | rev: 5.10.1 20 | hooks: 21 | - id: isort 22 | name: isort (python) 23 | args: ["--profile", "black"] 24 | 25 | # Checks unused imports, like lengths, etc 26 | - repo: https://gitlab.com/pycqa/flake8 27 | rev: 4.0.0 28 | hooks: 29 | - id: flake8 30 | args: [ 31 | '--per-file-ignores=__init__.py:F401', 32 | '--max-line-length=88', 33 | '--ignore=E1,W1,E2,W2,E4,W4,E5,W5' # Handled by black 34 | ] 35 | 36 | # Checks types 37 | - repo: https://github.com/pre-commit/mirrors-mypy 38 | rev: 'v0.971' 39 | hooks: 40 | - id: mypy 41 | additional_dependencies: [data-science-types>=0.2, torch>=1.6] 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 archinet.ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Difformer - PyTorch (Experimental) 3 | 4 | Diffusion based transformer, in PyTorch. 5 | 6 | ```bash 7 | pip install difformer-pytorch 8 | ``` 9 | [![PyPI - Python Version](https://img.shields.io/pypi/v/difformer-pytorch?style=flat&colorA=black&colorB=black)](https://pypi.org/project/difformer-pytorch/) 10 | 11 | 12 | ## Usage 13 | 14 | ### Token based 15 | ```python 16 | from difformer_pytorch import Difformer 17 | 18 | num_tokens = 1000 19 | 20 | difformer = Difformer( 21 | num_tokens=num_tokens, 22 | embedding_dim=512, 23 | num_layers=6 24 | ) 25 | 26 | # Input tokens and mask 27 | tokens = torch.randint(0, num_tokens, (1, 1024)) 28 | mask = torch.ones_like(x).bool() 29 | 30 | # Train difformer to demask 31 | loss = difformer(tokens=tokens, mask=mask) 32 | loss.backward() 33 | 34 | # Sample unmasked prediction given masked start sequence 35 | sampled = difformer.sample( 36 | tokens=tokens, 37 | mask=mask, 38 | num_steps=5 39 | ) # [1, 1024] 40 | 41 | ``` 42 | 43 | ### Embedding based 44 | ```py 45 | from difformer_pytorch import Difformer 46 | 47 | difformer = Difformer( 48 | embedding_dim=512, 49 | num_layers=6 50 | ) 51 | 52 | # Input embedding and mask 53 | embedding = torch.randn(1, 1024, 512) 54 | mask = torch.ones(1, 1024).bool() 55 | 56 | # Train difformer 57 | loss = difformer(embedding=embedding, mask=mask) 58 | loss.backward() 59 | 60 | # Sample prediction given masked start embedding 61 | sampled = difformer.sample( 62 | embedding=embedding, 63 | mask=mask, # Optional mask to apply on embeddings 64 | num_steps=5 65 | ) # [1, 1024, 512] 66 | ``` 67 | -------------------------------------------------------------------------------- /difformer_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .difformer import AEulerSampler, Difformer, LogNormalDistribution, RhoSchedule 2 | -------------------------------------------------------------------------------- /difformer_pytorch/difformer.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | from math import pi, sqrt 3 | from typing import Any, Callable, Optional, Tuple, TypeVar, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from einops import rearrange, reduce, repeat 8 | from einops_exts import rearrange_many 9 | from torch import Tensor, einsum, nn 10 | from typing_extensions import TypeGuard 11 | 12 | T = TypeVar("T") 13 | 14 | """ 15 | Utils 16 | """ 17 | 18 | 19 | def exists(val: Optional[T]) -> TypeGuard[T]: 20 | return val is not None 21 | 22 | 23 | def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: 24 | if exists(val): 25 | return val 26 | return d() if isfunction(d) else d 27 | 28 | 29 | """ 30 | Diffusion 31 | """ 32 | 33 | 34 | class Distribution: 35 | def __call__(self, num_samples: int, device: torch.device): 36 | raise NotImplementedError() 37 | 38 | 39 | class LogNormalDistribution(Distribution): 40 | def __init__(self, mean: float, std: float): 41 | self.mean = mean 42 | self.std = std 43 | 44 | def __call__( 45 | self, num_samples, device: torch.device = torch.device("cpu") 46 | ) -> Tensor: 47 | normal = self.mean + self.std * torch.randn((num_samples,), device=device) 48 | return normal.exp() 49 | 50 | 51 | class Schedule(nn.Module): 52 | """Interface used by different schedules""" 53 | 54 | def forward(self, num_steps: int, device: torch.device) -> Tensor: 55 | raise NotImplementedError() 56 | 57 | 58 | class RhoSchedule(Schedule): 59 | """https://arxiv.org/abs/2206.00364 equation 5""" 60 | 61 | def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0): 62 | super().__init__() 63 | self.sigma_min = sigma_min 64 | self.sigma_max = sigma_max 65 | self.rho = rho 66 | 67 | def forward(self, num_steps: int, device: Any) -> Tensor: 68 | rho_inv = 1.0 / self.rho 69 | steps = torch.arange(num_steps, device=device, dtype=torch.float32) 70 | sigmas = ( 71 | self.sigma_max**rho_inv 72 | + (steps / (num_steps - 1)) 73 | * (self.sigma_min**rho_inv - self.sigma_max**rho_inv) 74 | ) ** self.rho 75 | sigmas = F.pad(sigmas, pad=(0, 1), value=0.0) 76 | return sigmas 77 | 78 | 79 | class Sampler(nn.Module): 80 | def forward( 81 | self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int 82 | ) -> Tensor: 83 | raise NotImplementedError() 84 | 85 | 86 | class AEulerSampler(Sampler): 87 | def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]: 88 | sigma_up = sqrt(sigma_next**2 * (sigma**2 - sigma_next**2) / sigma**2) 89 | sigma_down = sqrt(sigma_next**2 - sigma_up**2) 90 | return sigma_up, sigma_down 91 | 92 | def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor: 93 | # Sigma steps 94 | sigma_up, sigma_down = self.get_sigmas(sigma, sigma_next) 95 | # Derivative at sigma (∂x/∂sigma) 96 | d = (x - fn(x, sigma=sigma)) / sigma 97 | # Euler method 98 | x_next = x + d * (sigma_down - sigma) 99 | # Add randomness 100 | x_next = x_next + torch.randn_like(x) * sigma_up 101 | return x_next 102 | 103 | def forward( 104 | self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int 105 | ) -> Tensor: 106 | x = sigmas[0] * noise 107 | # Denoise to sample 108 | for i in range(num_steps - 1): 109 | x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa 110 | return x 111 | 112 | 113 | class Diffusion(nn.Module): 114 | """Elucidated Diffusion: https://arxiv.org/abs/2206.00364""" 115 | 116 | def __init__( 117 | self, 118 | net: nn.Module, 119 | *, 120 | sigma_distribution: Distribution, 121 | sigma_data: float, 122 | ): 123 | super().__init__() 124 | 125 | self.net = net 126 | self.sigma_data = sigma_data 127 | self.sigma_distribution = sigma_distribution 128 | 129 | def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]: 130 | sigma_data = self.sigma_data 131 | sigmas_padded = rearrange(sigmas, "b -> b 1 1") 132 | c_skip = (sigma_data**2) / (sigmas_padded**2 + sigma_data**2) 133 | c_out = ( 134 | sigmas_padded * sigma_data * (sigma_data**2 + sigmas_padded**2) ** -0.5 135 | ) 136 | c_in = (sigmas_padded**2 + sigma_data**2) ** -0.5 137 | c_noise = torch.log(sigmas) * 0.25 138 | return c_skip, c_out, c_in, c_noise 139 | 140 | def denoise_fn( 141 | self, 142 | x_noisy: Tensor, 143 | sigmas: Optional[Tensor] = None, 144 | sigma: Optional[float] = None, 145 | **kwargs, 146 | ) -> Tensor: 147 | batch, device = x_noisy.shape[0], x_noisy.device 148 | 149 | assert exists(sigmas) ^ exists(sigma), "Either sigmas or sigma must be provided" 150 | 151 | # If sigma provided use the same for all batch items (used for sampling) 152 | if exists(sigma): 153 | sigmas = torch.full(size=(batch,), fill_value=sigma).to(device) 154 | 155 | assert exists(sigmas) 156 | 157 | # Predict network output and add skip connection 158 | c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas) 159 | x_pred = self.net(c_in * x_noisy, c_noise, **kwargs) 160 | x_denoised = c_skip * x_noisy + c_out * x_pred 161 | 162 | # Dynamic thresholding 163 | return x_denoised.clamp(-1.0, 1.0) 164 | 165 | def loss_weight(self, sigmas: Tensor) -> Tensor: 166 | # Computes weight depending on data distribution 167 | return (sigmas**2 + self.sigma_data**2) * (sigmas * self.sigma_data) ** -2 168 | 169 | def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: 170 | batch, device = x.shape[0], x.device 171 | 172 | # Sample amount of noise to add for each batch element 173 | sigmas = self.sigma_distribution(num_samples=batch, device=device) 174 | sigmas_padded = rearrange(sigmas, "b -> b 1 1") 175 | 176 | # Add noise to input 177 | noise = default(noise, lambda: torch.randn_like(x)) 178 | x_noisy = x + sigmas_padded * noise 179 | 180 | # Compute denoised values 181 | x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs) 182 | 183 | # Compute weighted loss 184 | losses = F.mse_loss(x_denoised, x, reduction="none") 185 | losses = reduce(losses, "b ... -> b", "mean") 186 | losses = losses * self.loss_weight(sigmas) 187 | loss = losses.mean() 188 | 189 | return loss 190 | 191 | 192 | class DiffusionSampler(nn.Module): 193 | def __init__( 194 | self, 195 | diffusion: Diffusion, 196 | *, 197 | sampler: Sampler, 198 | sigma_schedule: Schedule, 199 | num_steps: Optional[int] = None, 200 | ): 201 | super().__init__() 202 | self.denoise_fn = diffusion.denoise_fn 203 | self.sampler = sampler 204 | self.sigma_schedule = sigma_schedule 205 | self.num_steps = num_steps 206 | 207 | @torch.no_grad() 208 | def forward( 209 | self, noise: Tensor, num_steps: Optional[int] = None, **kwargs 210 | ) -> Tensor: 211 | device = noise.device 212 | num_steps = default(num_steps, self.num_steps) # type: ignore 213 | assert exists(num_steps), "Parameter `num_steps` must be provided" 214 | # Compute sigmas using schedule 215 | sigmas = self.sigma_schedule(num_steps, device) 216 | # Append additional kwargs to denoise_fn (used e.g. for conditional model) 217 | fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa 218 | # Sample using sampler 219 | x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps) 220 | x = x.clamp(-1.0, 1.0) 221 | return x 222 | 223 | 224 | """ 225 | Transformer 226 | """ 227 | 228 | 229 | def attention_mask( 230 | sim: Tensor, 231 | mask: Tensor, 232 | ) -> Tensor: 233 | mask = rearrange(mask, "b j -> b 1 1 j") 234 | max_neg_value = -torch.finfo(sim.dtype).max 235 | sim = sim.masked_fill(~mask, max_neg_value) 236 | return sim 237 | 238 | 239 | class LayerNorm(nn.Module): 240 | def __init__(self, features: int, *, bias: bool = True, eps: float = 1e-5): 241 | super().__init__() 242 | self.bias = bias 243 | self.eps = eps 244 | self.g = nn.Parameter(torch.ones(features)) 245 | self.b = nn.Parameter(torch.zeros(features)) if bias else None 246 | 247 | def forward(self, x: Tensor) -> Tensor: 248 | var = torch.var(x, dim=-1, unbiased=False, keepdim=True) 249 | mean = torch.mean(x, dim=-1, keepdim=True) 250 | norm = (x - mean) * (var + self.eps).rsqrt() * self.g 251 | return norm + self.b if self.bias else norm 252 | 253 | 254 | class AttentionBase(nn.Module): 255 | def __init__( 256 | self, 257 | features: int, 258 | *, 259 | head_features: int = 64, 260 | num_heads: int = 8, 261 | out_features: Optional[int] = None, 262 | ): 263 | super().__init__() 264 | self.scale = head_features**-0.5 265 | self.num_heads = num_heads 266 | mid_features = head_features * num_heads 267 | out_features = out_features if exists(out_features) else features 268 | 269 | self.to_out = nn.Sequential( 270 | nn.Linear(in_features=mid_features, out_features=out_features, bias=False), 271 | LayerNorm(features=out_features, bias=False), 272 | ) 273 | 274 | def forward( 275 | self, 276 | q: Tensor, 277 | k: Tensor, 278 | v: Tensor, 279 | *, 280 | mask: Optional[Tensor] = None, 281 | rel_pos: Optional[nn.Module] = None, 282 | ) -> Tensor: 283 | 284 | # Split heads, scale queries 285 | q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads) 286 | 287 | # Compute similarity matrix with bias and mask 288 | sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale 289 | sim = rel_pos(sim) if exists(rel_pos) else sim 290 | sim = attention_mask(sim, mask) if exists(mask) else sim 291 | 292 | # Get attention matrix with softmax 293 | attn = sim.softmax(dim=-1, dtype=torch.float32) 294 | 295 | # Compute values 296 | out = einsum("... n j, ... j d -> ... n d", attn, v) 297 | out = rearrange(out, "b h n d -> b n (h d)") 298 | return self.to_out(out) 299 | 300 | 301 | class Attention(nn.Module): 302 | def __init__( 303 | self, 304 | features: int, 305 | *, 306 | head_features: int = 64, 307 | num_heads: int = 8, 308 | out_features: Optional[int] = None, 309 | ): 310 | super().__init__() 311 | mid_features = head_features * num_heads 312 | 313 | self.norm = LayerNorm(features, bias=False) 314 | self.to_qkv = nn.Linear( 315 | in_features=features, out_features=mid_features * 3, bias=False 316 | ) 317 | self.attention = AttentionBase( 318 | features, 319 | num_heads=num_heads, 320 | head_features=head_features, 321 | out_features=out_features, 322 | ) 323 | 324 | def forward(self, x: Tensor, **kwargs) -> Tensor: 325 | x = self.norm(x) 326 | q, k, v = torch.chunk(self.to_qkv(x), chunks=3, dim=-1) 327 | x = self.attention(q, k, v, **kwargs) 328 | return x 329 | 330 | 331 | def FeedForward(features: int, multiplier: int = 2) -> nn.Module: 332 | mid_features = int(features * multiplier) 333 | return nn.Sequential( 334 | LayerNorm(features, bias=False), 335 | nn.Linear(in_features=features, out_features=mid_features, bias=False), 336 | nn.GELU(), 337 | LayerNorm(mid_features, bias=False), 338 | nn.Linear(in_features=mid_features, out_features=features, bias=False), 339 | ) 340 | 341 | 342 | class TransformerBlock(nn.Module): 343 | def __init__( 344 | self, 345 | features: int, 346 | *, 347 | head_features: int = 64, 348 | num_heads: int = 8, 349 | multiplier: int = 2, 350 | ): 351 | super().__init__() 352 | 353 | self.attention = Attention( 354 | features=features, head_features=head_features, num_heads=num_heads 355 | ) 356 | 357 | self.feed_forward = FeedForward(features=features, multiplier=multiplier) 358 | 359 | def forward(self, x: Tensor, **kwargs) -> Tensor: 360 | x = self.attention(x, **kwargs) + x 361 | x = self.feed_forward(x) + x 362 | return x 363 | 364 | 365 | class DynamicPositionBias(nn.Module): 366 | """From https://github.com/lucidrains/x-transformers/""" 367 | 368 | def __init__( 369 | self, 370 | dim: int, 371 | num_heads: int, 372 | depth: int = 2, 373 | log_distance: bool = False, 374 | norm: bool = False, 375 | ): 376 | super().__init__() 377 | assert depth >= 1, "depth for dynamic position bias MLP must be >= 1" 378 | self.log_distance = log_distance 379 | 380 | self.mlp = nn.ModuleList( 381 | [ 382 | nn.Sequential( 383 | nn.Linear(1, dim), 384 | nn.LayerNorm(dim) if norm else nn.Identity(), 385 | nn.ReLU(), 386 | ) 387 | ] 388 | ) 389 | 390 | for _ in range(depth - 1): 391 | self.mlp.append( 392 | nn.Sequential( 393 | nn.Linear(dim, dim), 394 | nn.LayerNorm(dim) if norm else nn.Identity(), 395 | nn.ReLU(), 396 | ) 397 | ) 398 | 399 | self.mlp.append(nn.Linear(dim, num_heads)) 400 | 401 | def forward(self, qk_dots: Tensor) -> Tensor: 402 | n, device, dtype = qk_dots.shape[-1], qk_dots.device, qk_dots.dtype 403 | 404 | # get the (n x n) matrix of distances 405 | seq_arange = torch.arange(n, device=device) 406 | ctx_arange = torch.arange(n, device=device) 407 | indices = rearrange(seq_arange, "i -> i 1") - rearrange(ctx_arange, "j -> 1 j") 408 | indices += n - 1 409 | 410 | # input to continuous positions MLP 411 | pos = torch.arange(-n + 1, n, device=device, dtype=dtype) 412 | pos = rearrange(pos, "... -> ... 1") 413 | 414 | if self.log_distance: 415 | # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1) 416 | pos = torch.sign(pos) * torch.log(pos.abs() + 1) 417 | 418 | for layer in self.mlp: 419 | pos = layer(pos) 420 | 421 | # get position biases 422 | bias = pos[indices] 423 | bias = rearrange(bias, "i j h -> h i j") 424 | return qk_dots + bias 425 | 426 | 427 | class LearnedPositionalEmbedding(nn.Module): 428 | """Used for continuous time""" 429 | 430 | def __init__(self, dim: int): 431 | super().__init__() 432 | assert (dim % 2) == 0 433 | half_dim = dim // 2 434 | self.weights = nn.Parameter(torch.randn(half_dim)) 435 | 436 | def forward(self, x: Tensor) -> Tensor: 437 | x = rearrange(x, "b -> b 1") 438 | freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi 439 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) 440 | fouriered = torch.cat((x, fouriered), dim=-1) 441 | return fouriered 442 | 443 | 444 | def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: 445 | return nn.Sequential( 446 | LearnedPositionalEmbedding(dim), 447 | nn.Linear(in_features=dim + 1, out_features=out_features), 448 | ) 449 | 450 | 451 | class ContinuousTransformer(nn.Module): 452 | def __init__( 453 | self, 454 | *, 455 | features: int, 456 | context_features: int, 457 | num_blocks: int, 458 | head_features: int = 64, 459 | num_heads: int = 8, 460 | multiplier: int = 4, 461 | ): 462 | super().__init__() 463 | self.features = features 464 | time_features = features * 2 465 | 466 | self.to_time = nn.Sequential( 467 | TimePositionalEmbedding(dim=features, out_features=time_features), 468 | nn.SiLU(), 469 | nn.Linear(in_features=time_features, out_features=time_features), 470 | ) 471 | 472 | self.rel_pos = DynamicPositionBias(dim=features // 4, num_heads=num_heads) 473 | 474 | self.blocks = nn.ModuleList( 475 | [ 476 | TransformerBlock( 477 | features=features + context_features, 478 | head_features=head_features, 479 | num_heads=num_heads, 480 | multiplier=multiplier, 481 | ) 482 | for i in range(num_blocks) 483 | ] 484 | ) 485 | 486 | def forward(self, x: Tensor, t: Tensor, context: Tensor) -> Tensor: 487 | n = x.shape[1] 488 | # Concat context 489 | x = torch.cat([x, context], dim=2) 490 | # Concat time token 491 | t = rearrange(self.to_time(t), "b d -> b 1 d") 492 | x = torch.cat([x, t], dim=1) 493 | # Feed into transformer 494 | for block in self.blocks: 495 | x = block(x, rel_pos=self.rel_pos) 496 | # Remove extra token and context features 497 | x = x[:, 0:n, 0 : self.features] 498 | return x 499 | 500 | 501 | class TokenEmbedding(nn.Module): 502 | def __init__(self, num_tokens: int, embedding_dim: int): 503 | super().__init__() 504 | self.embedding = nn.Embedding( 505 | num_embeddings=num_tokens, embedding_dim=embedding_dim 506 | ) 507 | 508 | def get_ids(self, x: Tensor) -> Tensor: 509 | b = x.shape[0] 510 | e = repeat(self.embedding.weight, "n d -> b n d", b=b) 511 | sim = -torch.cdist(x, e, p=2) 512 | indices = sim.argmax(dim=-1) 513 | return indices 514 | 515 | def forward(self, x: Tensor) -> Tensor: 516 | return torch.tanh(self.embedding(x)) 517 | 518 | 519 | """ 520 | Difformer 521 | """ 522 | 523 | 524 | class DifformerBase(nn.Module): 525 | def __init__( 526 | self, 527 | embedding_dim: int, 528 | num_layers: int, 529 | num_heads: int, 530 | head_features: int, 531 | multiplier: int, 532 | diffusion_sigma_distribution: Distribution, 533 | diffusion_sigma_data: float, 534 | num_tokens: Optional[int] = None, 535 | ): 536 | super().__init__() 537 | assert ( 538 | embedding_dim % num_heads == 0 539 | ), "embedding_dim must be divisible by num_heads" 540 | self.has_embedding = exists(num_tokens) 541 | 542 | if self.has_embedding: 543 | assert exists(num_tokens) 544 | self.token_embedding = TokenEmbedding( 545 | num_tokens=num_tokens, embedding_dim=embedding_dim 546 | ) 547 | 548 | self.transformer = ContinuousTransformer( 549 | features=embedding_dim, 550 | context_features=embedding_dim, 551 | num_blocks=num_layers, 552 | head_features=head_features, 553 | num_heads=num_heads, 554 | multiplier=multiplier, 555 | ) 556 | 557 | self.diffusion = Diffusion( 558 | net=self.transformer, 559 | sigma_distribution=diffusion_sigma_distribution, 560 | sigma_data=diffusion_sigma_data, 561 | ) 562 | 563 | def forward( 564 | self, 565 | mask: Tensor, 566 | tokens: Optional[Tensor] = None, 567 | embedding: Optional[Tensor] = None, 568 | ) -> Tensor: 569 | assert_message = "Either tokens or embedding must be provided" 570 | assert exists(tokens) ^ exists(embedding), assert_message 571 | 572 | embedding_masked = None 573 | 574 | if exists(tokens): 575 | self.assert_exists_embedding() 576 | embedding = self.token_embedding(tokens) 577 | embedding_masked = self.token_embedding(tokens.masked_fill(~mask, 0)) 578 | else: 579 | assert exists(embedding) 580 | mask = rearrange(mask, "b n -> b n 1") 581 | embedding_masked = embedding.masked_fill(~mask, 0) 582 | 583 | return self.diffusion(embedding, context=embedding_masked) 584 | 585 | def sample( 586 | self, 587 | num_steps: int, 588 | sigma_schedule: Schedule, 589 | sampler: Sampler, 590 | tokens: Optional[Tensor] = None, 591 | embedding: Optional[Tensor] = None, 592 | mask: Optional[Tensor] = None, 593 | **kwargs, 594 | ) -> Tensor: 595 | assert_message = "Either tokens or start embedding must be provided" 596 | assert exists(tokens) ^ exists(embedding), assert_message 597 | 598 | embedding_masked = embedding 599 | 600 | if exists(tokens): 601 | self.assert_exists_embedding() 602 | if exists(mask): 603 | embedding_masked = self.token_embedding(tokens.masked_fill(~mask, 0)) 604 | else: 605 | embedding_masked = self.token_embedding(tokens) 606 | elif exists(mask): 607 | assert exists(embedding) 608 | mask = rearrange(mask, "b n -> b n 1") 609 | embedding_masked = embedding.masked_fill(~mask, 0) # type: ignore 610 | 611 | assert exists(embedding_masked) 612 | 613 | noise = torch.randn_like(embedding_masked) 614 | # Sample unmasked embedding 615 | diffusion_sampler = DiffusionSampler( 616 | diffusion=self.diffusion, 617 | num_steps=num_steps, 618 | sampler=sampler, 619 | sigma_schedule=sigma_schedule, 620 | ) 621 | embedding_sample = diffusion_sampler(noise, context=embedding_masked, **kwargs) 622 | 623 | # Convert back into tokens, if input is token based 624 | if exists(tokens): 625 | self.assert_exists_embedding() 626 | return self.token_embedding.get_ids(embedding_sample) 627 | 628 | return embedding_sample 629 | 630 | def assert_exists_embedding(self): 631 | assert_message = "num_tokens required in constructor if token based input" 632 | assert self.has_embedding, assert_message 633 | 634 | 635 | class Difformer(DifformerBase): 636 | def __init__(self, *args, **kwargs): 637 | default_kwargs = dict( 638 | num_heads=8, 639 | head_features=64, 640 | multiplier=4, 641 | diffusion_sigma_distribution=LogNormalDistribution(-3.0, 1.0), 642 | diffusion_sigma_data=0.1, 643 | ) 644 | super().__init__(*args, **{**default_kwargs, **kwargs}) 645 | 646 | def sample(self, *args, **kwargs): 647 | default_kwargs = dict( 648 | sigma_schedule=RhoSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), 649 | sampler=AEulerSampler(), 650 | ) 651 | return super().sample(*args, **{**default_kwargs, **kwargs}) 652 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="difformer-pytorch", 5 | packages=find_packages(exclude=[]), 6 | version="0.0.6", 7 | license="MIT", 8 | description="Difformer - PyTorch", 9 | long_description_content_type="text/markdown", 10 | author="Flavio Schneider", 11 | author_email="archinetai@protonmail.com", 12 | url="https://github.com/archinetai/difformer-pytorch", 13 | keywords=["artificial intelligence", "deep learning", "transformer", "diffusion"], 14 | install_requires=[ 15 | "torch>=1.6", 16 | "data-science-types>=0.2", 17 | "einops>=0.4", 18 | ], 19 | classifiers=[ 20 | "Development Status :: 4 - Beta", 21 | "Intended Audience :: Developers", 22 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 23 | "License :: OSI Approved :: MIT License", 24 | "Programming Language :: Python :: 3.6", 25 | ], 26 | ) 27 | --------------------------------------------------------------------------------