├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── chroma.png ├── chroma_pytorch ├── __init__.py ├── chroma_pytorch.py └── semantic_conditioner.py ├── rfdiffusion.gif └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | figure 1 in paper 4 | 5 | 6 | 7 | *generating a protein that binds to spike protein of coronavirus - Baker lab's concurrent RFDiffusion work* 8 | 9 | ## Chroma - Pytorch (wip) 10 | 11 | Implementation of Chroma, generative model of proteins using DDPM and GNNs, in Pytorch. Concurrent work seems to suggest we have a slight lift-off applying denoising diffusion probabilistic models to protein design. Will also incorporate self-conditioning, applied successfully by Baker lab in RFDiffusion. 12 | 13 | Explanation by Stephan Heijl 14 | 15 | If you are interested in open sourcing works like these out in the wild, please consider joining OpenBioML 16 | 17 | ## Todo 18 | 19 | - [ ] use galactica 20 | 21 | ## Citations 22 | 23 | ```bibtex 24 | @misc{ 25 | title = {Illuminating protein space with a programmable generative model}, 26 | author = {John Ingraham, Max Baranov, Zak Costello, Vincent Frappier, Ahmed Ismail, Shan Tie, Wujie Wang, Vincent Xue, Fritz Obermeyer, Andrew Beam, Gevorg Grigoryan}, 27 | year = {2022}, 28 | url = {https://cdn.generatebiomedicines.com/assets/ingraham2022.pdf} 29 | } 30 | ``` 31 | -------------------------------------------------------------------------------- /chroma.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/chroma-pytorch/4bf51ba8d43b6297985742930bc5e4b174078aea/chroma.png -------------------------------------------------------------------------------- /chroma_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from chroma_pytorch.chroma_pytorch import Chroma 2 | -------------------------------------------------------------------------------- /chroma_pytorch/chroma_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | 4 | from einops import rearrange, repeat 5 | 6 | import math 7 | from pathlib import Path 8 | from random import random 9 | from functools import partial 10 | from multiprocessing import cpu_count 11 | 12 | import torch 13 | from torch import nn, einsum 14 | from torch.special import expm1 15 | import torch.nn.functional as F 16 | from torch.utils.data import Dataset, DataLoader 17 | 18 | from torch.optim import Adam 19 | from torchvision import transforms as T, utils 20 | 21 | from einops import rearrange, reduce, repeat 22 | from einops.layers.torch import Rearrange 23 | 24 | from tqdm.auto import tqdm 25 | from ema_pytorch import EMA 26 | 27 | from accelerate import Accelerator 28 | 29 | # helpers functions 30 | 31 | def exists(x): 32 | return x is not None 33 | 34 | def default(val, d): 35 | if exists(val): 36 | return val 37 | return d() if callable(d) else d 38 | 39 | def cycle(dl): 40 | while True: 41 | for data in dl: 42 | yield data 43 | 44 | def has_int_squareroot(num): 45 | return (math.sqrt(num) ** 2) == num 46 | 47 | def num_to_groups(num, divisor): 48 | groups = num // divisor 49 | remainder = num % divisor 50 | arr = [divisor] * groups 51 | if remainder > 0: 52 | arr.append(remainder) 53 | return arr 54 | 55 | def convert_image_to(img_type, image): 56 | if image.mode != img_type: 57 | return image.convert(img_type) 58 | return image 59 | 60 | # small helper modules 61 | 62 | class Residual(nn.Module): 63 | def __init__(self, fn): 64 | super().__init__() 65 | self.fn = fn 66 | 67 | def forward(self, x, *args, **kwargs): 68 | return self.fn(x, *args, **kwargs) + x 69 | 70 | def Upsample(dim, dim_out = None): 71 | return nn.Sequential( 72 | nn.Upsample(scale_factor = 2, mode = 'nearest'), 73 | nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1) 74 | ) 75 | 76 | def Downsample(dim, dim_out = None): 77 | return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1) 78 | 79 | class LayerNorm(nn.Module): 80 | def __init__(self, dim): 81 | super().__init__() 82 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) 83 | 84 | def forward(self, x): 85 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 86 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True) 87 | mean = torch.mean(x, dim = 1, keepdim = True) 88 | return (x - mean) * (var + eps).rsqrt() * self.g 89 | 90 | class PreNorm(nn.Module): 91 | def __init__(self, dim, fn): 92 | super().__init__() 93 | self.fn = fn 94 | self.norm = LayerNorm(dim) 95 | 96 | def forward(self, x): 97 | x = self.norm(x) 98 | return self.fn(x) 99 | 100 | # positional embeds 101 | 102 | class LearnedSinusoidalPosEmb(nn.Module): 103 | def __init__(self, dim): 104 | super().__init__() 105 | assert (dim % 2) == 0 106 | half_dim = dim // 2 107 | self.weights = nn.Parameter(torch.randn(half_dim)) 108 | 109 | def forward(self, x): 110 | x = rearrange(x, 'b -> b 1') 111 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi 112 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) 113 | fouriered = torch.cat((x, fouriered), dim = -1) 114 | return fouriered 115 | 116 | # building block modules 117 | 118 | class Block(nn.Module): 119 | def __init__(self, dim, dim_out, groups = 8): 120 | super().__init__() 121 | self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1) 122 | self.norm = nn.GroupNorm(groups, dim_out) 123 | self.act = nn.SiLU() 124 | 125 | def forward(self, x, scale_shift = None): 126 | x = self.proj(x) 127 | x = self.norm(x) 128 | 129 | if exists(scale_shift): 130 | scale, shift = scale_shift 131 | x = x * (scale + 1) + shift 132 | 133 | x = self.act(x) 134 | return x 135 | 136 | class ResnetBlock(nn.Module): 137 | def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): 138 | super().__init__() 139 | self.mlp = nn.Sequential( 140 | nn.SiLU(), 141 | nn.Linear(time_emb_dim, dim_out * 2) 142 | ) if exists(time_emb_dim) else None 143 | 144 | self.block1 = Block(dim, dim_out, groups = groups) 145 | self.block2 = Block(dim_out, dim_out, groups = groups) 146 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 147 | 148 | def forward(self, x, time_emb = None): 149 | 150 | scale_shift = None 151 | if exists(self.mlp) and exists(time_emb): 152 | time_emb = self.mlp(time_emb) 153 | time_emb = rearrange(time_emb, 'b c -> b c 1 1') 154 | scale_shift = time_emb.chunk(2, dim = 1) 155 | 156 | h = self.block1(x, scale_shift = scale_shift) 157 | 158 | h = self.block2(h) 159 | 160 | return h + self.res_conv(x) 161 | 162 | class LinearAttention(nn.Module): 163 | def __init__(self, dim, heads = 4, dim_head = 32): 164 | super().__init__() 165 | self.scale = dim_head ** -0.5 166 | self.heads = heads 167 | hidden_dim = dim_head * heads 168 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 169 | 170 | self.to_out = nn.Sequential( 171 | nn.Conv2d(hidden_dim, dim, 1), 172 | LayerNorm(dim) 173 | ) 174 | 175 | def forward(self, x): 176 | b, c, h, w = x.shape 177 | qkv = self.to_qkv(x).chunk(3, dim = 1) 178 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) 179 | 180 | q = q.softmax(dim = -2) 181 | k = k.softmax(dim = -1) 182 | 183 | q = q * self.scale 184 | v = v / (h * w) 185 | 186 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v) 187 | 188 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q) 189 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) 190 | return self.to_out(out) 191 | 192 | class Attention(nn.Module): 193 | def __init__(self, dim, heads = 4, dim_head = 32): 194 | super().__init__() 195 | self.scale = dim_head ** -0.5 196 | self.heads = heads 197 | hidden_dim = dim_head * heads 198 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 199 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 200 | 201 | def forward(self, x): 202 | b, c, h, w = x.shape 203 | qkv = self.to_qkv(x).chunk(3, dim = 1) 204 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) 205 | 206 | q = q * self.scale 207 | 208 | sim = einsum('b h d i, b h d j -> b h i j', q, k) 209 | attn = sim.softmax(dim = -1) 210 | out = einsum('b h i j, b h d j -> b h i d', attn, v) 211 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) 212 | return self.to_out(out) 213 | 214 | # model 215 | 216 | class Unet(nn.Module): 217 | def __init__( 218 | self, 219 | dim, 220 | init_dim = None, 221 | dim_mults=(1, 2, 4, 8), 222 | channels = 3, 223 | resnet_block_groups = 8, 224 | learned_sinusoidal_dim = 16 225 | ): 226 | super().__init__() 227 | 228 | # determine dimensions 229 | 230 | self.channels = channels 231 | 232 | input_channels = channels * 2 233 | 234 | init_dim = default(init_dim, dim) 235 | self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) 236 | 237 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 238 | in_out = list(zip(dims[:-1], dims[1:])) 239 | 240 | block_klass = partial(ResnetBlock, groups = resnet_block_groups) 241 | 242 | # time embeddings 243 | 244 | time_dim = dim * 4 245 | 246 | sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim) 247 | fourier_dim = learned_sinusoidal_dim + 1 248 | 249 | self.time_mlp = nn.Sequential( 250 | sinu_pos_emb, 251 | nn.Linear(fourier_dim, time_dim), 252 | nn.GELU(), 253 | nn.Linear(time_dim, time_dim) 254 | ) 255 | 256 | # layers 257 | 258 | self.downs = nn.ModuleList([]) 259 | self.ups = nn.ModuleList([]) 260 | num_resolutions = len(in_out) 261 | 262 | for ind, (dim_in, dim_out) in enumerate(in_out): 263 | is_last = ind >= (num_resolutions - 1) 264 | 265 | self.downs.append(nn.ModuleList([ 266 | block_klass(dim_in, dim_in, time_emb_dim = time_dim), 267 | block_klass(dim_in, dim_in, time_emb_dim = time_dim), 268 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 269 | Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) 270 | ])) 271 | 272 | mid_dim = dims[-1] 273 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) 274 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) 275 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) 276 | 277 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): 278 | is_last = ind == (len(in_out) - 1) 279 | 280 | self.ups.append(nn.ModuleList([ 281 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), 282 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), 283 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 284 | Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) 285 | ])) 286 | 287 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) 288 | self.final_conv = nn.Conv2d(dim, channels, 1) 289 | 290 | def forward(self, x, time, x_self_cond = None): 291 | 292 | x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) 293 | x = torch.cat((x_self_cond, x), dim = 1) 294 | 295 | x = self.init_conv(x) 296 | r = x.clone() 297 | 298 | t = self.time_mlp(time) 299 | 300 | h = [] 301 | 302 | for block1, block2, attn, downsample in self.downs: 303 | x = block1(x, t) 304 | h.append(x) 305 | 306 | x = block2(x, t) 307 | x = attn(x) 308 | h.append(x) 309 | 310 | x = downsample(x) 311 | 312 | x = self.mid_block1(x, t) 313 | x = self.mid_attn(x) 314 | x = self.mid_block2(x, t) 315 | 316 | for block1, block2, attn, upsample in self.ups: 317 | x = torch.cat((x, h.pop()), dim = 1) 318 | x = block1(x, t) 319 | 320 | x = torch.cat((x, h.pop()), dim = 1) 321 | x = block2(x, t) 322 | x = attn(x) 323 | 324 | x = upsample(x) 325 | 326 | x = torch.cat((x, r), dim = 1) 327 | 328 | x = self.final_res_block(x, t) 329 | return self.final_conv(x) 330 | 331 | # chroma class 332 | 333 | def log(t, eps = 1e-20): 334 | return torch.log(t.clamp(min = eps)) 335 | 336 | def right_pad_dims_to(x, t): 337 | padding_dims = x.ndim - t.ndim 338 | if padding_dims <= 0: 339 | return t 340 | return t.view(*t.shape, *((1,) * padding_dims)) 341 | 342 | def beta_linear_log_snr(t): 343 | return -torch.log(expm1(1e-4 + 10 * (t ** 2))) 344 | 345 | def alpha_cosine_log_snr(t, s: float = 0.008): 346 | return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # not sure if this accounts for beta being clipped to 0.999 in discrete version 347 | 348 | def log_snr_to_alpha_sigma(log_snr): 349 | return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr)) 350 | 351 | class Chroma(nn.Module): 352 | def __init__( 353 | self, 354 | model, 355 | *, 356 | image_size, 357 | timesteps = 1000, 358 | use_ddim = False, 359 | noise_schedule = 'cosine', 360 | time_difference = 0. 361 | ): 362 | super().__init__() 363 | self.model = model 364 | self.channels = self.model.channels 365 | 366 | self.image_size = image_size 367 | 368 | if noise_schedule == "linear": 369 | self.log_snr = beta_linear_log_snr 370 | elif noise_schedule == "cosine": 371 | self.log_snr = alpha_cosine_log_snr 372 | else: 373 | raise ValueError(f'invalid noise schedule {noise_schedule}') 374 | 375 | self.timesteps = timesteps 376 | self.use_ddim = use_ddim 377 | 378 | # proposed in the paper, summed to time_next 379 | # as a way to fix a deficiency in self-conditioning and lower FID when the number of sampling timesteps is < 400 380 | 381 | self.time_difference = time_difference 382 | 383 | @property 384 | def device(self): 385 | return next(self.model.parameters()).device 386 | 387 | def get_sampling_timesteps(self, batch, *, device): 388 | times = torch.linspace(1., 0., self.timesteps + 1, device = device) 389 | times = repeat(times, 't -> b t', b = batch) 390 | times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0) 391 | times = times.unbind(dim = -1) 392 | return times 393 | 394 | @torch.no_grad() 395 | def ddpm_sample(self, shape, time_difference = None): 396 | batch, device = shape[0], self.device 397 | 398 | time_difference = default(time_difference, self.time_difference) 399 | 400 | time_pairs = self.get_sampling_timesteps(batch, device = device) 401 | 402 | img = torch.randn(shape, device=device) 403 | 404 | x_start = None 405 | 406 | for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', total = self.timesteps): 407 | 408 | # add the time delay 409 | 410 | time_next = (time_next - self.time_difference).clamp(min = 0.) 411 | 412 | noise_cond = self.log_snr(time) 413 | 414 | # get predicted x0 415 | 416 | x_start = self.model(img, noise_cond, x_start) 417 | 418 | # clip x0 419 | 420 | x_start.clamp_(-1., 1.) 421 | 422 | # get log(snr) 423 | 424 | log_snr = self.log_snr(time) 425 | log_snr_next = self.log_snr(time_next) 426 | log_snr, log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next)) 427 | 428 | # get alpha sigma of time and next time 429 | 430 | alpha, sigma = log_snr_to_alpha_sigma(log_snr) 431 | alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next) 432 | 433 | # derive posterior mean and variance 434 | 435 | c = -expm1(log_snr - log_snr_next) 436 | 437 | mean = alpha_next * (img * (1 - c) / alpha + c * x_start) 438 | variance = (sigma_next ** 2) * c 439 | log_variance = log(variance) 440 | 441 | # get noise 442 | 443 | noise = torch.where( 444 | rearrange(time_next > 0, 'b -> b 1 1 1'), 445 | torch.randn_like(img), 446 | torch.zeros_like(img) 447 | ) 448 | 449 | img = mean + (0.5 * log_variance).exp() * noise 450 | 451 | return img 452 | 453 | @torch.no_grad() 454 | def ddim_sample(self, shape, time_difference = None): 455 | batch, device = shape[0], self.device 456 | 457 | time_difference = default(time_difference, self.time_difference) 458 | 459 | time_pairs = self.get_sampling_timesteps(batch, device = device) 460 | 461 | img = torch.randn(shape, device = device) 462 | 463 | x_start = None 464 | 465 | for times, times_next in tqdm(time_pairs, desc = 'sampling loop time step'): 466 | 467 | # get times and noise levels 468 | 469 | log_snr = self.log_snr(times) 470 | log_snr_next = self.log_snr(times_next) 471 | 472 | padded_log_snr, padded_log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next)) 473 | 474 | alpha, sigma = log_snr_to_alpha_sigma(padded_log_snr) 475 | alpha_next, sigma_next = log_snr_to_alpha_sigma(padded_log_snr_next) 476 | 477 | # add the time delay 478 | 479 | times_next = (times_next - time_difference).clamp(min = 0.) 480 | 481 | # predict x0 482 | 483 | x_start = self.model(img, log_snr, x_start) 484 | 485 | # clip x0 486 | 487 | x_start.clamp_(-1., 1.) 488 | 489 | # get predicted noise 490 | 491 | pred_noise = (img - alpha * x_start) / sigma.clamp(min = 1e-8) 492 | 493 | # calculate x next 494 | 495 | img = x_start * alpha_next + pred_noise * sigma_next 496 | 497 | return img 498 | 499 | @torch.no_grad() 500 | def sample(self, batch_size = 16): 501 | image_size, channels = self.image_size, self.channels 502 | sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample 503 | return sample_fn((batch_size, channels, image_size, image_size)) 504 | 505 | def forward(self, img, *args, **kwargs): 506 | batch, c, h, w, device, img_size, = *img.shape, img.device, self.image_size 507 | assert h == img_size and w == img_size, f'height and width of image must be {img_size}' 508 | 509 | # sample random times 510 | 511 | times = torch.zeros((batch,), device = device).float().uniform_(0, 1.) 512 | 513 | # noise sample 514 | 515 | noise = torch.randn_like(img) 516 | 517 | noise_level = self.log_snr(times) 518 | padded_noise_level = right_pad_dims_to(img, noise_level) 519 | alpha, sigma = log_snr_to_alpha_sigma(padded_noise_level) 520 | 521 | noised_img = alpha * img + sigma * noise 522 | 523 | # if doing self-conditioning, 50% of the time, predict x_start from current set of times 524 | # and condition with unet with that 525 | # this technique will slow down training by 25%, but seems to lower FID significantly 526 | 527 | self_cond = None 528 | if random() < 0.5: 529 | with torch.no_grad(): 530 | self_cond = self.model(noised_img, noise_level).detach_() 531 | 532 | # predict and take gradient step 533 | 534 | pred = self.model(noised_img, noise_level, self_cond) 535 | 536 | return F.mse_loss(pred, img) 537 | 538 | # trainer class 539 | 540 | class Trainer(object): 541 | def __init__( 542 | self, 543 | diffusion_model, 544 | folder, 545 | *, 546 | train_batch_size = 16, 547 | gradient_accumulate_every = 1, 548 | augment_horizontal_flip = True, 549 | train_lr = 1e-4, 550 | train_num_steps = 100000, 551 | ema_update_every = 10, 552 | ema_decay = 0.995, 553 | adam_betas = (0.9, 0.99), 554 | save_and_sample_every = 1000, 555 | num_samples = 25, 556 | results_folder = './results', 557 | amp = False, 558 | fp16 = False, 559 | split_batches = True, 560 | convert_image_to = None 561 | ): 562 | super().__init__() 563 | 564 | self.accelerator = Accelerator( 565 | split_batches = split_batches, 566 | mixed_precision = 'fp16' if fp16 else 'no' 567 | ) 568 | 569 | self.accelerator.native_amp = amp 570 | 571 | self.model = diffusion_model 572 | 573 | assert has_int_squareroot(num_samples), 'number of samples must have an integer square root' 574 | self.num_samples = num_samples 575 | self.save_and_sample_every = save_and_sample_every 576 | 577 | self.batch_size = train_batch_size 578 | self.gradient_accumulate_every = gradient_accumulate_every 579 | 580 | self.train_num_steps = train_num_steps 581 | self.image_size = diffusion_model.image_size 582 | 583 | # dataset and dataloader 584 | 585 | self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip, convert_image_to = convert_image_to) 586 | dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count()) 587 | 588 | dl = self.accelerator.prepare(dl) 589 | self.dl = cycle(dl) 590 | 591 | # optimizer 592 | 593 | self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas) 594 | 595 | # for logging results in a folder periodically 596 | 597 | if self.accelerator.is_main_process: 598 | self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every) 599 | 600 | self.results_folder = Path(results_folder) 601 | self.results_folder.mkdir(exist_ok = True) 602 | 603 | # step counter state 604 | 605 | self.step = 0 606 | 607 | # prepare model, dataloader, optimizer with accelerator 608 | 609 | self.model, self.opt = self.accelerator.prepare(self.model, self.opt) 610 | 611 | def save(self, milestone): 612 | if not self.accelerator.is_local_main_process: 613 | return 614 | 615 | data = { 616 | 'step': self.step, 617 | 'model': self.accelerator.get_state_dict(self.model), 618 | 'opt': self.opt.state_dict(), 619 | 'ema': self.ema.state_dict(), 620 | 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None 621 | } 622 | 623 | torch.save(data, str(self.results_folder / f'model-{milestone}.pt')) 624 | 625 | def load(self, milestone): 626 | data = torch.load(str(self.results_folder / f'model-{milestone}.pt')) 627 | 628 | model = self.accelerator.unwrap_model(self.model) 629 | model.load_state_dict(data['model']) 630 | 631 | self.step = data['step'] 632 | self.opt.load_state_dict(data['opt']) 633 | self.ema.load_state_dict(data['ema']) 634 | 635 | if exists(self.accelerator.scaler) and exists(data['scaler']): 636 | self.accelerator.scaler.load_state_dict(data['scaler']) 637 | 638 | def train(self): 639 | accelerator = self.accelerator 640 | device = accelerator.device 641 | 642 | with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: 643 | 644 | while self.step < self.train_num_steps: 645 | 646 | total_loss = 0. 647 | 648 | for _ in range(self.gradient_accumulate_every): 649 | data = next(self.dl).to(device) 650 | 651 | with self.accelerator.autocast(): 652 | loss = self.model(data) 653 | loss = loss / self.gradient_accumulate_every 654 | total_loss += loss.item() 655 | 656 | self.accelerator.backward(loss) 657 | 658 | pbar.set_description(f'loss: {total_loss:.4f}') 659 | 660 | accelerator.wait_for_everyone() 661 | 662 | self.opt.step() 663 | self.opt.zero_grad() 664 | 665 | accelerator.wait_for_everyone() 666 | 667 | if accelerator.is_main_process: 668 | self.ema.to(device) 669 | self.ema.update() 670 | 671 | if self.step != 0 and self.step % self.save_and_sample_every == 0: 672 | self.ema.ema_model.eval() 673 | 674 | with torch.no_grad(): 675 | milestone = self.step // self.save_and_sample_every 676 | batches = num_to_groups(self.num_samples, self.batch_size) 677 | all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches)) 678 | 679 | all_images = torch.cat(all_images_list, dim = 0) 680 | utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = int(math.sqrt(self.num_samples))) 681 | self.save(milestone) 682 | 683 | self.step += 1 684 | pbar.update(1) 685 | 686 | accelerator.print('training complete') 687 | -------------------------------------------------------------------------------- /chroma_pytorch/semantic_conditioner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import logging 4 | from transformers import AutoTokenizer, AutoModelForMaskedLM, logging 5 | from tf_bind_transformer.cache_utils import cache_fn, run_once 6 | 7 | logging.set_verbosity_error() 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | def map_values(fn, dictionary): 13 | return {k: fn(v) for k, v in dictionary.items()} 14 | 15 | CONTEXT_EMBED_USE_CPU = os.getenv('CONTEXT_EMBED_USE_CPU', None) is not None 16 | 17 | if CONTEXT_EMBED_USE_CPU: 18 | print('calculating context embed only on cpu') 19 | 20 | MODELS = dict( 21 | pubmed = dict( 22 | dim = 768, 23 | path = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract', 24 | ) 25 | ) 26 | 27 | GLOBAL_VARIABLES = dict(model = None, tokenizer = None) 28 | 29 | def get_contextual_dim(model_name): 30 | assert model_name in MODELS 31 | return MODELS[model_name]['dim'] 32 | 33 | @run_once('init_transformer') 34 | def init_transformer(model_name): 35 | path = MODELS[model_name]['path'] 36 | GLOBAL_VARIABLES['tokenizer'] = AutoTokenizer.from_pretrained(path) 37 | 38 | model = AutoModelForMaskedLM.from_pretrained(path) 39 | 40 | if not CONTEXT_EMBED_USE_CPU: 41 | model = model.cuda() 42 | 43 | GLOBAL_VARIABLES['model'] = model 44 | 45 | @torch.no_grad() 46 | def tokenize_text( 47 | text, 48 | max_length = 256, 49 | model_name = 'pubmed', 50 | hidden_state_index = -1, 51 | return_cls_token = True 52 | ): 53 | init_transformer(model_name) 54 | 55 | model = GLOBAL_VARIABLES['model'] 56 | tokenizer = GLOBAL_VARIABLES['tokenizer'] 57 | 58 | encoding = tokenizer.batch_encode_plus( 59 | [text], 60 | add_special_tokens = True, 61 | padding = True, 62 | truncation = True, 63 | max_length = max_length, 64 | return_attention_mask = True, 65 | return_tensors = 'pt' 66 | ) 67 | 68 | if not CONTEXT_EMBED_USE_CPU: 69 | encoding = map_values(lambda t: t.cuda(), encoding) 70 | 71 | model.eval() 72 | with torch.no_grad(): 73 | outputs = model(**encoding, output_hidden_states = True) 74 | 75 | hidden_state = outputs.hidden_states[hidden_state_index][0] 76 | 77 | if return_cls_token: 78 | return hidden_state[0] 79 | 80 | return hidden_state.mean(dim = 0) 81 | 82 | def get_text_repr( 83 | texts, 84 | *, 85 | device, 86 | max_length = 256, 87 | model_name = 'pubmed', 88 | hidden_state_index = -1, 89 | return_cls_token = True, 90 | ): 91 | assert model_name in MODELS, f'{model_name} not found in available text transformers to use' 92 | 93 | if isinstance(texts, str): 94 | texts = [texts] 95 | 96 | get_context_repr_fn = cache_fn(tokenize_text, path = f'contexts/{model_name}') 97 | 98 | representations = [get_context_repr_fn(text, max_length = max_length, model_name = model_name, hidden_state_index = hidden_state_index, return_cls_token = return_cls_token) for text in texts] 99 | 100 | return torch.stack(representations).to(device) 101 | -------------------------------------------------------------------------------- /rfdiffusion.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/chroma-pytorch/4bf51ba8d43b6297985742930bc5e4b174078aea/rfdiffusion.gif -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'chroma-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.0.1', 7 | license='MIT', 8 | description = 'Chroma - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/chroma-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'denoising diffusion', 17 | 'protein design' 18 | ], 19 | install_requires=[ 20 | 'einops>=0.6', 21 | 'invariant-point-attention', 22 | 'torch>=1.6', 23 | ], 24 | classifiers=[ 25 | 'Development Status :: 4 - Beta', 26 | 'Intended Audience :: Developers', 27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 28 | 'License :: OSI Approved :: MIT License', 29 | 'Programming Language :: Python :: 3.6', 30 | ], 31 | ) 32 | --------------------------------------------------------------------------------