├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── chroma.png
├── chroma_pytorch
├── __init__.py
├── chroma_pytorch.py
└── semantic_conditioner.py
├── rfdiffusion.gif
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | figure 1 in paper
4 |
5 |
6 |
7 | *generating a protein that binds to spike protein of coronavirus - Baker lab's concurrent RFDiffusion work*
8 |
9 | ## Chroma - Pytorch (wip)
10 |
11 | Implementation of Chroma, generative model of proteins using DDPM and GNNs, in Pytorch. Concurrent work seems to suggest we have a slight lift-off applying denoising diffusion probabilistic models to protein design. Will also incorporate self-conditioning, applied successfully by Baker lab in RFDiffusion.
12 |
13 | Explanation by Stephan Heijl
14 |
15 | If you are interested in open sourcing works like these out in the wild, please consider joining OpenBioML
16 |
17 | ## Todo
18 |
19 | - [ ] use galactica
20 |
21 | ## Citations
22 |
23 | ```bibtex
24 | @misc{
25 | title = {Illuminating protein space with a programmable generative model},
26 | author = {John Ingraham, Max Baranov, Zak Costello, Vincent Frappier, Ahmed Ismail, Shan Tie, Wujie Wang, Vincent Xue, Fritz Obermeyer, Andrew Beam, Gevorg Grigoryan},
27 | year = {2022},
28 | url = {https://cdn.generatebiomedicines.com/assets/ingraham2022.pdf}
29 | }
30 | ```
31 |
--------------------------------------------------------------------------------
/chroma.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/chroma-pytorch/4bf51ba8d43b6297985742930bc5e4b174078aea/chroma.png
--------------------------------------------------------------------------------
/chroma_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from chroma_pytorch.chroma_pytorch import Chroma
2 |
--------------------------------------------------------------------------------
/chroma_pytorch/chroma_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 |
4 | from einops import rearrange, repeat
5 |
6 | import math
7 | from pathlib import Path
8 | from random import random
9 | from functools import partial
10 | from multiprocessing import cpu_count
11 |
12 | import torch
13 | from torch import nn, einsum
14 | from torch.special import expm1
15 | import torch.nn.functional as F
16 | from torch.utils.data import Dataset, DataLoader
17 |
18 | from torch.optim import Adam
19 | from torchvision import transforms as T, utils
20 |
21 | from einops import rearrange, reduce, repeat
22 | from einops.layers.torch import Rearrange
23 |
24 | from tqdm.auto import tqdm
25 | from ema_pytorch import EMA
26 |
27 | from accelerate import Accelerator
28 |
29 | # helpers functions
30 |
31 | def exists(x):
32 | return x is not None
33 |
34 | def default(val, d):
35 | if exists(val):
36 | return val
37 | return d() if callable(d) else d
38 |
39 | def cycle(dl):
40 | while True:
41 | for data in dl:
42 | yield data
43 |
44 | def has_int_squareroot(num):
45 | return (math.sqrt(num) ** 2) == num
46 |
47 | def num_to_groups(num, divisor):
48 | groups = num // divisor
49 | remainder = num % divisor
50 | arr = [divisor] * groups
51 | if remainder > 0:
52 | arr.append(remainder)
53 | return arr
54 |
55 | def convert_image_to(img_type, image):
56 | if image.mode != img_type:
57 | return image.convert(img_type)
58 | return image
59 |
60 | # small helper modules
61 |
62 | class Residual(nn.Module):
63 | def __init__(self, fn):
64 | super().__init__()
65 | self.fn = fn
66 |
67 | def forward(self, x, *args, **kwargs):
68 | return self.fn(x, *args, **kwargs) + x
69 |
70 | def Upsample(dim, dim_out = None):
71 | return nn.Sequential(
72 | nn.Upsample(scale_factor = 2, mode = 'nearest'),
73 | nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
74 | )
75 |
76 | def Downsample(dim, dim_out = None):
77 | return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)
78 |
79 | class LayerNorm(nn.Module):
80 | def __init__(self, dim):
81 | super().__init__()
82 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
83 |
84 | def forward(self, x):
85 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3
86 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
87 | mean = torch.mean(x, dim = 1, keepdim = True)
88 | return (x - mean) * (var + eps).rsqrt() * self.g
89 |
90 | class PreNorm(nn.Module):
91 | def __init__(self, dim, fn):
92 | super().__init__()
93 | self.fn = fn
94 | self.norm = LayerNorm(dim)
95 |
96 | def forward(self, x):
97 | x = self.norm(x)
98 | return self.fn(x)
99 |
100 | # positional embeds
101 |
102 | class LearnedSinusoidalPosEmb(nn.Module):
103 | def __init__(self, dim):
104 | super().__init__()
105 | assert (dim % 2) == 0
106 | half_dim = dim // 2
107 | self.weights = nn.Parameter(torch.randn(half_dim))
108 |
109 | def forward(self, x):
110 | x = rearrange(x, 'b -> b 1')
111 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
112 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
113 | fouriered = torch.cat((x, fouriered), dim = -1)
114 | return fouriered
115 |
116 | # building block modules
117 |
118 | class Block(nn.Module):
119 | def __init__(self, dim, dim_out, groups = 8):
120 | super().__init__()
121 | self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
122 | self.norm = nn.GroupNorm(groups, dim_out)
123 | self.act = nn.SiLU()
124 |
125 | def forward(self, x, scale_shift = None):
126 | x = self.proj(x)
127 | x = self.norm(x)
128 |
129 | if exists(scale_shift):
130 | scale, shift = scale_shift
131 | x = x * (scale + 1) + shift
132 |
133 | x = self.act(x)
134 | return x
135 |
136 | class ResnetBlock(nn.Module):
137 | def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
138 | super().__init__()
139 | self.mlp = nn.Sequential(
140 | nn.SiLU(),
141 | nn.Linear(time_emb_dim, dim_out * 2)
142 | ) if exists(time_emb_dim) else None
143 |
144 | self.block1 = Block(dim, dim_out, groups = groups)
145 | self.block2 = Block(dim_out, dim_out, groups = groups)
146 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
147 |
148 | def forward(self, x, time_emb = None):
149 |
150 | scale_shift = None
151 | if exists(self.mlp) and exists(time_emb):
152 | time_emb = self.mlp(time_emb)
153 | time_emb = rearrange(time_emb, 'b c -> b c 1 1')
154 | scale_shift = time_emb.chunk(2, dim = 1)
155 |
156 | h = self.block1(x, scale_shift = scale_shift)
157 |
158 | h = self.block2(h)
159 |
160 | return h + self.res_conv(x)
161 |
162 | class LinearAttention(nn.Module):
163 | def __init__(self, dim, heads = 4, dim_head = 32):
164 | super().__init__()
165 | self.scale = dim_head ** -0.5
166 | self.heads = heads
167 | hidden_dim = dim_head * heads
168 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
169 |
170 | self.to_out = nn.Sequential(
171 | nn.Conv2d(hidden_dim, dim, 1),
172 | LayerNorm(dim)
173 | )
174 |
175 | def forward(self, x):
176 | b, c, h, w = x.shape
177 | qkv = self.to_qkv(x).chunk(3, dim = 1)
178 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
179 |
180 | q = q.softmax(dim = -2)
181 | k = k.softmax(dim = -1)
182 |
183 | q = q * self.scale
184 | v = v / (h * w)
185 |
186 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
187 |
188 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
189 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
190 | return self.to_out(out)
191 |
192 | class Attention(nn.Module):
193 | def __init__(self, dim, heads = 4, dim_head = 32):
194 | super().__init__()
195 | self.scale = dim_head ** -0.5
196 | self.heads = heads
197 | hidden_dim = dim_head * heads
198 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
199 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
200 |
201 | def forward(self, x):
202 | b, c, h, w = x.shape
203 | qkv = self.to_qkv(x).chunk(3, dim = 1)
204 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
205 |
206 | q = q * self.scale
207 |
208 | sim = einsum('b h d i, b h d j -> b h i j', q, k)
209 | attn = sim.softmax(dim = -1)
210 | out = einsum('b h i j, b h d j -> b h i d', attn, v)
211 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
212 | return self.to_out(out)
213 |
214 | # model
215 |
216 | class Unet(nn.Module):
217 | def __init__(
218 | self,
219 | dim,
220 | init_dim = None,
221 | dim_mults=(1, 2, 4, 8),
222 | channels = 3,
223 | resnet_block_groups = 8,
224 | learned_sinusoidal_dim = 16
225 | ):
226 | super().__init__()
227 |
228 | # determine dimensions
229 |
230 | self.channels = channels
231 |
232 | input_channels = channels * 2
233 |
234 | init_dim = default(init_dim, dim)
235 | self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
236 |
237 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
238 | in_out = list(zip(dims[:-1], dims[1:]))
239 |
240 | block_klass = partial(ResnetBlock, groups = resnet_block_groups)
241 |
242 | # time embeddings
243 |
244 | time_dim = dim * 4
245 |
246 | sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim)
247 | fourier_dim = learned_sinusoidal_dim + 1
248 |
249 | self.time_mlp = nn.Sequential(
250 | sinu_pos_emb,
251 | nn.Linear(fourier_dim, time_dim),
252 | nn.GELU(),
253 | nn.Linear(time_dim, time_dim)
254 | )
255 |
256 | # layers
257 |
258 | self.downs = nn.ModuleList([])
259 | self.ups = nn.ModuleList([])
260 | num_resolutions = len(in_out)
261 |
262 | for ind, (dim_in, dim_out) in enumerate(in_out):
263 | is_last = ind >= (num_resolutions - 1)
264 |
265 | self.downs.append(nn.ModuleList([
266 | block_klass(dim_in, dim_in, time_emb_dim = time_dim),
267 | block_klass(dim_in, dim_in, time_emb_dim = time_dim),
268 | Residual(PreNorm(dim_in, LinearAttention(dim_in))),
269 | Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
270 | ]))
271 |
272 | mid_dim = dims[-1]
273 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
274 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
275 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
276 |
277 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
278 | is_last = ind == (len(in_out) - 1)
279 |
280 | self.ups.append(nn.ModuleList([
281 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
282 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
283 | Residual(PreNorm(dim_out, LinearAttention(dim_out))),
284 | Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
285 | ]))
286 |
287 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
288 | self.final_conv = nn.Conv2d(dim, channels, 1)
289 |
290 | def forward(self, x, time, x_self_cond = None):
291 |
292 | x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
293 | x = torch.cat((x_self_cond, x), dim = 1)
294 |
295 | x = self.init_conv(x)
296 | r = x.clone()
297 |
298 | t = self.time_mlp(time)
299 |
300 | h = []
301 |
302 | for block1, block2, attn, downsample in self.downs:
303 | x = block1(x, t)
304 | h.append(x)
305 |
306 | x = block2(x, t)
307 | x = attn(x)
308 | h.append(x)
309 |
310 | x = downsample(x)
311 |
312 | x = self.mid_block1(x, t)
313 | x = self.mid_attn(x)
314 | x = self.mid_block2(x, t)
315 |
316 | for block1, block2, attn, upsample in self.ups:
317 | x = torch.cat((x, h.pop()), dim = 1)
318 | x = block1(x, t)
319 |
320 | x = torch.cat((x, h.pop()), dim = 1)
321 | x = block2(x, t)
322 | x = attn(x)
323 |
324 | x = upsample(x)
325 |
326 | x = torch.cat((x, r), dim = 1)
327 |
328 | x = self.final_res_block(x, t)
329 | return self.final_conv(x)
330 |
331 | # chroma class
332 |
333 | def log(t, eps = 1e-20):
334 | return torch.log(t.clamp(min = eps))
335 |
336 | def right_pad_dims_to(x, t):
337 | padding_dims = x.ndim - t.ndim
338 | if padding_dims <= 0:
339 | return t
340 | return t.view(*t.shape, *((1,) * padding_dims))
341 |
342 | def beta_linear_log_snr(t):
343 | return -torch.log(expm1(1e-4 + 10 * (t ** 2)))
344 |
345 | def alpha_cosine_log_snr(t, s: float = 0.008):
346 | return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # not sure if this accounts for beta being clipped to 0.999 in discrete version
347 |
348 | def log_snr_to_alpha_sigma(log_snr):
349 | return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))
350 |
351 | class Chroma(nn.Module):
352 | def __init__(
353 | self,
354 | model,
355 | *,
356 | image_size,
357 | timesteps = 1000,
358 | use_ddim = False,
359 | noise_schedule = 'cosine',
360 | time_difference = 0.
361 | ):
362 | super().__init__()
363 | self.model = model
364 | self.channels = self.model.channels
365 |
366 | self.image_size = image_size
367 |
368 | if noise_schedule == "linear":
369 | self.log_snr = beta_linear_log_snr
370 | elif noise_schedule == "cosine":
371 | self.log_snr = alpha_cosine_log_snr
372 | else:
373 | raise ValueError(f'invalid noise schedule {noise_schedule}')
374 |
375 | self.timesteps = timesteps
376 | self.use_ddim = use_ddim
377 |
378 | # proposed in the paper, summed to time_next
379 | # as a way to fix a deficiency in self-conditioning and lower FID when the number of sampling timesteps is < 400
380 |
381 | self.time_difference = time_difference
382 |
383 | @property
384 | def device(self):
385 | return next(self.model.parameters()).device
386 |
387 | def get_sampling_timesteps(self, batch, *, device):
388 | times = torch.linspace(1., 0., self.timesteps + 1, device = device)
389 | times = repeat(times, 't -> b t', b = batch)
390 | times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
391 | times = times.unbind(dim = -1)
392 | return times
393 |
394 | @torch.no_grad()
395 | def ddpm_sample(self, shape, time_difference = None):
396 | batch, device = shape[0], self.device
397 |
398 | time_difference = default(time_difference, self.time_difference)
399 |
400 | time_pairs = self.get_sampling_timesteps(batch, device = device)
401 |
402 | img = torch.randn(shape, device=device)
403 |
404 | x_start = None
405 |
406 | for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', total = self.timesteps):
407 |
408 | # add the time delay
409 |
410 | time_next = (time_next - self.time_difference).clamp(min = 0.)
411 |
412 | noise_cond = self.log_snr(time)
413 |
414 | # get predicted x0
415 |
416 | x_start = self.model(img, noise_cond, x_start)
417 |
418 | # clip x0
419 |
420 | x_start.clamp_(-1., 1.)
421 |
422 | # get log(snr)
423 |
424 | log_snr = self.log_snr(time)
425 | log_snr_next = self.log_snr(time_next)
426 | log_snr, log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next))
427 |
428 | # get alpha sigma of time and next time
429 |
430 | alpha, sigma = log_snr_to_alpha_sigma(log_snr)
431 | alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)
432 |
433 | # derive posterior mean and variance
434 |
435 | c = -expm1(log_snr - log_snr_next)
436 |
437 | mean = alpha_next * (img * (1 - c) / alpha + c * x_start)
438 | variance = (sigma_next ** 2) * c
439 | log_variance = log(variance)
440 |
441 | # get noise
442 |
443 | noise = torch.where(
444 | rearrange(time_next > 0, 'b -> b 1 1 1'),
445 | torch.randn_like(img),
446 | torch.zeros_like(img)
447 | )
448 |
449 | img = mean + (0.5 * log_variance).exp() * noise
450 |
451 | return img
452 |
453 | @torch.no_grad()
454 | def ddim_sample(self, shape, time_difference = None):
455 | batch, device = shape[0], self.device
456 |
457 | time_difference = default(time_difference, self.time_difference)
458 |
459 | time_pairs = self.get_sampling_timesteps(batch, device = device)
460 |
461 | img = torch.randn(shape, device = device)
462 |
463 | x_start = None
464 |
465 | for times, times_next in tqdm(time_pairs, desc = 'sampling loop time step'):
466 |
467 | # get times and noise levels
468 |
469 | log_snr = self.log_snr(times)
470 | log_snr_next = self.log_snr(times_next)
471 |
472 | padded_log_snr, padded_log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next))
473 |
474 | alpha, sigma = log_snr_to_alpha_sigma(padded_log_snr)
475 | alpha_next, sigma_next = log_snr_to_alpha_sigma(padded_log_snr_next)
476 |
477 | # add the time delay
478 |
479 | times_next = (times_next - time_difference).clamp(min = 0.)
480 |
481 | # predict x0
482 |
483 | x_start = self.model(img, log_snr, x_start)
484 |
485 | # clip x0
486 |
487 | x_start.clamp_(-1., 1.)
488 |
489 | # get predicted noise
490 |
491 | pred_noise = (img - alpha * x_start) / sigma.clamp(min = 1e-8)
492 |
493 | # calculate x next
494 |
495 | img = x_start * alpha_next + pred_noise * sigma_next
496 |
497 | return img
498 |
499 | @torch.no_grad()
500 | def sample(self, batch_size = 16):
501 | image_size, channels = self.image_size, self.channels
502 | sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample
503 | return sample_fn((batch_size, channels, image_size, image_size))
504 |
505 | def forward(self, img, *args, **kwargs):
506 | batch, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
507 | assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
508 |
509 | # sample random times
510 |
511 | times = torch.zeros((batch,), device = device).float().uniform_(0, 1.)
512 |
513 | # noise sample
514 |
515 | noise = torch.randn_like(img)
516 |
517 | noise_level = self.log_snr(times)
518 | padded_noise_level = right_pad_dims_to(img, noise_level)
519 | alpha, sigma = log_snr_to_alpha_sigma(padded_noise_level)
520 |
521 | noised_img = alpha * img + sigma * noise
522 |
523 | # if doing self-conditioning, 50% of the time, predict x_start from current set of times
524 | # and condition with unet with that
525 | # this technique will slow down training by 25%, but seems to lower FID significantly
526 |
527 | self_cond = None
528 | if random() < 0.5:
529 | with torch.no_grad():
530 | self_cond = self.model(noised_img, noise_level).detach_()
531 |
532 | # predict and take gradient step
533 |
534 | pred = self.model(noised_img, noise_level, self_cond)
535 |
536 | return F.mse_loss(pred, img)
537 |
538 | # trainer class
539 |
540 | class Trainer(object):
541 | def __init__(
542 | self,
543 | diffusion_model,
544 | folder,
545 | *,
546 | train_batch_size = 16,
547 | gradient_accumulate_every = 1,
548 | augment_horizontal_flip = True,
549 | train_lr = 1e-4,
550 | train_num_steps = 100000,
551 | ema_update_every = 10,
552 | ema_decay = 0.995,
553 | adam_betas = (0.9, 0.99),
554 | save_and_sample_every = 1000,
555 | num_samples = 25,
556 | results_folder = './results',
557 | amp = False,
558 | fp16 = False,
559 | split_batches = True,
560 | convert_image_to = None
561 | ):
562 | super().__init__()
563 |
564 | self.accelerator = Accelerator(
565 | split_batches = split_batches,
566 | mixed_precision = 'fp16' if fp16 else 'no'
567 | )
568 |
569 | self.accelerator.native_amp = amp
570 |
571 | self.model = diffusion_model
572 |
573 | assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
574 | self.num_samples = num_samples
575 | self.save_and_sample_every = save_and_sample_every
576 |
577 | self.batch_size = train_batch_size
578 | self.gradient_accumulate_every = gradient_accumulate_every
579 |
580 | self.train_num_steps = train_num_steps
581 | self.image_size = diffusion_model.image_size
582 |
583 | # dataset and dataloader
584 |
585 | self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip, convert_image_to = convert_image_to)
586 | dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())
587 |
588 | dl = self.accelerator.prepare(dl)
589 | self.dl = cycle(dl)
590 |
591 | # optimizer
592 |
593 | self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)
594 |
595 | # for logging results in a folder periodically
596 |
597 | if self.accelerator.is_main_process:
598 | self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)
599 |
600 | self.results_folder = Path(results_folder)
601 | self.results_folder.mkdir(exist_ok = True)
602 |
603 | # step counter state
604 |
605 | self.step = 0
606 |
607 | # prepare model, dataloader, optimizer with accelerator
608 |
609 | self.model, self.opt = self.accelerator.prepare(self.model, self.opt)
610 |
611 | def save(self, milestone):
612 | if not self.accelerator.is_local_main_process:
613 | return
614 |
615 | data = {
616 | 'step': self.step,
617 | 'model': self.accelerator.get_state_dict(self.model),
618 | 'opt': self.opt.state_dict(),
619 | 'ema': self.ema.state_dict(),
620 | 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None
621 | }
622 |
623 | torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
624 |
625 | def load(self, milestone):
626 | data = torch.load(str(self.results_folder / f'model-{milestone}.pt'))
627 |
628 | model = self.accelerator.unwrap_model(self.model)
629 | model.load_state_dict(data['model'])
630 |
631 | self.step = data['step']
632 | self.opt.load_state_dict(data['opt'])
633 | self.ema.load_state_dict(data['ema'])
634 |
635 | if exists(self.accelerator.scaler) and exists(data['scaler']):
636 | self.accelerator.scaler.load_state_dict(data['scaler'])
637 |
638 | def train(self):
639 | accelerator = self.accelerator
640 | device = accelerator.device
641 |
642 | with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar:
643 |
644 | while self.step < self.train_num_steps:
645 |
646 | total_loss = 0.
647 |
648 | for _ in range(self.gradient_accumulate_every):
649 | data = next(self.dl).to(device)
650 |
651 | with self.accelerator.autocast():
652 | loss = self.model(data)
653 | loss = loss / self.gradient_accumulate_every
654 | total_loss += loss.item()
655 |
656 | self.accelerator.backward(loss)
657 |
658 | pbar.set_description(f'loss: {total_loss:.4f}')
659 |
660 | accelerator.wait_for_everyone()
661 |
662 | self.opt.step()
663 | self.opt.zero_grad()
664 |
665 | accelerator.wait_for_everyone()
666 |
667 | if accelerator.is_main_process:
668 | self.ema.to(device)
669 | self.ema.update()
670 |
671 | if self.step != 0 and self.step % self.save_and_sample_every == 0:
672 | self.ema.ema_model.eval()
673 |
674 | with torch.no_grad():
675 | milestone = self.step // self.save_and_sample_every
676 | batches = num_to_groups(self.num_samples, self.batch_size)
677 | all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))
678 |
679 | all_images = torch.cat(all_images_list, dim = 0)
680 | utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = int(math.sqrt(self.num_samples)))
681 | self.save(milestone)
682 |
683 | self.step += 1
684 | pbar.update(1)
685 |
686 | accelerator.print('training complete')
687 |
--------------------------------------------------------------------------------
/chroma_pytorch/semantic_conditioner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import logging
4 | from transformers import AutoTokenizer, AutoModelForMaskedLM, logging
5 | from tf_bind_transformer.cache_utils import cache_fn, run_once
6 |
7 | logging.set_verbosity_error()
8 |
9 | def exists(val):
10 | return val is not None
11 |
12 | def map_values(fn, dictionary):
13 | return {k: fn(v) for k, v in dictionary.items()}
14 |
15 | CONTEXT_EMBED_USE_CPU = os.getenv('CONTEXT_EMBED_USE_CPU', None) is not None
16 |
17 | if CONTEXT_EMBED_USE_CPU:
18 | print('calculating context embed only on cpu')
19 |
20 | MODELS = dict(
21 | pubmed = dict(
22 | dim = 768,
23 | path = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract',
24 | )
25 | )
26 |
27 | GLOBAL_VARIABLES = dict(model = None, tokenizer = None)
28 |
29 | def get_contextual_dim(model_name):
30 | assert model_name in MODELS
31 | return MODELS[model_name]['dim']
32 |
33 | @run_once('init_transformer')
34 | def init_transformer(model_name):
35 | path = MODELS[model_name]['path']
36 | GLOBAL_VARIABLES['tokenizer'] = AutoTokenizer.from_pretrained(path)
37 |
38 | model = AutoModelForMaskedLM.from_pretrained(path)
39 |
40 | if not CONTEXT_EMBED_USE_CPU:
41 | model = model.cuda()
42 |
43 | GLOBAL_VARIABLES['model'] = model
44 |
45 | @torch.no_grad()
46 | def tokenize_text(
47 | text,
48 | max_length = 256,
49 | model_name = 'pubmed',
50 | hidden_state_index = -1,
51 | return_cls_token = True
52 | ):
53 | init_transformer(model_name)
54 |
55 | model = GLOBAL_VARIABLES['model']
56 | tokenizer = GLOBAL_VARIABLES['tokenizer']
57 |
58 | encoding = tokenizer.batch_encode_plus(
59 | [text],
60 | add_special_tokens = True,
61 | padding = True,
62 | truncation = True,
63 | max_length = max_length,
64 | return_attention_mask = True,
65 | return_tensors = 'pt'
66 | )
67 |
68 | if not CONTEXT_EMBED_USE_CPU:
69 | encoding = map_values(lambda t: t.cuda(), encoding)
70 |
71 | model.eval()
72 | with torch.no_grad():
73 | outputs = model(**encoding, output_hidden_states = True)
74 |
75 | hidden_state = outputs.hidden_states[hidden_state_index][0]
76 |
77 | if return_cls_token:
78 | return hidden_state[0]
79 |
80 | return hidden_state.mean(dim = 0)
81 |
82 | def get_text_repr(
83 | texts,
84 | *,
85 | device,
86 | max_length = 256,
87 | model_name = 'pubmed',
88 | hidden_state_index = -1,
89 | return_cls_token = True,
90 | ):
91 | assert model_name in MODELS, f'{model_name} not found in available text transformers to use'
92 |
93 | if isinstance(texts, str):
94 | texts = [texts]
95 |
96 | get_context_repr_fn = cache_fn(tokenize_text, path = f'contexts/{model_name}')
97 |
98 | representations = [get_context_repr_fn(text, max_length = max_length, model_name = model_name, hidden_state_index = hidden_state_index, return_cls_token = return_cls_token) for text in texts]
99 |
100 | return torch.stack(representations).to(device)
101 |
--------------------------------------------------------------------------------
/rfdiffusion.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/chroma-pytorch/4bf51ba8d43b6297985742930bc5e4b174078aea/rfdiffusion.gif
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'chroma-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.0.1',
7 | license='MIT',
8 | description = 'Chroma - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/chroma-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'denoising diffusion',
17 | 'protein design'
18 | ],
19 | install_requires=[
20 | 'einops>=0.6',
21 | 'invariant-point-attention',
22 | 'torch>=1.6',
23 | ],
24 | classifiers=[
25 | 'Development Status :: 4 - Beta',
26 | 'Intended Audience :: Developers',
27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
28 | 'License :: OSI Approved :: MIT License',
29 | 'Programming Language :: Python :: 3.6',
30 | ],
31 | )
32 |
--------------------------------------------------------------------------------