├── coatiLDM ├── __init__.py ├── common │ ├── __init__.py │ ├── utils.py │ ├── s3.py │ ├── ema.py │ └── fd.py ├── data │ ├── __init__.py │ ├── rank_data.py │ ├── dist_datapipe.py │ ├── decoding.py │ ├── datapipe.py │ └── transforms.py ├── models │ ├── __init__.py │ ├── coati │ │ ├── __init__.py │ │ ├── tokenizers │ │ │ ├── __init__.py │ │ │ ├── trie_tokenizer.py │ │ │ └── trie.py │ │ ├── io.py │ │ ├── trie_tokenizer.py │ │ ├── transformer_only.py │ │ └── basic_transformer.py │ ├── score_models │ │ ├── __init__.py │ │ ├── flow_wrapper.py │ │ ├── ranknet.py │ │ ├── score_model.py │ │ ├── due_dflow.py │ │ ├── due_cg_model.py │ │ └── non_conv_unet.py │ ├── diffusion_models │ │ ├── __init__.py │ │ ├── ddpm_lightweight.py │ │ ├── schedulers.py │ │ ├── particle_guidance.py │ │ ├── dflow.py │ │ ├── flow_matching.py │ │ ├── ddim_sample_routines.py │ │ └── ddpm_sample_routines.py │ └── io.py ├── trainers │ ├── __init__.py │ ├── train_ranknet.py │ ├── train_cg_due.py │ ├── train_cg_resnet.py │ ├── lfm_direct.py │ └── ldm_direct.py └── constants.py ├── MANIFEST.in ├── coati.jpg ├── coati_ldm.gif ├── requirements.txt ├── setup.py └── README.md /coatiLDM/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /coatiLDM/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /coatiLDM/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /coatiLDM/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /coatiLDM/models/coati/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /coatiLDM/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /coatiLDM/models/score_models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /coatiLDM/models/diffusion_models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include coatiLDM/models/coati/tokenizers/vocabs/* -------------------------------------------------------------------------------- /coati.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/terraytherapeutics/COATI-LDM/HEAD/coati.jpg -------------------------------------------------------------------------------- /coati_ldm.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/terraytherapeutics/COATI-LDM/HEAD/coati_ldm.gif -------------------------------------------------------------------------------- /coatiLDM/models/score_models/flow_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ODEWrapper(nn.Module): 6 | 7 | def __init__(self, score_net): 8 | super(ODEWrapper, self).__init__() 9 | self.score_net = score_net 10 | 11 | def forward(self, t, x): 12 | device = next(self.score_net.parameters()).device 13 | t = t * torch.ones(len(x), device=device) 14 | return self.score_net(x, t) 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.0 2 | torchdata==0.7.0 3 | torchdiffeq==0.2.3 4 | numpy==1.24.4 5 | altair==5.4.1 6 | rdkit-pypi==2022.9.5 7 | pandas>1.0 8 | jupyter==1.1.1 9 | matplotlib==3.9.2 10 | scipy==1.14.1 11 | scikit-learn==1.5.1 12 | boto3==1.35.7 13 | botocore==1.35.7 14 | tqdm==4.66.5 15 | torchinfo==1.8.0 16 | einops==0.8.0 17 | smart-open==7.0.4 18 | seaborn==0.13.2 19 | zuko==1.2.0 20 | s3fs==2024.6.0 21 | due @ git+https://github.com/y0ast/DUE.git -------------------------------------------------------------------------------- /coatiLDM/models/coati/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | from typing import Dict, List 5 | 6 | from coatiLDM.common.s3 import cache_read 7 | 8 | from .smiles_vocab import tokenizer_vocabs 9 | 10 | # absolute path to the vocabulary folder 11 | VOCAB_PATH = Path(__file__).parent / "vocabs" 12 | 13 | 14 | def load_vocab(vocab_name: str) -> Dict[str, List[str]]: 15 | with open(VOCAB_PATH / f"{vocab_name}.json", "r") as f: 16 | return json.load(f) 17 | 18 | 19 | def get_vocab(vocab_name: str) -> Dict[str, List[str]]: 20 | try: 21 | return tokenizer_vocabs[vocab_name] 22 | except KeyError: 23 | print("vocab_name not found in tokenizer_vocabs, trying to load from file") 24 | 25 | try: 26 | return load_vocab(vocab_name) 27 | except: 28 | raise ValueError(f"vocab_name {vocab_name} not found in vocabs folder") 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | __version__ = "0.1.0" 6 | 7 | with open("README.md") as readme_file: 8 | readme = readme_file.read() 9 | 10 | with open("requirements.txt") as req_file: 11 | # Exclude GitHub dependencies for `install_requires` 12 | requirements = [ 13 | line for line in req_file.read().splitlines() if not line.startswith("git+") 14 | ] 15 | 16 | setup( 17 | author="Ben Kaufman, Edward Williams, Carl Underkoffler, Ryan Pederson, Miles Wang-Henderson, John Parkhill", 18 | author_email="bkaufman@terraytx.com", 19 | python_requires=">=3.10", 20 | classifiers=[ 21 | "Development Status :: 1 - Pre-Alpha", 22 | "Intended Audience :: Developers", 23 | "Natural Language :: English", 24 | "Programming Language :: Python :: 3.10", 25 | ], 26 | description="COATI Diffusion", 27 | install_requires=requirements, 28 | packages=find_packages(), 29 | long_description=readme, 30 | include_package_data=True, 31 | keywords="diffusion", 32 | name="coatiLDM", 33 | version=__version__, 34 | zip_safe=False, 35 | ) 36 | -------------------------------------------------------------------------------- /coatiLDM/models/diffusion_models/ddpm_lightweight.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DDPMScoreNetTrainer(torch.nn.Module): 5 | def __init__(self, score_net): 6 | super().__init__() 7 | self.score_net = score_net 8 | self.scheduler = score_net.scheduler 9 | 10 | def forward(self, x, cond=None, loss_weight=None): 11 | batch_size = x.shape[0] 12 | device = next(self.score_net.parameters()).device 13 | T = torch.randint( 14 | low=0, high=self.scheduler.timesteps, size=(batch_size,), device=device 15 | ) 16 | noise = torch.randn((batch_size, self.score_net.x_dim), device=device) 17 | noisy_samples = ( 18 | self.scheduler.bar_alpha(T).sqrt() * x 19 | + (1.0 - self.scheduler.bar_alpha(T)).sqrt() * noise 20 | ) 21 | extracted_noise = self.score_net(noisy_samples, t=T.float(), cond=cond) 22 | pre_weight = torch.pow(noise - extracted_noise, 2.0).mean(-1) 23 | if not loss_weight is None: 24 | assert loss_weight.shape[0] == batch_size 25 | assert loss_weight.dim() == 1 26 | return (pre_weight * loss_weight).mean() 27 | return pre_weight.mean() 28 | -------------------------------------------------------------------------------- /coatiLDM/data/rank_data.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | 3 | import numpy as np 4 | import torch 5 | from smart_open import open 6 | from torch.utils.data.datapipes.iter import IterableWrapper 7 | 8 | rng = np.random.default_rng(42) 9 | 10 | 11 | def collate_rank_batch(batch): 12 | x_i = torch.tensor(np.vstack([entry["smiles_i_enc"] for entry in batch])).float() 13 | x_j = torch.tensor(np.vstack([entry["smiles_j_enc"] for entry in batch])).float() 14 | label = torch.tensor(np.array([entry["label"] for entry in batch])) 15 | 16 | return x_i, x_j, label 17 | 18 | 19 | def make_rank_pipes(datapath, train_prob: float = 0.9, bsize: int = 32): 20 | print(f"loading data from {datapath}") 21 | with open(datapath, "rb") as inf: 22 | data_records = pkl.load(inf) 23 | 24 | rng = np.random.default_rng(42) 25 | 26 | # split into train/test partition 27 | is_train = list(rng.random((len(data_records))) < train_prob) 28 | 29 | train_recs = [rec for train_data, rec in zip(is_train, data_records) if train_data] 30 | test_recs = [ 31 | rec for train_data, rec in zip(is_train, data_records) if not train_data 32 | ] 33 | 34 | train_pipe = ( 35 | IterableWrapper(train_recs, deepcopy=False) 36 | .batch(bsize) 37 | .collate(collate_rank_batch) 38 | ) 39 | 40 | test_pipe = ( 41 | IterableWrapper(test_recs, deepcopy=False) 42 | .batch(bsize) 43 | .collate(collate_rank_batch) 44 | ) 45 | 46 | return train_pipe, test_pipe 47 | -------------------------------------------------------------------------------- /coatiLDM/data/dist_datapipe.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data.datapipes.iter import FileLister, Shuffler, IterableWrapper 3 | from torch.utils.data.datapipes.datapipe import IterDataPipe 4 | from torchdata.datapipes.iter import FileLister, InMemoryCacheHolder 5 | from torch.utils.data.datapipes._decorator import functional_datapipe 6 | import pickle 7 | 8 | 9 | @functional_datapipe("unstack_pickv2") 10 | class UnstackPickles(IterDataPipe): 11 | def __init__(self, dp, keep_fields=None) -> None: 12 | super().__init__() 13 | self.dp = dp 14 | self.keep_fields = keep_fields 15 | 16 | def __iter__(self): 17 | for X in self.dp: 18 | # print('loading... ',X) 19 | with open(X, "rb") as f: 20 | raw_rows = pickle.load(f) 21 | if not self.keep_fields is None: 22 | yield [ 23 | {key: row[key] for key in row if key in self.keep_fields} 24 | for row in raw_rows 25 | ] 26 | else: 27 | yield raw_rows 28 | 29 | 30 | def get_dist_pipe(data_path, cache_mask=["*.pkl"], keep_fields=["smiles"]): 31 | pipe = ( 32 | FileLister( 33 | root=data_path, 34 | recursive=False, 35 | masks=cache_mask, 36 | ) 37 | .shuffle() 38 | .unstack_pickv2(keep_fields=keep_fields) 39 | .unbatch() 40 | .sharding_filter() 41 | .shuffle(buffer_size=100_000) 42 | # .in_memory_cache(size=50_000) 43 | ) 44 | return pipe 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # COATI-LDM 2 | 3 | This repository contains the code and data associated with the paper preprint [Latent Diffusion for Conditional Generation of Molecules](https://www.biorxiv.org/content/10.1101/2024.08.22.609169). 4 | 5 | ![](coati_ldm.gif) 6 | 7 | ## Installation 8 | 9 | To install the required dependencies and the package itself, run the following commands from the base directory: 10 | 11 | ```bash 12 | pip install . 13 | ``` 14 | 15 | ## Examples 16 | 17 | runnable notebooks for paper figures can be found in `figure_notebooks/`. 18 | 19 | general examples for training and using models can be found in `example_notebooks/`. 20 | 21 | ## Models 22 | 23 | model artifacts from the paper can be pulled down from the paths in `coatiLDM/constants.py`. 24 | 25 | 26 | 27 | ## COATI 28 | 29 | For more details and examples using the COATI model visit [COATI](https://github.com/terraytherapeutics/COATI) and the associated paper below. 30 | 31 | icon 32 | 33 | ## Cite 34 | 35 | ``` 36 | @article{kaufman2024latent, 37 | title={Latent Diffusion For Conditional Generation of Molecules}, 38 | author={Kaufman, Benjamin and Williams, Edward C and Pederson, Ryan and Underkoffler, Carl and Panjwani, Zahid and Wang-Henderson, Miles and Mardirossian, Narbe and Katcher, Matthew H and Strater, Zack and Grandjean, Jean-Marc and others}, 39 | journal={bioRxiv}, 40 | pages={2024--08}, 41 | year={2024}, 42 | publisher={Cold Spring Harbor Laboratory} 43 | } 44 | ``` 45 | 46 | ``` 47 | @article{kaufman2024coati, 48 | title={Coati: Multimodal contrastive pretraining for representing and traversing chemical space}, 49 | author={Kaufman, Benjamin and Williams, Edward C and Underkoffler, Carl and Pederson, Ryan and Mardirossian, Narbe and Watson, Ian and Parkhill, John}, 50 | journal={Journal of Chemical Information and Modeling}, 51 | volume={64}, 52 | number={4}, 53 | pages={1145--1157}, 54 | year={2024}, 55 | publisher={ACS Publications} 56 | } 57 | ``` -------------------------------------------------------------------------------- /coatiLDM/constants.py: -------------------------------------------------------------------------------- 1 | FIGURE_DATA_PATH = "s3://terray-public/coatiLDM/figure_data" 2 | DIFFUSION_MODELS = { 3 | "diffusion_large": "s3://terray-public/coatiLDM/models/general_coati_diffuser.pt", 4 | "hcaii_diff": "s3://terray-public/coatiLDM/models/hcaii_diff.pt", 5 | "logp_diff": "s3://terray-public/coatiLDM/models/logp_diff.pt", 6 | "tpsa_diff": "s3://terray-public/coatiLDM/models/tpsa_diff.pt", 7 | "logp_tpsa_diff": "s3://terray-public/coatiLDM/models/logp_tpsa_diff.pt", 8 | "uncond_diff": "s3://terray-public/coatiLDM/models/uncond_diff.pt", 9 | "qed_diff": "s3://terray-public/coatiLDM/models/qed_opt_diffuser.pt", 10 | } 11 | 12 | FLOW_MODELS = { 13 | "flow_large": "s3://terray-public/coatiLDM/models/general_coati_flow.pt", 14 | "logp_flow": "s3://terray-public/coatiLDM/models/logp_flow.pt", 15 | "tpsa_flow": "s3://terray-public/coatiLDM/models/tpsa_flow.pt", 16 | "hcaii_flow": "s3://terray-public/coatiLDM/models/hcaii_flow.pt", 17 | "logp_tpsa_flow": "s3://terray-public/coatiLDM/models/logp_tpsa_flow.pt", 18 | "uncond_flow": "s3://terray-public/coatiLDM/models/uncond_flow.pt", 19 | "uncond_for_dflow": "s3://terray-public/coatiLDM/models/uncond_for_dflow.pt", 20 | } 21 | 22 | COATI2_DOCS = { 23 | "general_doc": "s3://terray-public/coatiLDM/models/general_doc.pt", 24 | "qed_doc": "s3://terray-public/coatiLDM/models/qed_doc.pt", 25 | } 26 | 27 | CLASSIFIER_GUIDE_DOCS = { 28 | "tpsa": "s3://terray-public/coatiLDM/models/tpsa_cg.pt", 29 | "logp": "s3://terray-public/coatiLDM/models/logp_cg.pt", 30 | "hcaii": "s3://terray-public/coatiLDM/models/hcaii_cg.pt", 31 | "qed": "s3://terray-public/coatiLDM/models/guide.pt", 32 | } 33 | 34 | DFLOW_CLASSIFIER_DOCS = { 35 | "hcaii": "s3://terray-public/coatiLDM/models/dflow_hcaii_due.pt", 36 | "logp": "s3://terray-public/coatiLDM/models/dflow_logp_due.pt", 37 | "tpsa": "s3://terray-public/coatiLDM/models/dflow_tpsa_due.pt", 38 | } 39 | 40 | QED_OPT_DOCS = { 41 | "score_model": "s3://terray-public/coatiLDM/models/qed_opt_diffuser.pt", 42 | "guide": "s3://terray-public/coatiLDM/models/guide.pt", 43 | } 44 | -------------------------------------------------------------------------------- /coatiLDM/models/score_models/ranknet.py: -------------------------------------------------------------------------------- 1 | # models and such. Basic MLP. use DUE later. 2 | from typing import Tuple 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class RankNet(nn.Module): 9 | def __init__( 10 | self, 11 | input_size: int = 2048, 12 | hidden_size: int = 256, 13 | n_layers: int = 3, 14 | dropout_p: float = 0.0, 15 | ) -> None: 16 | """Basic RankNet implementation. Pairs of samples are classified 17 | according to sigmoid(s_i - s_j) where s_i, s_j are scores learned 18 | during training. 19 | 20 | Args: 21 | input_size (int, optional): Descriptor size for each sample. Defaults to 2048. 22 | hidden_size (int, optional): Number of neurons in hidden layers. Defaults to 256. 23 | n_layers (int, optional): Number of hidden layers. Defaults to 3. 24 | dropout_p (float, optional): Dropout probability. Defaults to 0.0. 25 | """ 26 | super(RankNet, self).__init__() 27 | self.encoder = nn.Sequential( 28 | nn.Linear(input_size, hidden_size), nn.Dropout(dropout_p), nn.ReLU() 29 | ) 30 | 31 | for _ in range(n_layers): 32 | self.encoder.append(nn.Linear(hidden_size, hidden_size)) 33 | self.encoder.append(nn.Dropout(dropout_p)) 34 | self.encoder.append(nn.ReLU()) 35 | self.encoder.append(nn.Linear(hidden_size, 1)) 36 | 37 | def forward( 38 | self, x_i: torch.Tensor, x_j: torch.Tensor, sigmoid: bool = False 39 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 40 | assert x_i.size() == x_j.size() 41 | 42 | score_i, score_j = self.encoder(x_i), self.encoder(x_j) 43 | out = score_i - score_j 44 | if sigmoid: 45 | out = torch.sigmoid(out) 46 | return score_i, score_j, out 47 | 48 | def score(self, x: torch.Tensor) -> torch.Tensor: 49 | """Scores sample `x` 50 | 51 | Args: 52 | x: input fingerprints // (n_samples, n_feat) 53 | """ 54 | with torch.inference_mode(): 55 | return self.encoder(x) 56 | -------------------------------------------------------------------------------- /coatiLDM/models/score_models/score_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn.utils.parametrizations import weight_norm 5 | 6 | 7 | class SwiGLU(nn.Module): 8 | def forward(self, x): 9 | x, gate = x.chunk(2, dim=-1) 10 | return torch.nn.functional.silu(gate) * x 11 | 12 | 13 | class SwiGLUNet(nn.Module): 14 | def __init__( 15 | self, d_in, d_out, residual=False, dropout=0.0, use_weight_norm=False, bias=True 16 | ): 17 | 18 | super().__init__() 19 | self.residual = residual 20 | self.net = nn.Sequential( 21 | nn.LayerNorm(d_in), 22 | torch.nn.Dropout(p=dropout), 23 | # should this one be weight-normed as well? (vs just the second) 24 | ( 25 | weight_norm(nn.Linear(d_in, 2 * d_out, bias=bias), dim=None) 26 | if use_weight_norm 27 | else nn.Linear(d_in, 2 * d_out, bias=bias) 28 | ), 29 | SwiGLU(), 30 | ( 31 | weight_norm(nn.Linear(d_out, d_out, bias=bias), dim=None) 32 | if use_weight_norm 33 | else nn.Linear(d_out, d_out, bias=bias) 34 | ), 35 | ) 36 | 37 | def forward(self, x): 38 | if self.residual: 39 | return self.net(x) + x 40 | else: 41 | return self.net(x) 42 | 43 | 44 | def get_time_embedding( 45 | timesteps, 46 | embedding_dim: int, 47 | dtype=torch.float32, 48 | max_timescale=10_000, 49 | min_timescale=1, 50 | max_time=1.0, 51 | ): 52 | # Adapted from tensor2tensor and VDM codebase. 53 | 54 | timesteps *= ( 55 | 1000.0 / max_time 56 | ) # In DDPM the time step is in [0, 1000], in BFN [0, 1] 57 | num_timescales = embedding_dim // 2 58 | inv_timescales = torch.logspace( # or exp(-linspace(log(min), log(max), n)) 59 | -np.log10(min_timescale), 60 | -np.log10(max_timescale), 61 | num_timescales, 62 | device=timesteps.device, 63 | ) 64 | emb = timesteps.to(dtype)[:, None] * inv_timescales[None, :] # (T, D/2) 65 | return torch.cat([emb.sin(), emb.cos()], dim=1) # (T, D) 66 | -------------------------------------------------------------------------------- /coatiLDM/data/decoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from rdkit import Chem 3 | import numpy as np 4 | 5 | 6 | def force_decode_valid_batch_efficient( 7 | V, 8 | encoder, 9 | tokenizer, 10 | max_attempts=64, 11 | inv_temp=1.5, 12 | k=2000, 13 | noise_scale=0.0, 14 | chiral=True, 15 | silent=False, 16 | ): 17 | 18 | # logger.debug(f"Running decode with max {max_attempts} attempts") 19 | device = V.device 20 | assert V.device == next(encoder.parameters()).device 21 | 22 | mols = ["" for _ in range(V.shape[0])] 23 | indices = list(range(V.shape[0])) 24 | vectors = V.detach().clone() 25 | 26 | for _ in range(max_attempts): 27 | with torch.no_grad(): 28 | assert vectors.dim() == 2 29 | if chiral: 30 | regen_smiles = encoder.hcoati_to_2d_batch( 31 | vectors, tokenizer, inv_temp=inv_temp, k=k, noise_scale=noise_scale 32 | ) 33 | else: 34 | regen_smiles = encoder.hclip_to_2d_batch( 35 | vectors, tokenizer, inv_temp=inv_temp, k=k, noise_scale=noise_scale 36 | ) 37 | fail_flag = [1 for k in range(vectors.shape[0])] 38 | for j in range(vectors.shape[0]): 39 | smiles = regen_smiles[j] 40 | if smiles == "C": 41 | continue 42 | try: 43 | mol = Chem.MolFromSmiles(smiles) 44 | if mol is None: 45 | raise Exception 46 | mols[indices[j]] = smiles 47 | fail_flag[j] = 0 48 | except Exception as e: 49 | continue 50 | vectors = ( 51 | vectors[ 52 | torch.tensor(fail_flag, dtype=torch.bool, device=vectors.device) 53 | ] 54 | .clone() 55 | .detach() 56 | ) 57 | indices = [indices[j] for j in range(len(indices)) if fail_flag[j]] 58 | if not silent: 59 | print(len(indices), " remaining ") 60 | if (len(indices)) == 0: 61 | break 62 | 63 | return mols 64 | -------------------------------------------------------------------------------- /coatiLDM/models/score_models/due_dflow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from due.dkl import DKL, GP 4 | from due.fc_resnet import FCResNet 5 | import smart_open 6 | import pickle 7 | 8 | 9 | def load_inference_due_for_dflow( 10 | load_as, 11 | input_dim=512, 12 | n_inducing_points=60, 13 | depth=4, 14 | remove_spectral_norm=False, 15 | passed_state_dict=None, 16 | ): 17 | """just loads model for inference. Doesn't require associated dataset. can remove norm if desired 18 | 19 | Args: 20 | load_as (str): model path 21 | input_dim (int): input dimensions. 22 | n_inducing_points(int, optional): need this to load correctly, uses default from basic_due 23 | depth (int, optional): default shared with basic_due, but can also be modified 24 | remove_spectral_norm (bool, optional): remove spectral norm if taking gradients. Defaults to False. 25 | """ 26 | 27 | features = 256 28 | num_outputs = 1 29 | spectral_normalization = True 30 | coeff = 0.95 31 | n_power_iterations = 2 32 | dropout_rate = 0.00 33 | 34 | feature_extractor = FCResNet( 35 | input_dim=input_dim, 36 | features=features, 37 | depth=depth, 38 | spectral_normalization=spectral_normalization, 39 | coeff=coeff, 40 | n_power_iterations=n_power_iterations, 41 | dropout_rate=dropout_rate, 42 | ) 43 | 44 | kernel = "RBF" 45 | # The following will be loaded just need right shapes 46 | initial_inducing_points = torch.zeros((n_inducing_points, features)) 47 | initial_lengthscale = torch.tensor(0.5) 48 | 49 | gp = GP( 50 | num_outputs=num_outputs, 51 | initial_lengthscale=initial_lengthscale, 52 | initial_inducing_points=initial_inducing_points, 53 | kernel=kernel, 54 | ) 55 | model = DKL(feature_extractor, gp) 56 | 57 | if passed_state_dict: 58 | read = passed_state_dict 59 | else: 60 | with smart_open.open(load_as, "rb") as f_in: 61 | read = torch.load(f_in) 62 | # read = torch.load(load_as) 63 | model.load_state_dict(read) 64 | 65 | if remove_spectral_norm: 66 | model.feature_extractor.first = torch.nn.utils.remove_spectral_norm( 67 | model.feature_extractor.first 68 | ) 69 | 70 | model.eval() 71 | return model 72 | -------------------------------------------------------------------------------- /coatiLDM/models/io.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from coatiLDM.models.diffusion_models.schedulers import DDPMScheduler 3 | from coatiLDM.models.score_models.non_conv_unet import NonConvUNet 4 | from coatiLDM.models.diffusion_models.flow_matching import ScoreNetCondVF 5 | from coatiLDM.models.score_models.due_cg_model import DueCG, save_due 6 | import gpytorch 7 | import torch 8 | import numpy as np 9 | from smart_open import open 10 | 11 | implemented_score_models = ["non_conv_unet"] 12 | 13 | 14 | def load_score_model(model_type, params, state_dict, device="cpu"): 15 | 16 | if model_type == "non_conv_unet": 17 | score_model = NonConvUNet(**params) 18 | 19 | else: 20 | raise Exception( 21 | f"bad score model currently implemented: {implemented_score_models}" 22 | ) 23 | 24 | score_model.load_state_dict(state_dict) 25 | return score_model.to(device).eval() 26 | 27 | 28 | def load_score_model_from_model_doc(doc_url, device="cpu"): 29 | with open(doc_url, "rb") as f_in: 30 | model_doc = pickle.loads(f_in.read(), encoding="UTF-8") 31 | model_kwargs = model_doc["score_model_params"] 32 | train_args = model_doc["train_args"] 33 | model = load_score_model( 34 | train_args["score_model"], model_kwargs, model_doc["model"], device=device 35 | ) 36 | return model.eval(), train_args, model_doc["norm_summary"] 37 | 38 | 39 | def load_flow_model_from_model_doc(doc_url, device="cpu"): 40 | model, train_args, norm_summary = load_score_model_from_model_doc( 41 | doc_url, device=device 42 | ) 43 | return ScoreNetCondVF(model), train_args, norm_summary 44 | 45 | 46 | def load_due_cg(due_params, state_dict, device="cpu"): 47 | due = DueCG(**due_params) 48 | if not due.initalized: 49 | dummy_size = due.n_inducing_points + 10 50 | dummy_x = torch.zeros((dummy_size, due.x_dim)) 51 | dummy_t = torch.zeros((dummy_size,)) 52 | due.initalize_model(dummy_x, dummy_t) 53 | due.load_state_dict(state_dict) 54 | due = due.to(device) 55 | due = due.eval() 56 | 57 | return due 58 | 59 | 60 | def load_due_cg_from_model_doc( 61 | doc_url, 62 | remove_spectral_norm=True, 63 | device="cpu", 64 | ): 65 | with open(doc_url, "rb") as f_in: 66 | model_doc = pickle.loads(f_in.read(), encoding="UTF-8") 67 | model_kwargs = model_doc["model_kwargs"] 68 | due = load_due_cg(model_kwargs, model_doc["model"], device=device) 69 | if remove_spectral_norm: 70 | due.feature_extractor.first = torch.nn.utils.remove_spectral_norm( 71 | due.feature_extractor.first 72 | ) 73 | return due.eval() 74 | -------------------------------------------------------------------------------- /coatiLDM/data/datapipe.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data.datapipes.iter import FileLister, Shuffler, IterableWrapper 3 | from torch.utils.data.datapipes.datapipe import IterDataPipe 4 | from torchdata.datapipes.iter import FileLister, InMemoryCacheHolder 5 | from torch.utils.data.datapipes._decorator import functional_datapipe 6 | import pickle 7 | 8 | 9 | @functional_datapipe("unstack_picklesv2") 10 | class UnstackPickles(IterDataPipe): 11 | def __init__(self, dp) -> None: 12 | super().__init__() 13 | self.dp = dp 14 | 15 | def __iter__(self): 16 | for X in self.dp: 17 | # print('loading... ',X) 18 | with open(X, "rb") as f: 19 | raw_rows = pickle.load(f) 20 | yield raw_rows 21 | 22 | 23 | def get_cache_pipe( 24 | cache_dir, masks=["*chunk*.pkl"], mem_cache=False, mem_cache_size=200000 25 | ): 26 | pipe = ( 27 | FileLister( 28 | root=cache_dir, 29 | recursive=False, 30 | masks=masks, 31 | ) 32 | .shuffle() 33 | # .open_files(mode="rb") 34 | .unstack_picklesv2() 35 | .unbatch() 36 | # .in_memory_cache(size=10_000) 37 | .shuffle(buffer_size=200_000) 38 | ) 39 | if mem_cache: 40 | pipe = pipe.in_memory_cache(size=mem_cache_size) 41 | return pipe 42 | 43 | 44 | def get_dist_pipe(data_path, cache_mask=["*.pkl"]): 45 | pipe = ( 46 | FileLister( 47 | root=data_path, 48 | recursive=False, 49 | masks=cache_mask, 50 | ) 51 | .shuffle() 52 | .unstack_picklesv2() 53 | .unbatch() 54 | .sharding_filter() 55 | .shuffle(buffer_size=20_000) 56 | .in_memory_cache(size=10_000) 57 | ) 58 | return pipe 59 | 60 | 61 | def get_base_pipe(data_path, load_type, cache_mask=["*chunk*.pkl"]): 62 | 63 | if load_type == "pickle": 64 | encoded_data = pickle.load(open(data_path, "rb")) 65 | base_pipe = IterableWrapper(encoded_data, deepcopy=False) 66 | elif load_type == "cache": 67 | base_pipe = get_cache_pipe(data_path, masks=cache_mask) 68 | elif load_type == "torch": 69 | encoded_data = pickle.load(data_path) 70 | base_pipe = IterableWrapper(encoded_data, deepcopy=False) 71 | # this is dumb, but I want this to work with unified train routines 72 | # and I'm godless. feel free to reformulate - Ben. 73 | elif load_type == "train_val_dict": 74 | split_dict = pickle.load(open(data_path, "rb")) 75 | encoded_data = split_dict["train"] 76 | base_pipe = IterableWrapper(encoded_data, deepcopy=False) 77 | else: 78 | raise ValueError("Unknown load type, choose from: [pickle, torch]") 79 | return base_pipe 80 | -------------------------------------------------------------------------------- /coatiLDM/trainers/train_ranknet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm, trange 3 | import torch.nn.functional as F 4 | import pickle as pkl 5 | from coatiLDM.common.utils import makedir, utc_epoch_now 6 | import os 7 | 8 | 9 | def get_reg_loss( 10 | score_i: torch.Tensor, score_j: torch.Tensor, regularization_factor: float = 1e-6 11 | ) -> torch.Tensor: 12 | """Returns regularization loss for the scores ||s||^2 / batch_size 13 | and scales it by `regularization_factor` 14 | """ 15 | batch_size = score_i.size(0) 16 | reg_loss = ( 17 | regularization_factor 18 | * (torch.norm(score_i) ** 2 + torch.norm(score_j) ** 2) 19 | / batch_size 20 | ) 21 | return reg_loss 22 | 23 | 24 | # target is whether or not smiles_j was chosen. 25 | # lower score is better - logit is higher if score_j < score_i. 26 | def get_loss(score_i, score_j, target): 27 | logit = (score_i - score_j).squeeze() 28 | return F.binary_cross_entropy_with_logits(logit, target.float()) 29 | 30 | 31 | def train_ranknet(train_pipe, model, lr=3e-4, epochs=10, device="cuda:0"): 32 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 33 | 34 | # stats = [] 35 | 36 | for epoch in range(epochs): 37 | print(f"epoch {epoch}") 38 | t = tqdm(train_pipe) 39 | for i, (emb_i, emb_j, labels) in enumerate(t): 40 | emb_i = emb_i.to(device) 41 | emb_j = emb_j.to(device) 42 | labels = labels.to(device) 43 | optimizer.zero_grad() 44 | (score_i, score_j, logit) = model(emb_i, emb_j) 45 | loss = get_loss(score_i, score_j, labels) + get_reg_loss(score_i, score_j) 46 | loss.backward() 47 | optimizer.step() 48 | t.set_description(f"Loss: {loss}") 49 | # if i % LOG_EVERY: 50 | # stats.append({"epoch": epoch, "batch": i, "loss": loss.cpu().detach()}) 51 | 52 | # with open(os.path.join(prefix, "stats.pkl"), "wb") as outf: 53 | # pkl.dump(stats, outf) 54 | 55 | 56 | def infer_pipe(test_pipe, model, device="cuda:0"): 57 | test_entries = [] 58 | for emb_i, emb_j, labels in test_pipe: 59 | emb_i = emb_i.to(device) 60 | emb_j = emb_j.to(device) 61 | labels = labels.to(device) 62 | 63 | with torch.inference_mode(): 64 | score_i, score_j, logits = model(emb_i, emb_j) 65 | 66 | for idx in range(emb_i.size()[0]): 67 | test_entries.append( 68 | { 69 | "score_i": float(score_i[idx]), 70 | "score_j": float(score_j[idx]), 71 | "logit": float(logits[idx]), 72 | "label": int(labels[idx]), 73 | } 74 | ) 75 | 76 | return test_entries 77 | -------------------------------------------------------------------------------- /coatiLDM/models/diffusion_models/schedulers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from tqdm.auto import tqdm 5 | 6 | EPSILON = 1e-7 7 | DEFAULT_BETA_START = 1e-4 8 | DEFAULT_BETA_END = 0.02 9 | 10 | 11 | def cosine_beta_schedule(timesteps, s=0.008): 12 | """ 13 | cosine schedule 14 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 15 | """ 16 | steps = timesteps + 1 17 | x = np.linspace(0, steps, steps) 18 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 19 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 20 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 21 | return np.clip(betas, a_min=0, a_max=0.99) 22 | 23 | 24 | class DDPMScheduler(torch.nn.Module): 25 | def __init__(self, schedule, timesteps, beta_start=1e-4, beta_end=0.02): 26 | super().__init__() 27 | self.diff = "ddpm" 28 | self.beta_start = beta_start 29 | self.beta_end = beta_end 30 | self.timesteps = timesteps 31 | self.schedule = schedule 32 | if self.schedule == "linear": 33 | self.register_buffer( 34 | "all_betas", 35 | torch.linspace(self.beta_start, self.beta_end, self.timesteps), 36 | ) 37 | elif self.schedule == "cosine": 38 | self.register_buffer( 39 | "all_betas", 40 | torch.tensor(cosine_beta_schedule(self.timesteps), dtype=torch.float), 41 | ) 42 | else: 43 | raise ValueError("unknown noise schedule type") 44 | self.register_buffer("all_alphas", 1.0 - self.all_betas) 45 | self.register_buffer( 46 | "all_bar_alphas", torch.cumprod(self.all_alphas, 0).clamp(0.0, 1.0) 47 | ) 48 | 49 | def beta(self, T): 50 | """ 51 | Exactly the beta schedule of Ho & Abeel. 52 | Args: 53 | T: torch. int tensor batch_size 54 | """ 55 | beta = self.all_betas[T.long()] 56 | return beta.unsqueeze(-1) 57 | 58 | def alpha(self, T): 59 | 60 | return self.all_alphas[T.long()].unsqueeze(-1) 61 | 62 | def bar_alpha(self, T): 63 | 64 | return self.all_bar_alphas[T.long()].clamp(0.0, 1.0).unsqueeze(-1) 65 | 66 | def is_same(self, other): 67 | """ 68 | Check if two instances of DDPMScheduler are functionally the same. 69 | Args: 70 | other: Another instance of DDPMScheduler. 71 | Returns: 72 | True if all values in all_betas are equal for both instances, False otherwise. 73 | """ 74 | # Ensure that the other instance is of the same class 75 | if not isinstance(other, DDPMScheduler): 76 | raise ValueError( 77 | "Comparison is only supported between instances of DDPMScheduler." 78 | ) 79 | 80 | # Check if all_betas are equal for both instances 81 | return torch.equal(self.all_betas, other.all_betas) 82 | -------------------------------------------------------------------------------- /coatiLDM/models/diffusion_models/particle_guidance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm.auto import tqdm 4 | 5 | 6 | def similarity_guidance_gradient(x): 7 | """ 8 | Produces the gradient of the K(X, X') similarity kernel 9 | Here we use the norm for similarity/distance 10 | """ 11 | with torch.no_grad(): 12 | diff = x.unsqueeze(1) - x.unsqueeze(0) 13 | distance = torch.norm(diff, p=2, dim=-1, keepdim=True) 14 | num_latents = x.shape[0] 15 | h_t = ( 16 | distance.mean(dim=1, keepdim=True) * num_latents / (num_latents - 1) 17 | ) ** 2 / np.log(num_latents) 18 | weights = torch.exp(-(distance**2 / h_t)) 19 | grad_phi = 2 * weights * diff / h_t * 2 20 | grad_phi = grad_phi.sum(dim=1) 21 | return -grad_phi 22 | 23 | 24 | def low_memory_cosine_guidance_gradient(vectors): 25 | with torch.enable_grad(): 26 | tore = [] 27 | B = vectors.clone().detach().requires_grad_(False) 28 | CS = torch.nn.CosineSimilarity() 29 | for k in range(vectors.shape[0]): 30 | A = vectors[k].clone().detach().requires_grad_(True) 31 | B_ = B.clone() 32 | B_[k] = 0.0 33 | tore.append( 34 | torch.autograd.grad( 35 | (10.0 * torch.erfinv(CS(A, B_).clamp(-0.999, 0.999))).exp().mean(), 36 | [A], 37 | )[0].detach() 38 | ) 39 | del A 40 | return torch.stack(tore, 0) 41 | 42 | 43 | def cosine_guidance_gradient(input_batch): 44 | """ 45 | Computes the gradient of the cosine distances between each vector in the batch. 46 | 47 | Args: 48 | input_batch (torch.Tensor): Input batch of vectors, shape (batch_size, vector_size). 49 | 50 | Returns: 51 | torch.Tensor: Gradient of cosine distances summed across dim=1, shape (batch_size, vector_size). 52 | """ 53 | 54 | with torch.enable_grad(): 55 | 56 | cloned_batch = input_batch.clone().detach().requires_grad_(True) 57 | 58 | # Compute cosine similarity matrix 59 | sim_adj = torch.nn.functional.cosine_similarity( 60 | cloned_batch.unsqueeze(1), cloned_batch.unsqueeze(0) 61 | ) 62 | 63 | sim_adj.fill_diagonal_(0) 64 | 65 | sim_adj = sim_adj.abs().sum() # .abs().sum() 66 | 67 | # Compute gradient of cosine similarity matrix 68 | cos_sim_grad = torch.autograd.grad(sim_adj, cloned_batch, create_graph=True) 69 | 70 | # Compute gradient of cosine distances 71 | cos_dist_grad = cos_sim_grad[0] 72 | 73 | return cos_dist_grad # .detach().cpu() 74 | 75 | 76 | def cosine_guidance_updated(A): 77 | with torch.enable_grad(): 78 | A_ = A.clone().detach().requires_grad_(True) 79 | A_nrm = A_ / ((A_ * A_).sum(-1, keepdims=True).sqrt()) 80 | no_diag = torch.einsum("ij,kj->ik", A_nrm, A_nrm) * ( 81 | 1.0 - torch.eye(A_.shape[0], device=A_.device, requires_grad=True) 82 | ) 83 | 84 | return torch.autograd.grad(no_diag.clamp(-0.9999, 0.9999).mean(), [A_])[0] 85 | -------------------------------------------------------------------------------- /coatiLDM/models/coati/io.py: -------------------------------------------------------------------------------- 1 | # model loading function for a molclip model. 2 | from io import BytesIO 3 | import pickle 4 | 5 | import torch 6 | 7 | from coatiLDM.common.s3 import cache_read 8 | from coatiLDM.models.coati.transformer_only import COATI_Smiles_Inference 9 | from coatiLDM.models.coati.trie_tokenizer import TrieTokenizer 10 | from coatiLDM.models.coati.tokenizers import get_vocab 11 | 12 | 13 | class CPU_Unpickler(pickle.Unpickler): 14 | def find_class(self, module, name): 15 | if module == "torch.storage" and name == "_load_from_bytes": 16 | return lambda b: torch.load(BytesIO(b), map_location="cpu") 17 | else: 18 | return super().find_class(module, name) 19 | 20 | 21 | def load_coati2( 22 | doc_url: str, 23 | device: str = "cpu", 24 | freeze: bool = True, 25 | old_architecture=False, 26 | force_cpu=False, # needed to deserialize on some cpu-only machines 27 | ): 28 | 29 | print(f"Loading model from {doc_url}") 30 | 31 | with cache_read(doc_url, "rb") as f_in: 32 | if force_cpu: 33 | model_doc = CPU_Unpickler(f_in, encoding="UTF-8").load() 34 | else: 35 | model_doc = pickle.loads(f_in.read(), encoding="UTF-8") 36 | model_kwargs = model_doc["model_kwargs"] 37 | 38 | model_dict_ = model_doc["model"] 39 | new_names = [ 40 | k.replace("module.", "") if k.startswith("module.") else k 41 | for k in model_dict_.keys() 42 | ] 43 | state_dict = {new_name: t for new_name, t in zip(new_names, model_dict_.values())} 44 | 45 | tokenizer_vocab = model_doc["train_args"]["tokenizer_vocab"] 46 | print(f"Loading tokenizer {tokenizer_vocab} from {doc_url}") 47 | 48 | if old_architecture: 49 | model_kwargs["old_architecture"] = True 50 | 51 | if "device" in model_kwargs: 52 | model_kwargs["device"] = device 53 | 54 | # Let's just be explicit for our use case these are the values for the reduced model 55 | updated_kwargs = { 56 | "n_layer_xformer": model_kwargs["n_layer_xformer"], 57 | "n_hidden_xformer": model_kwargs["n_hidden_xformer"], 58 | "embed_dim": model_kwargs["embed_dim"], 59 | "n_head": model_kwargs["n_head"], 60 | "n_seq": model_kwargs["n_seq"], 61 | "mlp_dropout": model_kwargs["mlp_dropout"], 62 | "enc_to_coati": model_kwargs["enc_to_coati"], 63 | "n_direct_clr": model_kwargs["n_direct_clr"], 64 | "n_tok": model_kwargs["n_tok"], 65 | "biases": model_kwargs["biases"], 66 | "device": model_kwargs["device"], 67 | "dtype": model_kwargs["dtype"], 68 | } 69 | 70 | model = COATI_Smiles_Inference(**updated_kwargs) 71 | 72 | model.load_state_dict(state_dict, strict=False) 73 | model.to(device) 74 | model.device = device 75 | tokenizer = TrieTokenizer(n_seq=model_kwargs["n_seq"], **get_vocab(tokenizer_vocab)) 76 | 77 | if freeze: 78 | print("Freezing encoder") 79 | n_params = 0 80 | for param in model.parameters(): 81 | param.requires_grad = False 82 | n_params += param.numel() 83 | print(f"{n_params } params frozen!") 84 | return model, tokenizer 85 | -------------------------------------------------------------------------------- /coatiLDM/common/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | import shutil 4 | import datetime 5 | from datetime import timezone 6 | from rdkit import Chem 7 | from rdkit.Chem.AllChem import ( 8 | GetMorganFingerprintAsBitVect, 9 | ) 10 | import numpy as np 11 | 12 | import torch 13 | 14 | 15 | def makedir(path: str, isfile: bool = False): 16 | """ 17 | Creates a directory given a path to either a directory or file. 18 | If a directory is provided, creates that directory. If a file is provided (i.e. isfile == True), 19 | creates the parent directory for that file. 20 | :param path: Path to a directory or file. 21 | :param isfile: Whether the provided path is a directory or file. 22 | """ 23 | if isfile: 24 | path = os.path.dirname(path) 25 | if path != "": 26 | os.makedirs(path, exist_ok=True) 27 | 28 | 29 | def rmdir(path: str): 30 | """ 31 | Creates a directory given a path to either a directory or file. 32 | If a directory is provided, creates that directory. If a file is provided (i.e. isfile == True), 33 | creates the parent directory for that file. 34 | :param path: Path to a directory or file. 35 | :param isfile: Whether the provided path is a directory or file. 36 | """ 37 | try: 38 | shutil.rmtree(path) 39 | except Exception as Ex: 40 | print("rmdir failure", Ex) 41 | 42 | 43 | def utc_epoch_now(): 44 | return datetime.datetime.now().replace(tzinfo=timezone.utc).timestamp() 45 | 46 | 47 | def uniform_sample_in_range(sample_shape, a, b): 48 | return ((b - a) * torch.rand(sample_shape) + a).numpy() 49 | 50 | 51 | def batch_iterable(iterable, n=128): 52 | if isinstance(iterable, list): 53 | iterable = iter(iterable) 54 | 55 | while True: 56 | batch = list(itertools.islice(iterable, n)) 57 | if not batch: 58 | break 59 | yield batch 60 | 61 | 62 | def mol_to_morgan( 63 | smiles: str, 64 | radius: int = 3, 65 | n_bits: int = 2048, 66 | chiral: bool = False, 67 | features: bool = False, 68 | ) -> np.ndarray: 69 | # if any([a.GetAtomicNum()==1 for a in mol.GetAtoms()]): 70 | # print(f'WARNING: mol has hydrogens during morgan creation: "{Chem.MolToSmiles(mol)}"') 71 | mol = Chem.MolFromSmiles(smiles) 72 | return np.frombuffer( 73 | GetMorganFingerprintAsBitVect( 74 | mol, 75 | radius=radius, 76 | nBits=n_bits, 77 | useChirality=chiral, 78 | useFeatures=features, 79 | ) 80 | .ToBitString() 81 | .encode(), 82 | "u1", 83 | ) - ord("0") 84 | 85 | 86 | def tanimoto_distance_torch(A, B): 87 | A = torch.tensor(A, dtype=torch.float32).to(B.device) 88 | dot_products = torch.mm(A, B.T) 89 | norm_A = torch.sum(A**2, axis=1) 90 | norm_B = torch.sum(B**2, axis=1) 91 | distances = 1 - dot_products / (norm_A[:, None] + norm_B[None, :] - dot_products) 92 | return distances 93 | 94 | 95 | def colored_background(r: int, g: int, b: int, text): 96 | """ 97 | r,g,b integers between 0,255 98 | """ 99 | return f"\033[48;2;{r};{g};{b}m{text}\033[0m" 100 | -------------------------------------------------------------------------------- /coatiLDM/models/diffusion_models/dflow.py: -------------------------------------------------------------------------------- 1 | # load data 2 | import torch 3 | from torchdiffeq import odeint 4 | 5 | 6 | def dflow( 7 | x_start, 8 | target_set, 9 | vec_field_net, 10 | target_net, 11 | learning_rate=1.0, 12 | decode_steps=100, 13 | opt_steps=5, 14 | device="cuda:0", 15 | ): 16 | """ 17 | Performs the DFlow optimization. 18 | 19 | Args: 20 | x_start (torch.Tensor): Starting tensor for the optimization. 21 | target_value (float): The target value to aim for in the optimization. 22 | vec_field_net (nn.Module): Vector field network. 23 | target_net (nn.Module): Target network. 24 | learning_rate (float): Learning rate for the optimizer. 25 | decode_steps (int): Number of decoding steps. 26 | opt_steps (int): Number of optimization steps. 27 | device (str): Device to run the computation on. 28 | 29 | Returns: 30 | torch.Tensor: The optimized tensor. 31 | """ 32 | 33 | # def wrapper(t, x): 34 | # t = t * torch.ones(len(x), device=device) 35 | # return vec_field_net(x, t) 36 | 37 | def closure(): 38 | optimizer.zero_grad() 39 | x_1 = odeint(vec_field_net, x_0, t, method="midpoint") 40 | mse_loss = torch.pow(target_net(x_1).mean - target_tensor, 2).mean() 41 | mse_loss.backward(retain_graph=True) 42 | return mse_loss 43 | 44 | batch_size = x_start.shape[0] 45 | 46 | target_tensor = target_set.clone().to(device) 47 | # print(target_tensor.shape) 48 | 49 | t = torch.linspace(0, 1, steps=decode_steps, device=device) 50 | x_0 = x_start.clone().detach().requires_grad_(True) # Ensure x_0 requires grad 51 | 52 | optimizer = torch.optim.LBFGS([x_0], lr=learning_rate, max_iter=5) 53 | 54 | for ii in range(opt_steps): 55 | optimizer.step(closure) 56 | # x_0 = x_0.detach().requires_grad_(True) 57 | 58 | x_1 = odeint(vec_field_net, x_0, t, method="midpoint") 59 | return x_1.detach() 60 | 61 | 62 | def dflow_multi( 63 | x_start, 64 | target_sets, 65 | vec_field_net, 66 | target_nets, 67 | learning_rate=1.0, 68 | decode_steps=100, 69 | opt_steps=5, 70 | device="cuda:0", 71 | ): 72 | """ 73 | Performs the DFlow optimization for multiple targets. 74 | 75 | Args: 76 | x_start (torch.Tensor): Starting tensor for the optimization. 77 | target_values (list of float): List of target values to aim for in the optimization. 78 | vec_field_net (nn.Module): Vector field network. 79 | target_nets (list of nn.Module): List of target networks. 80 | learning_rate (float): Learning rate for the optimizer. 81 | decode_steps (int): Number of decoding steps. 82 | opt_steps (int): Number of optimization steps. 83 | device (str): Device to run the computation on. 84 | 85 | Returns: 86 | torch.Tensor: The optimized tensor. 87 | """ 88 | 89 | def closure(): 90 | optimizer.zero_grad() 91 | x_1 = odeint(vec_field_net, x_0, t, method="midpoint") 92 | mse_losses = [ 93 | torch.pow(net(x_1).mean - target_tensor, 2).mean() 94 | for net, target_tensor in zip(target_nets, target_tensors) 95 | ] 96 | total_loss = sum(mse_losses) 97 | total_loss.backward(retain_graph=True) 98 | return total_loss 99 | 100 | batch_size = x_start.shape[0] 101 | 102 | target_tensors = [val.clone().to(device) for val in target_sets] 103 | 104 | t = torch.linspace(0, 1, steps=decode_steps, device=device) 105 | x_0 = x_start.clone().detach().requires_grad_(True) 106 | 107 | optimizer = torch.optim.LBFGS([x_0], lr=learning_rate, max_iter=5) 108 | 109 | for ii in range(opt_steps): 110 | optimizer.step(closure) 111 | 112 | x_1 = odeint(vec_field_net, x_0, t, method="midpoint") 113 | result = x_1.detach() 114 | del x_1, x_0, target_tensors, optimizer 115 | torch.cuda.empty_cache() # Clear CUDA cache 116 | 117 | return result 118 | -------------------------------------------------------------------------------- /coatiLDM/trainers/train_cg_due.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from coatiLDM.models.score_models.due_cg_model import DueCG, save_due 3 | from gpytorch.mlls import VariationalELBO 4 | from tqdm.auto import tqdm 5 | 6 | from coatiLDM.models.diffusion_models.schedulers import DDPMScheduler 7 | from coatiLDM.data.datapipe import get_base_pipe 8 | from coatiLDM.data.transforms import cg_xform_routine 9 | from coatiLDM.common.utils import makedir, utc_epoch_now 10 | import argparse 11 | import os 12 | 13 | 14 | def train_cg_due(datapipe, due_params, n_samples, lr=1e-3, epochs=100, device="cuda:0"): 15 | train_samples = [] 16 | Ts = [] 17 | total = 0 18 | for i, batch in enumerate(datapipe): 19 | 20 | if total < n_samples: 21 | train_samples.append(batch["noised_samples"]) 22 | Ts.append(batch["T"]) 23 | total += len(batch["noised_samples"]) 24 | train_samples = torch.cat(train_samples, dim=0)[:n_samples] 25 | train_Ts = torch.cat(Ts, dim=0)[:n_samples] 26 | model = DueCG(**due_params).to("cpu") 27 | model.initalize_model(X=train_samples, T=train_Ts) 28 | model = model.to(device) 29 | elbo = VariationalELBO(model.likelihood, model.gp, num_data=total) 30 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 31 | 32 | for epoch in range(epochs): 33 | t = tqdm(datapipe, desc=f"Epoch {epoch}, Loss: ") 34 | model.train() 35 | avg_loss = 0 36 | for i, batch in enumerate(t): 37 | feats = batch["noised_samples"] 38 | targets = batch["target"] 39 | Ts = batch["T"] 40 | optimizer.zero_grad() 41 | preds = model(feats.to(device), Ts.to(device)) 42 | loss = -elbo(preds, targets.to(device)) 43 | loss.backward() 44 | optimizer.step() 45 | t.set_description(f"Epoch {epoch}, Loss: {loss.item():.2f}") 46 | avg_loss += loss.item() 47 | 48 | avg_loss /= i + 1 49 | return model 50 | 51 | 52 | def train_cg(args): 53 | 54 | base_pipe = get_base_pipe(args.data_path, args.load_type) 55 | x_dim = next(iter(base_pipe))[args.x_field].shape[-1] 56 | 57 | scheduler = DDPMScheduler( 58 | schedule=args.schedule, timesteps=args.timesteps, beta_start=1e-4, beta_end=0.02 59 | ) 60 | sched_bar_alphas = scheduler.all_bar_alphas.clone().detach().cpu() 61 | datapipe = ( 62 | base_pipe.shuffle() 63 | .batch(args.batch_size) 64 | .collate( 65 | lambda batch: cg_xform_routine( 66 | batch, 67 | x_field=args.x_field, 68 | scalar_field=args.scalar_field, 69 | timesteps=args.timesteps, 70 | bar_alphas=sched_bar_alphas, 71 | ) 72 | ) 73 | ) 74 | 75 | print("obtaining test batch... ") 76 | test_batch = next(iter(datapipe)) 77 | x_dim = test_batch["unnoised_samples"].shape[-1] 78 | 79 | due_params = { 80 | "scheduler": scheduler, 81 | "scalar_name": args.scalar_field, 82 | "time_embed_dim": args.time_dim, 83 | "train_data_sample": None, # their implementation only uses 1k samples 84 | "x_dim": x_dim, 85 | "depth": 4, 86 | "num_outputs": 1, 87 | "spectral_normalization": True, 88 | "n_inducing_points": args.n_inducing_points, 89 | "soft_norm_coeff": args.soft_norm_coeff, 90 | "n_power_iterations": args.n_power_iterations, 91 | "dropout_rate": args.dropout_rate, 92 | "kernel": "RBF", 93 | } 94 | model = train_cg_due( 95 | datapipe, 96 | due_params, 97 | args.n_samples, 98 | lr=args.lr, 99 | epochs=args.num_epochs, 100 | device=args.device, 101 | ) 102 | params = vars(args) 103 | 104 | model.eval() 105 | model = model.to("cpu") 106 | output_path = os.path.join( 107 | args.model_dir, f"{args.exp_name}_{args.run_name}_final.pkl" 108 | ) 109 | serialized_artifact = save_due(due_params, model) 110 | 111 | with open(output_path, "wb") as f_out: 112 | f_out.write(serialized_artifact) 113 | 114 | print("saved model to: ", output_path) 115 | return model 116 | 117 | 118 | def do_args(mock_args=False): 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument("--data_path", type=str, default=None) 121 | parser.add_argument("--exp_name", type=str, default="cg_model") 122 | parser.add_argument("--run_name", type=str, default=str(int(utc_epoch_now()))) 123 | parser.add_argument("--model_dir", type=str, default="./") 124 | parser.add_argument("--load_type", type=str, default="pickle") 125 | parser.add_argument("--n_samples", type=int, default=10000) 126 | parser.add_argument("--x_field", type=str, default="normd_vector") 127 | parser.add_argument("--no_noise", type=bool, default=False) 128 | 129 | parser.add_argument("--scalar_field", type=list, default="normd_logp") 130 | parser.add_argument("--dropout_rate", type=float, default=0.03) 131 | parser.add_argument("--n_inducing_points", type=int, default=60) 132 | parser.add_argument("--soft_norm_coeff", type=float, default=0.95) 133 | parser.add_argument("--n_power_iterations", type=int, default=2) 134 | 135 | parser.add_argument("--time_dim", type=int, default=16) 136 | parser.add_argument("--timesteps", type=int, default=1000) 137 | 138 | parser.add_argument("--device", type=str, default="cuda:0") 139 | parser.add_argument("--num_epochs", type=int, default=50) 140 | parser.add_argument("--batch_size", type=int, default=896) 141 | parser.add_argument("--gamma", type=float, default=0.9) 142 | parser.add_argument("--weight_decay", type=float, default=1e-4) 143 | parser.add_argument("--lr", type=float, default=2e-4) 144 | parser.add_argument("--schedule", type=str, default="linear") 145 | if mock_args: 146 | args = parser.parse_args(args=[]) 147 | else: 148 | args = parser.parse_args() 149 | return args 150 | 151 | 152 | if __name__ == "__main__": 153 | args = do_args() 154 | train_cg(args) 155 | -------------------------------------------------------------------------------- /coatiLDM/trainers/train_cg_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from coatiLDM.models.score_models.resnet import ResAttnNetWithTime 3 | from gpytorch.mlls import VariationalELBO 4 | from tqdm.auto import tqdm 5 | from coatiLDM.models.diffusion_models.schedulers import DDPMScheduler 6 | from coatiLDM.data.datapipe import get_base_pipe 7 | from coatiLDM.data.transforms import cg_xform_routine 8 | from coatiLDM.common.utils import makedir, utc_epoch_now 9 | import argparse 10 | import os 11 | 12 | 13 | def train_resnet( 14 | datapipe, 15 | resnet_params, 16 | n_samples, 17 | lr=1e-3, 18 | epochs=100, 19 | mode="regression", 20 | device="cuda:0", 21 | ): 22 | t_embed_dim = resnet_params["t_emb_dim"] 23 | resnet_params = {k: v for k, v in resnet_params.items() if k != "t_emb_dim"} 24 | model = ResAttnNetWithTime(resnet_params, t_emb_dim=t_embed_dim).to(device) 25 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 26 | if mode == "regression": 27 | loss_fn = torch.nn.functional.mse_loss 28 | else: 29 | # loss for binary classification 30 | loss_fn = torch.nn.functional.binary_cross_entropy_with_logits 31 | for epoch in range(epochs): 32 | t = tqdm(datapipe, desc=f"Epoch {epoch}, Loss: ") 33 | model.train() 34 | avg_loss = 0 35 | for i, batch in enumerate(t): 36 | feats = batch["noised_samples"] 37 | targets = batch["target"] 38 | Ts = batch["T"] 39 | optimizer.zero_grad() 40 | preds = model(feats.to(device), Ts.to(device)) 41 | loss = loss_fn(preds, targets.view(preds.shape).to(device)) 42 | loss.backward() 43 | optimizer.step() 44 | t.set_description(f"Epoch {epoch}, Loss: {loss.item():.2f}") 45 | avg_loss += loss.item() 46 | 47 | avg_loss /= i + 1 48 | return model 49 | 50 | 51 | def train_resnet_with_time(args): 52 | 53 | tags = { 54 | "data_path": f"norm_summary__{args.data_path}", 55 | "run_name": f"run__{args.run_name}", 56 | "target": args.scalar_field, 57 | } 58 | 59 | base_pipe = get_base_pipe(args.data_path, args.load_type) 60 | x_dim = next(iter(base_pipe))[args.x_field].shape[-1] 61 | 62 | scheduler = DDPMScheduler( 63 | schedule=args.schedule, timesteps=args.timesteps, beta_start=1e-4, beta_end=0.02 64 | ) 65 | sched_bar_alphas = scheduler.all_bar_alphas.clone().detach().cpu() 66 | datapipe = ( 67 | base_pipe.shuffle() 68 | .batch(args.batch_size) 69 | .collate( 70 | lambda batch: cg_xform_routine( 71 | batch, 72 | x_field=args.x_field, 73 | scalar_field=args.scalar_field, 74 | timesteps=999, 75 | bar_alphas=sched_bar_alphas, 76 | ) 77 | ) 78 | ) 79 | 80 | print("obtaining test batch... ") 81 | test_batch = next(iter(datapipe)) 82 | x_dim = test_batch["unnoised_samples"].shape[-1] 83 | 84 | resnet_params = { 85 | "input_dim": x_dim, 86 | "hidden_dim": args.hidden_dim, 87 | "activation": args.activation, 88 | "n_heads": args.n_heads, 89 | "specnorm": args.specnorm, 90 | "depth": args.depth, 91 | "mup": False, 92 | "t_emb_dim": args.time_dim, 93 | "output_dim": 1, 94 | } 95 | model = train_resnet( 96 | datapipe, 97 | resnet_params, 98 | args.n_samples, 99 | lr=args.lr, 100 | epochs=args.num_epochs, 101 | mode=args.mode, 102 | device=args.device, 103 | ) 104 | params = vars(args) 105 | model.eval() 106 | model = model.to("cpu") 107 | output_path = os.path.join(args.model_dir, f"{args.exp_name}_{args.run_name}_final") 108 | torch.save(model, output_path + ".pt") 109 | print("saved model to: ", output_path + ".pt") 110 | return model 111 | 112 | 113 | def do_args(mock_args=False): 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument("--data_path", type=str, default=None) 116 | parser.add_argument("--exp_name", type=str, default="cg_model") 117 | parser.add_argument("--run_name", type=str, default=str(int(utc_epoch_now()))) 118 | parser.add_argument("--model_dir", type=str, default="./") 119 | parser.add_argument("--load_type", type=str, default="pickle") 120 | parser.add_argument("--n_samples", type=int, default=10000) 121 | parser.add_argument("--x_field", type=str, default="normd_vector") 122 | parser.add_argument("--no_noise", type=bool, default=False) 123 | 124 | parser.add_argument("--scalar_field", type=list, default="normd_logp") 125 | parser.add_argument("--dropout_rate", type=float, default=0.03) 126 | parser.add_argument("--n_inducing_points", type=int, default=60) 127 | parser.add_argument("--soft_norm_coeff", type=float, default=0.95) 128 | parser.add_argument("--n_power_iterations", type=int, default=2) 129 | 130 | parser.add_argument("--time_dim", type=int, default=16) 131 | parser.add_argument("--timesteps", type=int, default=1000) 132 | 133 | parser.add_argument("--device", type=str, default="cuda:0") 134 | parser.add_argument("--num_epochs", type=int, default=50) 135 | parser.add_argument("--batch_size", type=int, default=896) 136 | parser.add_argument("--gamma", type=float, default=9) 137 | parser.add_argument("--weight_decay", type=float, default=1e-4) 138 | parser.add_argument("--lr", type=float, default=2e-4) 139 | parser.add_argument("--schedule", type=str, default="linear") 140 | parser.add_argument("--hidden_dim", type=int, default=256) 141 | parser.add_argument("--activation", type=str, default="silu") 142 | parser.add_argument("--n_heads", type=int, default=4) 143 | parser.add_argument("--specnorm", type=bool, default=False) 144 | parser.add_argument("--mode", type=str, default="regression") 145 | parser.add_argument("--depth", type=int, default=12) 146 | if mock_args: 147 | args = parser.parse_args(args=[]) 148 | else: 149 | args = parser.parse_args() 150 | 151 | return args 152 | 153 | 154 | if __name__ == "__main__": 155 | args = do_args() 156 | train_resnet_with_time(args) 157 | -------------------------------------------------------------------------------- /coatiLDM/models/score_models/due_cg_model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | import numpy as np 4 | import torch 5 | from due.dkl import DKL, GP, _get_initial_inducing_points, _get_initial_lengthscale 6 | from due.fc_resnet import FCResNet 7 | from gpytorch.likelihoods import GaussianLikelihood 8 | from torch import nn 9 | from coatiLDM.data.transforms import cg_xform_routine 10 | from coatiLDM.models.score_models.score_model import get_time_embedding 11 | import pickle 12 | from smart_open import open 13 | from coatiLDM.data.datapipe import get_base_pipe 14 | import matplotlib.pyplot as plt 15 | 16 | 17 | def initial_values( 18 | train_sample: torch.Tensor, feature_extractor: nn.Module, n_inducing_points: int 19 | ): 20 | with torch.no_grad(): 21 | f_X_samples = feature_extractor(train_sample) 22 | 23 | initial_inducing_points = _get_initial_inducing_points( 24 | f_X_samples.numpy(), n_inducing_points 25 | ) 26 | initial_lengthscale = _get_initial_lengthscale(f_X_samples) 27 | 28 | return initial_inducing_points, initial_lengthscale 29 | 30 | 31 | class DueCG(nn.Module): 32 | def __init__( 33 | self, 34 | scheduler, 35 | time_embed_dim, 36 | train_data_sample: torch.Tensor = None, # their implementation only uses 1k samples 37 | x_dim: int = 256, 38 | scalar_name: str = "logp", 39 | # features: int = 256, 40 | depth: int = 4, 41 | num_outputs: int = 1, 42 | spectral_normalization: bool = True, 43 | n_inducing_points=60, 44 | soft_norm_coeff: float = 0.95, 45 | n_power_iterations=2, 46 | dropout_rate=0.03, 47 | kernel="RBF", 48 | ) -> None: 49 | super().__init__() 50 | 51 | self.x_dim = x_dim 52 | self.scalar_name = scalar_name 53 | self.time_embed_dim = time_embed_dim 54 | self.input_dim = x_dim + time_embed_dim 55 | self.features = x_dim + time_embed_dim 56 | self.depth = depth 57 | self.num_outputs = num_outputs 58 | self.spectral_normalization = spectral_normalization 59 | self.n_inducing_points = n_inducing_points 60 | self.soft_norm_coeff = soft_norm_coeff 61 | self.n_power_iterations = n_power_iterations 62 | self.dropout_rate = dropout_rate 63 | self.kernel = kernel 64 | self.scheduler = scheduler 65 | self.initalized = False 66 | 67 | self.feature_extractor = FCResNet( 68 | input_dim=self.input_dim, 69 | features=self.features, 70 | depth=depth, 71 | spectral_normalization=spectral_normalization, 72 | coeff=soft_norm_coeff, 73 | n_power_iterations=n_power_iterations, 74 | dropout_rate=dropout_rate, 75 | ) 76 | self.gp = None 77 | self.dkl = None 78 | self.n_inducing_points = n_inducing_points 79 | self.num_outputs = num_outputs 80 | self.kernel = kernel 81 | self.initalized = False 82 | self.likelihood = GaussianLikelihood() 83 | 84 | if train_data_sample is not None: 85 | self.initalize_model(train_data_sample["X"], train_data_sample=["T"]) 86 | 87 | def noise_train_data_sample(self, train_data_sample): 88 | batch_size = train_data_sample.shape[0] 89 | T = torch.randint( 90 | low=0, high=self.scheduler.timesteps, size=(batch_size,), device=self.device 91 | ) 92 | t_embed = get_time_embedding(T.float(), self.time_embed_dim) 93 | noise = torch.randn((batch_size, self.x_dim), device=self.device) 94 | noisy_samples = ( 95 | self.scheduler.bar_alpha(T).sqrt() * train_data_sample 96 | + (1.0 - self.scheduler.bar_alpha(T)).sqrt() * noise 97 | ) 98 | samp_with_noise = torch.cat([noisy_samples, t_embed], dim=1) 99 | return samp_with_noise 100 | 101 | def initalize_model(self, X: torch.Tensor, T: torch.Tensor): 102 | with torch.no_grad(): 103 | t_embed = get_time_embedding(T.float(), self.time_embed_dim) 104 | train_data_sample = torch.cat([X, t_embed], dim=1) 105 | initial_inducing_points, initial_lengthscale = initial_values( 106 | train_data_sample, self.feature_extractor, self.n_inducing_points 107 | ) 108 | self.gp = GP( 109 | num_outputs=self.num_outputs, 110 | initial_lengthscale=initial_lengthscale, 111 | initial_inducing_points=initial_inducing_points, 112 | kernel=self.kernel, 113 | ) 114 | self.dkl = DKL(self.feature_extractor, self.gp) 115 | self.initalized = True 116 | 117 | @property 118 | def device(self): 119 | return next(self.parameters()).device 120 | 121 | def forward(self, x, T): 122 | t_embed = get_time_embedding(T.float(), self.time_embed_dim) 123 | full_rep = torch.cat([x, t_embed], dim=1) 124 | return self.dkl(full_rep) 125 | 126 | 127 | def save_due(params, model) -> bytes: 128 | scheduler_args = { 129 | "schedule": model.scheduler.schedule, 130 | "timesteps": model.scheduler.timesteps, 131 | "beta_start": model.scheduler.beta_start, 132 | "beta_end": model.scheduler.beta_end, 133 | } 134 | due_serialized = { 135 | "model_kwargs": {key: val for key, val in params.items() if key != "n_samples"}, 136 | # "scheduler_kwargs": scheduler_args, 137 | "model": model.to("cpu").state_dict(), 138 | } 139 | return pickle.dumps(due_serialized) 140 | 141 | 142 | def get_due_batch_pipe( 143 | pickle_path, due, x_field="emb_smiles", batch_size=2048, load_type="pickle" 144 | ): 145 | 146 | base_pipe = get_base_pipe(pickle_path, load_type) 147 | sched_bar_alphas = due.scheduler.all_bar_alphas.clone().detach().cpu() 148 | datapipe = ( 149 | base_pipe.shuffle() 150 | .batch(batch_size) 151 | .collate( 152 | lambda batch: cg_xform_routine( 153 | batch, 154 | x_field=x_field, 155 | scalar_field=due.scalar_name, 156 | timesteps=due.scheduler.timesteps, 157 | bar_alphas=sched_bar_alphas, 158 | ) 159 | ) 160 | ) 161 | 162 | return datapipe 163 | -------------------------------------------------------------------------------- /coatiLDM/models/coati/tokenizers/trie_tokenizer.py: -------------------------------------------------------------------------------- 1 | from coatiLDM.common.utils import colored_background 2 | from coatiLDM.models.coati.tokenizers.trie import Trie 3 | import torch 4 | from typing import Tuple, List 5 | 6 | 7 | class TrieTokenizer: 8 | """ 9 | Converts smiles+sentinel tokens into a list of integers. 10 | """ 11 | 12 | def __init__( 13 | self, 14 | n_seq=256, # The dimension of the token embedding. 15 | smiles_tokens=[], 16 | special_tokens=[], 17 | side_tasks=True, 18 | ): 19 | self.n_seq = n_seq 20 | self.special_tokens = special_tokens 21 | self.smiles_tokens = smiles_tokens 22 | self.keys = self.special_tokens + self.smiles_tokens 23 | self.n_token = len(self.keys) # base number of tokens. 24 | self.vocab = {T.strip(): I for I, T in enumerate(self.keys)} 25 | 26 | # I am human, after all. 27 | # These are tokens wrt, the model should be uniform (loss masked) 28 | self.stop_token = self.vocab["[STOP]"] 29 | self.pad_token = self.vocab["[PAD]"] 30 | 31 | self.clip_token = self.vocab["[CLIP]"] 32 | self.unk_token = self.vocab["[UNK]"] 33 | self.smiles_token = self.vocab["[SMILES]"] 34 | self.suffix_token = self.vocab["[SUFFIX]"] 35 | self.middle_token = self.vocab["[MIDDLE]"] 36 | if side_tasks: 37 | self.graph_token = self.vocab["[GRAPH]"] 38 | self.formula_token = self.vocab["[FORMULA]"] 39 | self.set_token = self.vocab["[SET]"] 40 | 41 | self.smiles_trie = Trie() 42 | self.special_trie = Trie() 43 | for k in self.special_tokens: 44 | self.special_trie.add(k) 45 | for k in self.smiles_tokens: 46 | self.smiles_trie.add(k) 47 | 48 | def pre_tokenize(self, text): 49 | """ 50 | Splits the special tokens first. 51 | """ 52 | split0 = self.special_trie.split(text) 53 | tokens = [] 54 | for T in split0: 55 | if T in self.special_tokens: 56 | tokens.append(T) 57 | else: 58 | tokens.extend(self.smiles_trie.split(T)) 59 | return tokens 60 | 61 | def tokenize_text( 62 | self, text: str, pad: bool = True, range_check: bool = True 63 | ) -> List[int]: 64 | """ 65 | Tokenizes a single row. 66 | """ 67 | try: 68 | tore = [self.vocab[T] for T in self.pre_tokenize(text)] 69 | if len(tore) > self.n_seq and range_check: 70 | raise Exception("Oversized String", len(tore)) 71 | if pad: 72 | tore = tore + [ 73 | self.vocab["[PAD]"] for k in range(self.n_seq - len(tore)) 74 | ] 75 | except Exception as Ex: 76 | print("tokenize text exception... ", text, Ex, self.pre_tokenize(text)) 77 | raise Ex 78 | return tore 79 | 80 | def batch_smiles( 81 | self, smiles_batch: List[str], device: str = "cpu", skip_failed: bool = False 82 | ) -> Tuple[torch.Tensor, List[int]]: 83 | token_stack = [] 84 | bad_idxs = [] 85 | for idx, smi in enumerate(smiles_batch): 86 | try: 87 | ttext = self.tokenize_text( 88 | "[SMILES]" + smi + "[STOP]", pad=False, range_check=False 89 | ) 90 | except KeyError as e: 91 | if skip_failed: # filling with a dummy string, and adding to bad_idxs 92 | ttext = self.tokenize_text( 93 | "[SMILES]" + "C" + "[STOP]", pad=False, range_check=False 94 | ) 95 | bad_idxs.append(idx) 96 | else: 97 | raise e 98 | 99 | if len(ttext) <= self.n_seq: 100 | t = torch.zeros(self.n_seq, dtype=torch.long, device=device) 101 | t[: len(ttext)] = torch.tensor(ttext) 102 | token_stack.append(t) 103 | else: 104 | bad_idxs.append(idx) 105 | 106 | new_smi_batch = torch.stack(token_stack, 0) 107 | new_smi_batch = new_smi_batch[:, : (new_smi_batch.sum(0) > 0).sum()] 108 | return new_smi_batch, bad_idxs 109 | 110 | def decode( 111 | self, 112 | ints, 113 | special=True, 114 | end_at_stop=True, 115 | de_fim=True, 116 | color_loss=None, # Provides colored likelihoods in blue 117 | ): 118 | """ 119 | Detokenizes a single row. 120 | 121 | Args: 122 | ints: a list of token integers 123 | special: decode special tokens? (if False they are mapped to '') 124 | de_fim: undo fill-in-middle 125 | Returns: 126 | a string of decoded tokens. 127 | """ 128 | if not len(ints): 129 | return "" 130 | assert type(ints[0]) == int 131 | if end_at_stop and self.stop_token in ints: 132 | ints = ints[: ints.index(self.stop_token) + 1] 133 | 134 | if not color_loss is None: 135 | assert len(color_loss) >= len(ints) 136 | max_loss = max(color_loss) 137 | min_loss = min(color_loss) 138 | strings = [ 139 | colored_background( 140 | int((color_loss[i] - min_loss) / (max_loss - min_loss) * 255), 141 | 128, 142 | 128, 143 | self.keys[I], 144 | ) 145 | for i, I in enumerate(ints) 146 | if I > 0 147 | ] 148 | else: 149 | strings = [self.keys[I] for I in ints if I > 0] 150 | 151 | if special: 152 | if de_fim and "[MIDDLE]" in strings and "[SUFFIX]" in strings: 153 | si = strings.index("[SUFFIX]") 154 | mi = strings.index("[MIDDLE]") 155 | return "".join( 156 | strings[:si] + strings[mi:-1] + strings[si:mi] + strings[-1:] 157 | ) 158 | else: 159 | return "".join(strings) 160 | else: 161 | if de_fim and "[MIDDLE]" in strings and "[SUFFIX]" in strings: 162 | si = strings.index("[SUFFIX]") 163 | mi = strings.index("[MIDDLE]") 164 | ordd = strings[:si] + strings[mi:-1] + strings[si:mi] + strings[-1:] 165 | return "".join([S for S in ordd if not S in self.special_tokens]) 166 | else: 167 | return "".join([S for S in strings if not S in self.special_tokens]) 168 | -------------------------------------------------------------------------------- /coatiLDM/data/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.interpolate import PchipInterpolator 4 | 5 | 6 | class cdf: 7 | # broadcasted/vectorized version of the CDF computation. 8 | # scales much better than using np.grid. 9 | def __init__(self, x, npts=800, max_bin=None, min_bin=None, verbose=False): 10 | 11 | if verbose: 12 | print( 13 | f"Computing empirical CDF from {x.shape[0]} points on {npts} grid points..." 14 | ) 15 | 16 | X = np.nan_to_num(x) 17 | if min_bin is None: 18 | min_bin = X.min() 19 | if max_bin is None: 20 | max_bin = X.max() 21 | left_barrier = min_bin - (max_bin - min_bin) / 10.0 22 | right_barrier = max_bin + (max_bin - min_bin) / 10.0 23 | self.grid = np.linspace(left_barrier, right_barrier, npts) 24 | self.count = (X[:, None] < self.grid).sum(0) 25 | self.scdf = self.count / self.count[-1] + np.linspace( 26 | 0, 1e-9, self.count.shape[0] 27 | ) 28 | if verbose: 29 | print("done! Computing function approximations....") 30 | # compute spline approximations to cdfs 31 | self.invcdf = PchipInterpolator(self.scdf, self.grid) 32 | self.cdf = PchipInterpolator(self.grid, self.scdf) 33 | # compute smoothed cdf, useful for computing pdf 34 | w = self.grid.shape[0] // 200 35 | self.smoothed_scdf = np.convolve( 36 | self.scdf, np.hamming(w) / np.sum(np.hamming(w)), mode="same" 37 | ) 38 | self.smoothed_scdf[-w:] = 1.0 39 | self.smoothed_cdf = PchipInterpolator(self.grid, self.smoothed_scdf) 40 | 41 | if verbose: 42 | print("done!") 43 | 44 | def pdf(self, pts, smooth=True): 45 | """ 46 | The pdf of the cdf 47 | """ 48 | cdf_approx = self.smoothed_cdf if smooth else self.cdf 49 | return cdf_approx.derivative()(pts) 50 | 51 | def to_unit_interval(self, pts): 52 | return self.cdf(pts) 53 | 54 | def quantile_class_boundaries( 55 | self, bounds=np.array([0.1, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999, 0.99999]) 56 | ): 57 | return self.invcdf(bounds) 58 | 59 | def sample(self, n_sample=4000): 60 | ent = np.random.random(n_sample) 61 | return self.invcdf(ent) 62 | 63 | 64 | def embed_scalar( 65 | timesteps, 66 | embedding_dim: int = 16, 67 | dtype=torch.float32, 68 | max_timescale=10_000, 69 | min_timescale=1, 70 | max_time=1.0, 71 | ): 72 | # Adapted from tensor2tensor and VDM codebase. 73 | assert timesteps.ndim == 1 74 | assert embedding_dim % 2 == 0 75 | timesteps *= ( 76 | 1000.0 / max_time 77 | ) # In DDPM the time step is in [0, 1000], in BFN [0, 1] 78 | num_timescales = embedding_dim // 2 79 | inv_timescales = torch.logspace( # or exp(-linspace(log(min), log(max), n)) 80 | -np.log10(min_timescale), 81 | -np.log10(max_timescale), 82 | num_timescales, 83 | device=timesteps.device, 84 | ) 85 | emb = timesteps.to(dtype)[:, None] * inv_timescales[None, :] # (T, D/2) 86 | return torch.cat([emb.sin(), emb.cos()], dim=1) # (T, D) 87 | 88 | 89 | def safe_embed_scalar( 90 | timesteps, 91 | embedding_dim: int = 16, 92 | dtype=torch.float32, 93 | max_timescale=10_000, 94 | min_timescale=1, 95 | max_time=1.0, 96 | ): 97 | # Adapted from tensor2tensor and VDM codebase. 98 | assert timesteps.ndim == 1 99 | assert embedding_dim % 2 == 0 100 | transformed_timesteps = timesteps * ( 101 | 1000.0 / max_time 102 | ) # In DDPM the time step is in [0, 1000], in BFN [0, 1] 103 | num_timescales = embedding_dim // 2 104 | inv_timescales = torch.logspace( # or exp(-linspace(log(min), log(max), n)) 105 | -np.log10(min_timescale), 106 | -np.log10(max_timescale), 107 | num_timescales, 108 | device=timesteps.device, 109 | ) 110 | emb = transformed_timesteps.to(dtype)[:, None] * inv_timescales[None, :] # (T, D/2) 111 | return torch.cat([emb.sin(), emb.cos()], dim=1) # (T, D) 112 | 113 | 114 | def cg_xform_routine( 115 | batch, 116 | x_field="normd_vector", 117 | scalar_field="normd_logp", 118 | bar_alphas=None, 119 | timesteps=None, 120 | device=torch.device("cpu"), 121 | no_noise=False, 122 | ): 123 | batch_size = len(batch) 124 | assert batch_size > 0 125 | T = torch.randint(low=0, high=timesteps, size=(batch_size,), device=device) 126 | stacked = {} 127 | unnoised_sample = torch.tensor( 128 | np.stack([row[x_field] for row in batch], 0), device=device, dtype=torch.float 129 | ) 130 | noise = torch.randn((batch_size, unnoised_sample.shape[-1]), device=device) 131 | b_alph_reshape = bar_alphas[T.long()].clamp(0.0, 1.0).unsqueeze(-1) 132 | noisy_samples = ( 133 | b_alph_reshape.sqrt() * unnoised_sample + (1.0 - b_alph_reshape).sqrt() * noise 134 | ) 135 | stacked["unnoised_samples"] = unnoised_sample 136 | stacked["noised_samples"] = noisy_samples 137 | stacked["T"] = T 138 | if no_noise: 139 | stacked["noised_samples"] = unnoised_sample 140 | stacked["T"] = torch.zeros_like(T) 141 | C = torch.tensor( 142 | [row[scalar_field] for row in batch], device=device, dtype=torch.float 143 | ) 144 | # make uniform 145 | stacked["target"] = C 146 | return stacked 147 | 148 | 149 | def xform_basic( 150 | batch, 151 | x_field="emb_smiles", 152 | scalar_cond_fields=["logp"], 153 | cond_emb_dim=16, 154 | device=torch.device("cpu"), 155 | ): 156 | """ 157 | Stacks and vector embeds. assumes no normalization. 158 | """ 159 | batch_size = len(batch) 160 | assert batch_size > 0 161 | stacked = {} 162 | stacked["samples"] = torch.tensor( 163 | np.stack([row[x_field] for row in batch], 0), device=device, dtype=torch.float 164 | ) 165 | cond_vectors = [] 166 | if len(scalar_cond_fields): 167 | for c in scalar_cond_fields: 168 | if c in batch[0]: 169 | stacked[c] = torch.tensor( 170 | [row[c] for row in batch], device=device, dtype=torch.float 171 | ) 172 | C = torch.tensor( 173 | [row[c] for row in batch], device=device, dtype=torch.float 174 | ) 175 | cond_vectors.append(embed_scalar(C, embedding_dim=cond_emb_dim)) 176 | if len(cond_vectors): 177 | cond_vectors = torch.cat(cond_vectors, -1) 178 | stacked["cond_vector"] = cond_vectors 179 | else: 180 | stacked["cond_vector"] = None 181 | return stacked 182 | -------------------------------------------------------------------------------- /coatiLDM/models/coati/trie_tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Tuple, List 3 | import json 4 | 5 | import torch 6 | 7 | from coatiLDM.common.utils import colored_background 8 | from coatiLDM.models.coati.tokenizers.trie import Trie 9 | 10 | 11 | class TrieTokenizer: 12 | """ 13 | Converts smiles+setinel tokens into a list of integers. 14 | 15 | Also applicable to a graph. 16 | Cannot accommodate > 150 atoms of a single type or 17 | 150 nodes in a graph. 18 | 19 | for specific interpretations of special tokens see fill_in_middle.py 20 | """ 21 | 22 | def __init__( 23 | self, 24 | n_seq=256, # The dimension of the token embedding. 25 | smiles_tokens=[], 26 | special_tokens=[], 27 | side_tasks=True, 28 | ): 29 | self.n_seq = n_seq 30 | self.special_tokens = special_tokens 31 | self.n_special = len(self.special_tokens) 32 | self.smiles_tokens = smiles_tokens 33 | self.keys = self.special_tokens + self.smiles_tokens 34 | self.n_token = len(self.keys) # base number of tokens. 35 | self.vocab = {T.strip(): I for I, T in enumerate(self.keys)} 36 | 37 | # I am human, after all. 38 | # These are tokens wrt, the model should be uniform (loss masked) 39 | self.stop_token = self.vocab["[STOP]"] 40 | self.pad_token = self.vocab["[PAD]"] 41 | 42 | self.clip_token = self.vocab["[CLIP]"] 43 | self.unk_token = self.vocab["[UNK]"] 44 | self.smiles_token = self.vocab["[SMILES]"] 45 | self.suffix_token = self.vocab["[SUFFIX]"] 46 | self.middle_token = self.vocab["[MIDDLE]"] 47 | self.mask_token = self.vocab["[MASK]"] 48 | if side_tasks: 49 | self.graph_token = self.vocab["[GRAPH]"] 50 | self.formula_token = self.vocab["[FORMULA]"] 51 | self.set_token = self.vocab["[SET]"] 52 | 53 | self.smiles_trie = Trie() 54 | self.special_trie = Trie() 55 | for k in self.special_tokens: 56 | self.special_trie.add(k) 57 | for k in self.smiles_tokens: 58 | self.smiles_trie.add(k) 59 | 60 | def pre_tokenize(self, text): 61 | """ 62 | Splits the special tokens first. 63 | """ 64 | split0 = self.special_trie.split(text) 65 | tokens = [] 66 | for T in split0: 67 | if T in self.special_tokens: 68 | tokens.append(T) 69 | else: 70 | tokens.extend(self.smiles_trie.split(T)) 71 | return tokens 72 | 73 | def tokenize_text( 74 | self, text: str, pad: bool = True, range_check: bool = True 75 | ) -> List[int]: 76 | """ 77 | Tokenizes a single row. 78 | """ 79 | try: 80 | tore = [self.vocab[T] for T in self.pre_tokenize(text)] 81 | if len(tore) > self.n_seq and range_check: 82 | raise Exception("Oversized String", len(tore)) 83 | if pad: 84 | tore = tore + [ 85 | self.vocab["[PAD]"] for k in range(self.n_seq - len(tore)) 86 | ] 87 | except Exception as Ex: 88 | print("tokenize text exception... ", text, Ex, self.pre_tokenize(text)) 89 | raise Ex 90 | return tore 91 | 92 | def batch_smiles( 93 | self, smiles_batch: List[str], device: str = "cpu", skip_failed: bool = False 94 | ): 95 | token_stack = [] 96 | bad_idxs = [] 97 | for idx, smi in enumerate(smiles_batch): 98 | try: 99 | ttext = self.tokenize_text( 100 | "[SMILES]" + smi + "[STOP]", pad=False, range_check=False 101 | ) 102 | except KeyError as e: 103 | if skip_failed: # filling with a dummy string, and adding to bad_idxs 104 | ttext = self.tokenize_text( 105 | "[SMILES]" + "C" + "[STOP]", pad=False, range_check=False 106 | ) 107 | bad_idxs.append(idx) 108 | else: 109 | raise e 110 | 111 | if len(ttext) <= self.n_seq: 112 | t = torch.zeros(self.n_seq, dtype=torch.long, device=device) 113 | t[: len(ttext)] = torch.tensor(ttext) 114 | token_stack.append(t) 115 | else: 116 | bad_idxs.append(idx) 117 | 118 | new_smi_batch = torch.stack(token_stack, 0) 119 | new_smi_batch = new_smi_batch[:, : (new_smi_batch.sum(0) > 0).sum()] 120 | return new_smi_batch, bad_idxs 121 | 122 | def decode( 123 | self, 124 | ints, 125 | special=True, 126 | end_at_stop=True, 127 | de_fim=True, 128 | color_loss=None, # Provides colored likelihoods in blue 129 | ): 130 | """ 131 | Detokenizes a single row. 132 | 133 | Args: 134 | ints: a list of token integers 135 | special: decode special tokens? (if False they are mapped to '') 136 | de_fim: undo fill-in-middle 137 | Returns: 138 | a string of decoded tokens. 139 | """ 140 | if not len(ints): 141 | return "" 142 | assert type(ints[0]) == int 143 | if end_at_stop and self.stop_token in ints: 144 | ints = ints[: ints.index(self.stop_token) + 1] 145 | 146 | if not color_loss is None: 147 | assert len(color_loss) >= len(ints) 148 | max_loss = max(color_loss) 149 | min_loss = min(color_loss) 150 | strings = [ 151 | colored_background( 152 | int((color_loss[i] - min_loss) / (max_loss - min_loss) * 255), 153 | 128, 154 | 128, 155 | self.keys[I], 156 | ) 157 | for i, I in enumerate(ints) 158 | if I > 0 159 | ] 160 | else: 161 | strings = [self.keys[I] for I in ints if I > 0] 162 | 163 | if special: 164 | if de_fim and "[MIDDLE]" in strings and "[SUFFIX]" in strings: 165 | si = strings.index("[SUFFIX]") 166 | mi = strings.index("[MIDDLE]") 167 | return "".join( 168 | strings[:si] + strings[mi:-1] + strings[si:mi] + strings[-1:] 169 | ) 170 | else: 171 | return "".join(strings) 172 | else: 173 | if de_fim and "[MIDDLE]" in strings and "[SUFFIX]" in strings: 174 | si = strings.index("[SUFFIX]") 175 | mi = strings.index("[MIDDLE]") 176 | ordd = strings[:si] + strings[mi:-1] + strings[si:mi] + strings[-1:] 177 | return "".join([S for S in ordd if not S in self.special_tokens]) 178 | else: 179 | return "".join([S for S in strings if not S in self.special_tokens]) 180 | -------------------------------------------------------------------------------- /coatiLDM/common/s3.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import pickle 4 | import pandas as pd 5 | import pytz 6 | from urllib.parse import urlparse 7 | 8 | import torch 9 | from tqdm import tqdm 10 | import boto3 11 | from botocore import UNSIGNED 12 | from botocore.client import Config 13 | 14 | from coatiLDM.constants import FIGURE_DATA_PATH 15 | 16 | 17 | def load_figure_file(figure_filename, local_dir, filetype="pkl", has_header=True): 18 | s3_path = os.path.join(FIGURE_DATA_PATH, figure_filename) 19 | bucket, prefix = split_s3_path(s3_path) 20 | sync_s3_file_to_local_dir(bucket, prefix, local_dir=local_dir) 21 | local_path = os.path.join(local_dir, figure_filename) 22 | if filetype == "pkl": 23 | with open(local_path, "rb") as f: 24 | return pickle.load(f) 25 | elif filetype == "csv": 26 | if has_header: 27 | return pd.read_csv(local_path) 28 | else: 29 | return pd.read_csv(local_path, header=None) 30 | elif filetype == "pt": 31 | return torch.load(local_path) 32 | else: 33 | raise ValueError(f"filetype {filetype} not supported") 34 | 35 | 36 | def split_s3_path(s3_path): 37 | components = urlparse(s3_path) 38 | # Remove the leading '/' from the path 39 | prefix = components.path[1:] 40 | return components.netloc, prefix 41 | 42 | 43 | def sync_s3_file_to_local_dir(bucket_name, prefix, local_dir="./", verbose=True): 44 | """ 45 | Sync s3 file to local disc if s3 file modified time > local modified time (or file does not exist) 46 | Default dir is user's home, otherwise set via S3_CACHE_DIR env 47 | """ 48 | # Initialize a session using boto3 49 | session = boto3.Session() 50 | 51 | # Use the session to create a resource 52 | s3 = session.resource( 53 | "s3", region_name="us-west-2", config=Config(signature_version=UNSIGNED) 54 | ) 55 | 56 | # Make sure the directory exists 57 | os.makedirs(local_dir, exist_ok=True) 58 | filename = prefix.split("/")[-1] 59 | local_file_path = os.path.join(local_dir, filename) 60 | 61 | # Get object summary for the file on s3 62 | s3_obj = s3.Object(bucket_name, prefix) 63 | 64 | # If local file exists, compare modification times 65 | if os.path.exists(local_file_path): 66 | # Get modification time of local file 67 | local_file_mtime = os.path.getmtime(local_file_path) 68 | local_file_dt = datetime.datetime.fromtimestamp( 69 | local_file_mtime, 70 | datetime.datetime.now(datetime.timezone.utc).astimezone().tzinfo, 71 | ).astimezone(pytz.utc) 72 | 73 | # Get 'LastModified' time of s3 object 74 | s3_obj_dt = s3_obj.last_modified.astimezone(pytz.utc) 75 | 76 | # Download file if it was modified on s3 after the local copy 77 | if s3_obj_dt > local_file_dt: 78 | if verbose: 79 | print( 80 | f"Re-downloading {prefix} from {bucket_name}, {s3_obj_dt} > {local_file_dt}" 81 | ) 82 | s3_obj.download_file(local_file_path) 83 | if verbose: 84 | print(f"File updated successfully at {local_file_path}") 85 | # else: 86 | # print(f"File at {local_file_path} is up-to-date.") 87 | else: 88 | # If local file doesn't exist, just download 89 | if verbose: 90 | print(f"Downloading {prefix} from {bucket_name}") 91 | s3_obj.download_file(local_file_path) 92 | if verbose: 93 | print(f"File downloaded successfully to {local_file_path}") 94 | 95 | return local_file_path 96 | 97 | 98 | def sync_s3_to_local(bucket_name, prefix, verbose=True, home_dir=None): 99 | """ 100 | Sync s3 file to local disc if s3 file modified time > local modified time (or file does not exist) 101 | Default dir is user's home, otherwise set via S3_CACHE_DIR env 102 | """ 103 | # Initialize a session using boto3 104 | session = boto3.Session() 105 | 106 | # Use the session to create a resource 107 | s3 = session.resource( 108 | "s3", region_name="us-west-2", config=Config(signature_version=UNSIGNED) 109 | ) 110 | 111 | # # Get home directory 112 | if home_dir is None: 113 | home_dir = os.getenv("S3_CACHE_DIR", os.path.expanduser("~")) 114 | 115 | # Generate local file path 116 | local_file_path = os.path.join(home_dir, prefix) 117 | local_file_dir = os.path.dirname(local_file_path) 118 | 119 | # Make sure the directory exists 120 | os.makedirs(local_file_dir, exist_ok=True) 121 | 122 | # Get object summary for the file on s3 123 | s3_obj = s3.Object(bucket_name, prefix) 124 | 125 | # If local file exists, compare modification times 126 | if os.path.exists(local_file_path): 127 | # Get modification time of local file 128 | local_file_mtime = os.path.getmtime(local_file_path) 129 | local_file_dt = datetime.datetime.fromtimestamp( 130 | local_file_mtime, 131 | datetime.datetime.now(datetime.timezone.utc).astimezone().tzinfo, 132 | ).astimezone(pytz.utc) 133 | 134 | # Get 'LastModified' time of s3 object 135 | s3_obj_dt = s3_obj.last_modified.astimezone(pytz.utc) 136 | 137 | # Download file if it was modified on s3 after the local copy 138 | if s3_obj_dt > local_file_dt: 139 | if verbose: 140 | print( 141 | f"Re-downloading {prefix} from {bucket_name}, {s3_obj_dt} > {local_file_dt}" 142 | ) 143 | s3_obj.download_file(local_file_path) 144 | if verbose: 145 | print(f"File updated successfully at {local_file_path}") 146 | # else: 147 | # print(f"File at {local_file_path} is up-to-date.") 148 | else: 149 | # If local file doesn't exist, just download 150 | if verbose: 151 | print(f"Downloading {prefix} from {bucket_name}") 152 | s3_obj.download_file(local_file_path) 153 | if verbose: 154 | print(f"File downloaded successfully to {local_file_path}") 155 | 156 | return local_file_path 157 | 158 | 159 | def copy_bucket_dir_from_s3(bucket_dir, dest_dir): 160 | s3_resource = boto3.resource("s3") 161 | bucket = s3_resource.Bucket("terray-public") 162 | nfiles = len(list(bucket.objects.filter(Prefix=bucket_dir))) 163 | if nfiles < 1: 164 | print(list(bucket.objects.filter(Prefix=bucket_dir))) 165 | raise Exception(f"empty_s3 {bucket_dir}") 166 | else: 167 | print(f"copying {nfiles} files from {bucket_dir} to {dest_dir}") 168 | for obj in tqdm(bucket.objects.filter(Prefix=bucket_dir), total=nfiles): 169 | if not os.path.exists(os.path.dirname(dest_dir + obj.key)): 170 | os.makedirs(os.path.dirname(dest_dir + obj.key)) 171 | bucket.download_file(obj.key, dest_dir + obj.key) # save to same path 172 | 173 | 174 | def download_from_s3(s3_path): 175 | """Simple download from s3 to local file""" 176 | 177 | bucket_name, prefix = split_s3_path(s3_path) 178 | local_file_path = sync_s3_to_local(bucket_name, prefix, verbose=True) 179 | return local_file_path 180 | 181 | 182 | class cache_read: 183 | VALID_MODES = ["rb", "r"] 184 | """Given full s3_uri with bucket name, sync it locally if needed, open it""" 185 | 186 | def __init__(self, s3_path, mode, verbose=True): 187 | self.s3_path = s3_path 188 | if mode not in self.VALID_MODES: 189 | raise ValueError(f'"{mode}" not in {self.VALID_MODES}') 190 | self.mode = mode 191 | self.local_file_path = None 192 | self.file = None 193 | self.verbose = verbose 194 | 195 | def __enter__(self): 196 | if os.path.isfile(self.s3_path): 197 | self.local_file_path = self.s3_path 198 | else: 199 | bucket_name, prefix = split_s3_path(self.s3_path) 200 | self.local_file_path = sync_s3_to_local( 201 | bucket_name, prefix, verbose=self.verbose 202 | ) 203 | 204 | if self.local_file_path is not None: 205 | self.file = open(self.local_file_path, self.mode) 206 | return self.file 207 | 208 | def __exit__(self, exc_type, exc_val, exc_tb): 209 | if self.file is not None: 210 | self.file.close() 211 | -------------------------------------------------------------------------------- /coatiLDM/models/score_models/non_conv_unet.py: -------------------------------------------------------------------------------- 1 | # 2 | # A non-convolutional u-net. 3 | # 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | import math 8 | 9 | from coatiLDM.models.score_models.score_model import get_time_embedding, SwiGLUNet 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | class BottleNeck(nn.Module): 16 | def __init__( 17 | self, 18 | in_dim=256, 19 | out_dim=256, 20 | const_dim=256, 21 | dropout=0.0, 22 | use_weight_norm=False, 23 | bias=True, 24 | ): 25 | super().__init__() 26 | self.in_dim = in_dim 27 | self.out_dim = out_dim 28 | self.const_dim = const_dim 29 | self.use_weight_norm = use_weight_norm 30 | self.wedge_in = SwiGLUNet( 31 | self.in_dim + self.const_dim, 32 | self.out_dim, 33 | dropout=dropout, 34 | use_weight_norm=use_weight_norm, 35 | bias=bias, 36 | ) 37 | self.wedge_out = SwiGLUNet( 38 | self.out_dim + self.const_dim, 39 | self.in_dim, 40 | dropout=dropout, 41 | use_weight_norm=use_weight_norm, 42 | bias=bias, 43 | ) 44 | 45 | def forward(self, x, const): 46 | """ 47 | Downsamples and residuals like a U-Net. 48 | """ 49 | tocat = [const] 50 | z = self.wedge_in(torch.cat([x, *tocat], -1)) 51 | z2 = self.wedge_out(torch.cat([torch.nn.functional.silu(z), *tocat], -1)) 52 | out_ = x + z2 53 | return out_, z 54 | 55 | 56 | class Flat(nn.Module): 57 | def __init__( 58 | self, in_dim=256, const_dim=256, dropout=0.0, use_weight_norm=False, bias=True 59 | ): 60 | super().__init__() 61 | self.in_dim = in_dim 62 | self.const_dim = const_dim 63 | self.wedge_in = SwiGLUNet( 64 | self.in_dim + self.const_dim, 65 | self.in_dim, 66 | dropout=dropout, 67 | use_weight_norm=use_weight_norm, 68 | bias=bias, 69 | ) 70 | self.wedge_out = SwiGLUNet( 71 | self.in_dim + self.const_dim, 72 | self.in_dim, 73 | dropout=dropout, 74 | use_weight_norm=use_weight_norm, 75 | bias=bias, 76 | ) 77 | self.use_weight_norm = use_weight_norm 78 | 79 | def forward(self, x, const): 80 | tocat = [const] 81 | z = self.wedge_in(torch.cat([x, *tocat], -1)) 82 | z2 = self.wedge_out(torch.cat([x + torch.nn.functional.silu(z), *tocat], -1)) 83 | out = x + z2 84 | return out 85 | 86 | 87 | class Up(nn.Module): 88 | def __init__( 89 | self, 90 | in_dim=256, 91 | out_dim=256, 92 | const_dim=256, 93 | dropout=0.0, 94 | use_weight_norm=False, 95 | bias=True, 96 | ): 97 | super().__init__() 98 | self.in_dim = in_dim 99 | self.out_dim = out_dim 100 | self.const_dim = const_dim 101 | self.wedge_in = SwiGLUNet( 102 | self.in_dim + self.const_dim, 103 | self.out_dim, 104 | dropout=dropout, 105 | use_weight_norm=use_weight_norm, 106 | bias=bias, 107 | ) 108 | self.wedge_out = SwiGLUNet( 109 | self.out_dim + self.const_dim, 110 | self.out_dim, 111 | dropout=dropout, 112 | use_weight_norm=use_weight_norm, 113 | bias=bias, 114 | ) 115 | self.use_weight_norm = use_weight_norm 116 | 117 | def forward(self, x, const): 118 | tocat = [const] 119 | z = self.wedge_in(torch.cat([x, *tocat], -1)) 120 | z2 = self.wedge_out(torch.cat([torch.nn.functional.silu(z), *tocat], -1)) 121 | out = z + z2 122 | return out 123 | 124 | 125 | class OU(nn.Module): 126 | def __init__( 127 | self, 128 | x_dim=256, 129 | cond_dim=64, 130 | time_dim=64, 131 | ): 132 | super().__init__() 133 | # Allow learning of OU-like nearly constant DOFs. 134 | # score(X) \propto mu - X 135 | self.ou_const = nn.Parameter(torch.zeros(x_dim).normal_()) 136 | self.ou_dec = SwiGLUNet( 137 | time_dim + cond_dim, 138 | x_dim, 139 | residual=False, 140 | dropout=False, 141 | use_weight_norm=False, 142 | bias=True, 143 | ) 144 | 145 | def forward(self, x, t, cond=None): 146 | tocat = [t] 147 | if (not cond is None) and len(cond): 148 | tocat.append(cond) 149 | const = torch.cat(tocat, -1) 150 | fac = torch.nn.functional.softplus(self.ou_dec(const)) 151 | return (x - self.ou_const.unsqueeze(0)) * fac 152 | 153 | 154 | class NonConvUNet(nn.Module): 155 | def __init__( 156 | self, 157 | x_dim=256, 158 | cond_dim=1, 159 | time_max=1.0, 160 | time_dim=None, 161 | dropout=0.0, 162 | use_weight_norm=False, 163 | scheduler=None, 164 | bias=True, 165 | ): 166 | super().__init__() 167 | self.x_dim = x_dim 168 | self.cond_dim = cond_dim 169 | self.time_max = time_max 170 | self.time_dim = time_dim 171 | self.use_weight_norm = use_weight_norm 172 | self.scheduler = scheduler 173 | 174 | self.ou = OU(x_dim=x_dim, time_dim=time_dim, cond_dim=cond_dim) 175 | 176 | # bias is just set to true for the bottleneck layers. This aligned with whatever was going on in John branch pre-merge. 177 | self.steps_down = nn.ModuleList( 178 | [ 179 | BottleNeck( 180 | x_dim, 181 | x_dim // 2, 182 | time_dim + cond_dim, 183 | dropout=dropout, 184 | use_weight_norm=use_weight_norm, 185 | bias=True, 186 | ), 187 | BottleNeck( 188 | x_dim // 2, 189 | x_dim // 4, 190 | time_dim + cond_dim, 191 | dropout=dropout, 192 | use_weight_norm=use_weight_norm, 193 | bias=True, 194 | ), 195 | BottleNeck( 196 | x_dim // 4, 197 | x_dim // 8, 198 | time_dim + cond_dim, 199 | dropout=dropout, 200 | use_weight_norm=use_weight_norm, 201 | bias=True, 202 | ), 203 | ] 204 | ) 205 | 206 | self.flat = nn.ModuleList( 207 | [ 208 | Flat( 209 | x_dim // 8, 210 | time_dim + cond_dim, 211 | dropout=dropout, 212 | use_weight_norm=use_weight_norm, 213 | bias=bias, 214 | ), 215 | Flat( 216 | x_dim, 217 | time_dim + cond_dim, 218 | dropout=dropout, 219 | use_weight_norm=use_weight_norm, 220 | bias=bias, 221 | ), 222 | ] 223 | ) 224 | 225 | self.steps_up = nn.ModuleList( 226 | [ 227 | Up( 228 | x_dim // 8, 229 | x_dim // 4, 230 | time_dim + cond_dim, 231 | dropout=dropout, 232 | use_weight_norm=use_weight_norm, 233 | bias=bias, 234 | ), 235 | Up( 236 | x_dim // 4, 237 | x_dim // 2, 238 | time_dim + cond_dim, 239 | dropout=dropout, 240 | use_weight_norm=use_weight_norm, 241 | bias=bias, 242 | ), 243 | Up( 244 | x_dim // 2, 245 | x_dim, 246 | time_dim + cond_dim, 247 | dropout=dropout, 248 | use_weight_norm=use_weight_norm, 249 | bias=bias, 250 | ), 251 | ] 252 | ) 253 | 254 | def forward(self, x, t, cond=None): 255 | time = get_time_embedding( 256 | t, max_time=self.time_max, embedding_dim=self.time_dim 257 | ) 258 | tocat = [time] 259 | if (not cond is None) and len(cond): 260 | tocat.append(cond) 261 | const = torch.cat(tocat, -1) 262 | 263 | z0 = self.ou(x, time, cond=cond) 264 | 265 | x0, d0 = self.steps_down[0](x, const) 266 | x1, d1 = self.steps_down[1](d0, const) 267 | x2, d2 = self.steps_down[2](d1, const) 268 | 269 | x3 = self.flat[0](d2, const) 270 | 271 | y2 = self.steps_up[0](x3, const) 272 | y1 = self.steps_up[1](y2 + x2, const) 273 | y0 = self.steps_up[2](x1 + y1, const) 274 | 275 | return self.flat[1](x0 + y0 + z0, const) 276 | -------------------------------------------------------------------------------- /coatiLDM/trainers/lfm_direct.py: -------------------------------------------------------------------------------- 1 | # 2 | # Distributed Latent Flow matching. 3 | # because why not? 4 | # 5 | 6 | import pickle, os, argparse, copy, datetime, gc 7 | import numpy as np 8 | import pandas as pd 9 | import json 10 | from datetime import timedelta 11 | 12 | import matplotlib.pyplot as plt 13 | from tqdm.auto import tqdm 14 | 15 | import torch 16 | from torch.optim import Adam 17 | import torchdata 18 | from torch.utils.data.datapipes.iter import IterableWrapper 19 | 20 | # Distributed stuff. 21 | from smart_open import open 22 | 23 | 24 | from coatiLDM.models.diffusion_models.flow_matching import ( 25 | CondOTFlowMatching, 26 | ScoreNetCondVF, 27 | ) 28 | from coatiLDM.common.utils import makedir, utc_epoch_now 29 | from coatiLDM.common.ema import ExponentialMovingAverage 30 | 31 | # from coatiLDM.data.transforms import xform_basic, inference_helper, embed_scalar 32 | from coatiLDM.data.transforms import xform_basic 33 | 34 | from coatiLDM.models.score_models.non_conv_unet import NonConvUNet 35 | 36 | 37 | def save_model(model, output_path, norm_summary=None, train_args=None): 38 | torch.save(model, output_path + ".pt") 39 | with open(output_path + ".pkl", "wb") as f: 40 | pickle.dump({"norm_summary": norm_summary, "train_args": train_args}, f) 41 | 42 | 43 | def serialize_score_model(score_model, norm_summary, train_args, score_model_params): 44 | model_state_dict = score_model.to("cpu").state_dict() 45 | model_serialized = { 46 | "model": model_state_dict, 47 | "norm_summary": norm_summary, 48 | "score_model_params": score_model_params, 49 | "train_args": train_args, 50 | } 51 | return pickle.dumps(model_serialized) 52 | 53 | 54 | def train_flow(args): 55 | makedir(args.model_dir) 56 | 57 | data_split_name = args.data_path.split("/")[-1].split(".")[0] 58 | tags = { 59 | "data_path": f"direct_{args.data_path}", 60 | "run_name": f"run__{args.run_name}", 61 | "diff_model": args.diff_type, 62 | "score_model": args.score_model, 63 | "data_split_name": data_split_name, 64 | } 65 | 66 | train_val_meta_dict = pickle.load(open(args.data_path, "rb")) 67 | try: 68 | norm_summary = train_val_meta_dict["cond_cdfs"] 69 | except: 70 | norm_summary = None 71 | 72 | if isinstance(train_val_meta_dict, dict): 73 | try: 74 | coati_doc = train_val_meta_dict["metadata"]["coati_doc"] 75 | except: 76 | coati_doc = None 77 | else: 78 | train_val_meta_dict = {"train": train_val_meta_dict} 79 | coati_doc = None 80 | # load it all into memory. 81 | 82 | base_pipe = IterableWrapper(train_val_meta_dict["train"], deepcopy=False) 83 | 84 | print("loaded data from " + args.data_path) 85 | print(f"batch_size:{args.batch_size}") 86 | print("device:", args.device) 87 | 88 | datapipe = ( 89 | base_pipe.shuffle() 90 | .batch(args.batch_size) 91 | .collate( 92 | lambda batch: xform_basic( 93 | batch, 94 | x_field=args.x_field, 95 | scalar_cond_fields=args.scalar_cond_fields, 96 | cond_emb_dim=args.dim_per_cond, 97 | device=args.device, 98 | ) 99 | ) 100 | ) 101 | 102 | print("obtaining test batch... ") 103 | test_batch = next(iter(datapipe)) 104 | 105 | x_dim = test_batch["samples"].shape[-1] 106 | if not test_batch["cond_vector"] is None: 107 | cond_dim = test_batch["cond_vector"].shape[-1] 108 | else: 109 | cond_dim = 0 110 | print("building model... ") 111 | 112 | if args.score_model == "non_conv_unet": 113 | score_model_params = { 114 | "x_dim": x_dim, 115 | "cond_dim": cond_dim, 116 | "time_max": 1.0, 117 | "time_dim": args.time_dim, 118 | "dropout": 0.0, 119 | "scheduler": None, 120 | "bias": args.bias, 121 | "use_weight_norm": args.use_weight_norm, 122 | } 123 | score_model = NonConvUNet(**score_model_params) 124 | else: 125 | raise ValueError("only score_model == non_conv_unet is supported currently") 126 | 127 | print(f"Cond dim {cond_dim}") 128 | 129 | if args.diff_type == "flow_matching": 130 | flow_matching = CondOTFlowMatching() 131 | flow_model = ScoreNetCondVF(score_model).to(args.device) # Has no parameters. 132 | else: 133 | print(args.diff_type) 134 | raise Exception("bad diff_type") 135 | 136 | print("DEVICE: ", args.device) 137 | print("Flow Model: ") 138 | print(flow_model) 139 | 140 | train_args = vars(args) 141 | train_args["x_dim"] = x_dim 142 | train_args["cond_dim"] = cond_dim 143 | train_args["coati_doc"] = coati_doc 144 | 145 | optimizer = torch.optim.AdamW( 146 | flow_model.parameters(), lr=args.lr, weight_decay=args.weight_decay 147 | ) 148 | if args.ema: 149 | ema = ExponentialMovingAverage(flow_model.parameters(), decay=args.ema_const) 150 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.gamma) 151 | num = 1 152 | ave_losses = [] 153 | 154 | for epoch in range(0, args.num_epochs + 1): 155 | losses = [] 156 | with tqdm(datapipe, desc="Epoch {} ".format(epoch), unit="batch") as tepoch: 157 | for batch in tepoch: 158 | optimizer.zero_grad() 159 | loss = flow_matching.loss( 160 | flow_model, batch["samples"], batch["cond_vector"] 161 | ) 162 | loss.backward() 163 | 164 | torch.nn.utils.clip_grad_norm_(flow_model.parameters(), 1.0) 165 | optimizer.step() 166 | if args.ema: 167 | ema.update() 168 | losses.append(loss.detach().cpu().numpy().item()) 169 | tepoch.set_postfix(loss=loss.item()) 170 | num += 1 171 | 172 | scheduler.step() 173 | ave = 0 174 | for loss in losses: 175 | ave += loss 176 | ave = ave / len(losses) 177 | ave_losses.append(ave) 178 | print("Epoch {}: Loss: {:.8f}".format(epoch, ave)) 179 | 180 | output_path = os.path.join(args.model_dir, f"{args.exp_name}_{args.run_name}_final") 181 | if args.ema: 182 | with ema.average_parameters(): 183 | print("serializing ema model") 184 | score_model_serialized = serialize_score_model( 185 | flow_model.score_net, norm_summary, train_args, score_model_params 186 | ) 187 | else: 188 | print("serializing non-ema model") 189 | score_model_serialized = serialize_score_model( 190 | flow_model.score_net, norm_summary, train_args, score_model_params 191 | ) 192 | 193 | with open(output_path + ".pkl", "wb") as f: 194 | f.write(score_model_serialized) 195 | 196 | return flow_model 197 | 198 | 199 | def do_args(mock_args=False): 200 | parser = argparse.ArgumentParser() 201 | parser.add_argument("--data_path", type=str, default=None) 202 | parser.add_argument("--exp_name", type=str, default="flow") 203 | parser.add_argument("--run_name", type=str, default=str(int(utc_epoch_now()))) 204 | parser.add_argument("--model_dir", type=str, default="./") 205 | 206 | parser.add_argument("--device", type=str, default="cuda:0") 207 | 208 | parser.add_argument("--diff_type", type=str, default="flow_matching") 209 | parser.add_argument("--score_model", type=str, default="non_conv_unet") 210 | parser.add_argument( 211 | "--scalar_cond_fields", type=list, default=["INVSIGCDF", "logp"] 212 | ) 213 | parser.add_argument("--dim_per_cond", type=int, default=16) 214 | parser.add_argument("--use_weight_norm", type=bool, default=False) 215 | parser.add_argument("--bias", type=bool, default=True) 216 | parser.add_argument("--time_dim", type=int, default=16) 217 | parser.add_argument("--timesteps", type=int, default=1000) 218 | parser.add_argument("--beta_start", type=float, default=1e-4) 219 | parser.add_argument("--beta_end", type=float, default=0.02) 220 | parser.add_argument("--schedule", type=str, default="linear") 221 | parser.add_argument("--ema", type=bool, default=True) 222 | parser.add_argument("--ema_const", type=float, default=0.996) 223 | 224 | parser.add_argument("--torch_compile", type=bool, default=False) 225 | parser.add_argument("--restore_chkpt", type=str, default=None) 226 | parser.add_argument("--num_epochs", type=int, default=50) 227 | parser.add_argument("--skip_train", type=bool, default=False) 228 | parser.add_argument("--save_every_n_updates", type=int, default=1e6) 229 | parser.add_argument("--batch_size", type=int, default=896) 230 | parser.add_argument("--gamma", type=float, default=0.9) 231 | parser.add_argument("--weight_decay", type=float, default=1e-4) 232 | parser.add_argument("--lr", type=float, default=2e-4) 233 | if mock_args: 234 | args = parser.parse_args(args=[]) 235 | else: 236 | args = parser.parse_args() 237 | return args 238 | -------------------------------------------------------------------------------- /coatiLDM/models/coati/transformer_only.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit import Chem 3 | from rdkit import RDLogger 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | 8 | from coatiLDM.models.coati.smiles_xformer import ( 9 | SmilesTransformerConfig, 10 | RotarySmilesTransformer, 11 | ) 12 | 13 | 14 | RDLogger.DisableLog("rdApp.*") 15 | lg = RDLogger.logger() 16 | lg.setLevel(RDLogger.CRITICAL) 17 | 18 | 19 | class SwiGLUResNet(nn.Module): 20 | def __init__(self, d_in, d_out, dropout=0.0): 21 | """ 22 | 10/25 - added dropout. 23 | """ 24 | super().__init__() 25 | self.net = nn.Sequential( 26 | nn.LayerNorm(d_in), 27 | torch.nn.Dropout(p=dropout), 28 | nn.Linear(d_in, 2 * d_out), 29 | SwiGLU(), 30 | nn.Linear(d_out, d_out), 31 | ) 32 | 33 | def forward(self, x): 34 | return self.net(x) + x 35 | 36 | 37 | class SwiGLU(nn.Module): 38 | def forward(self, x): 39 | x, gate = x.chunk(2, dim=-1) 40 | return torch.nn.functional.silu(gate) * x 41 | 42 | 43 | class COATI_Smiles_Inference(nn.Module): 44 | """ 45 | A coati that can try to take advantage of the 46 | pseudoscalar signal from allegro. 47 | """ 48 | 49 | def __init__( 50 | self, 51 | n_layer_xformer=16, 52 | n_hidden_xformer=256, 53 | embed_dim=256, 54 | n_head=16, 55 | n_seq=80, 56 | mlp_dropout=0.0, 57 | enc_to_coati="linear", # 'swiglu_mlp', 58 | n_direct_clr=64, # n_dim to take from the representation for the directCLR loss. 59 | n_tok=4, # I think this is a hack to pickle num toks processed 60 | biases=True, 61 | device=torch.device("cpu"), 62 | dtype=torch.float, 63 | ): 64 | super().__init__() 65 | 66 | self.embed_dim = embed_dim 67 | self.enc_to_coati = enc_to_coati 68 | self.n_direct_clr = n_direct_clr 69 | 70 | kwargs = { 71 | "n_layer": n_layer_xformer, 72 | "n_embd": n_hidden_xformer, 73 | "n_head": n_head, 74 | "n_seq": n_seq, 75 | "n_tok": n_tok, 76 | "device": device, 77 | "dtype": dtype, 78 | "biases": biases, 79 | } 80 | 81 | self.xformer_config = SmilesTransformerConfig(**kwargs) 82 | self.xformer = RotarySmilesTransformer(self.xformer_config) 83 | self.device = device 84 | 85 | if enc_to_coati == "linear": 86 | self.smiles_to_coati = nn.Sequential( 87 | nn.LayerNorm(self.embed_dim), 88 | nn.Linear(self.xformer.n_embd, self.embed_dim), 89 | ) 90 | # Make the common representation 91 | elif enc_to_coati == "swiglu_mlp": 92 | self.smiles_to_coati = nn.Sequential( 93 | nn.LayerNorm(self.xformer.n_embd), 94 | nn.Linear(self.xformer.n_embd, 2 * self.embed_dim), 95 | SwiGLU(), 96 | nn.Linear(self.embed_dim, self.embed_dim), 97 | ) 98 | elif enc_to_coati == "swiglu_resnet": 99 | self.smiles_to_coati = SwiGLUResNet( 100 | self.xformer.n_embd, self.embed_dim, dropout=mlp_dropout 101 | ) 102 | 103 | self.coati_to_token = SwiGLUResNet(self.embed_dim, self.embed_dim) 104 | 105 | n_params_smiles = sum(p.numel() for p in self.xformer.parameters()) 106 | print(f"number of parameters Total: xformer: {n_params_smiles/1e6:.2f}M ") 107 | self.to(self.device) 108 | 109 | def encode_tokens(self, token_indices, tokenizer): 110 | assert token_indices.dim() == 2 111 | return self.smiles_to_coati(self.xformer.encode(token_indices, tokenizer)) 112 | 113 | def hcoati_to_2d( 114 | self, 115 | h_coati, 116 | tokenizer, 117 | fill_in_from="[SMILES]", 118 | noise_scale=0.0, 119 | do_suffix=False, 120 | inv_temp=2, 121 | k=100, 122 | ): 123 | """ 124 | Testing generation of SMILES (or GRAPH) 125 | from atoms and coords 126 | """ 127 | assert fill_in_from == "[SMILES]" or fill_in_from == "[GRAPH]" 128 | if noise_scale > 0: 129 | h_coati += torch.normal( 130 | mean=torch.zeros_like(h_coati), 131 | std=noise_scale * torch.ones_like(h_coati), 132 | ) 133 | h_token = self.coati_to_token(h_coati) 134 | # create a 'batch' to infer smiles. 135 | if do_suffix: 136 | suffstr = "[SUFFIX][MIDDLE]" 137 | else: 138 | suffstr = "" 139 | token_prebatch = tokenizer.tokenize_text( 140 | "[CLIP][UNK]" + fill_in_from + suffstr, pad=False 141 | ) 142 | generation = self.xformer.generate_topk_with_inj( 143 | prefix=token_prebatch, 144 | stop_token=tokenizer.stop_token, 145 | inv_temp=inv_temp, 146 | k=k, 147 | inj_token=tokenizer.unk_token, 148 | inj_payload=h_token[0], 149 | ) 150 | if fill_in_from == "[SMILES]": 151 | return tokenizer.decode(generation, special=False) 152 | else: 153 | return tokenizer.decode(generation) 154 | 155 | def hcoati_to_2d_batch( 156 | self, 157 | h_coati: torch.Tensor, 158 | tokenizer, 159 | fill_in_from: str = "[SMILES]", 160 | noise_scale: float = 0.0, 161 | inv_temp: float = 2, 162 | k: int = 100, 163 | do_suffix=False, 164 | keep_special: bool = False, 165 | return_tokens: bool = False, 166 | ): 167 | """ 168 | Testing generation of SMILES (or GRAPH) 169 | from atoms and coords 170 | """ 171 | assert k > 1 172 | if noise_scale > 0: 173 | h_coati += torch.normal( 174 | mean=torch.zeros_like(h_coati), 175 | std=noise_scale * torch.ones_like(h_coati), 176 | ) 177 | h_token = self.coati_to_token(h_coati) 178 | if do_suffix: 179 | suffstr = "[SUFFIX][MIDDLE]" 180 | else: 181 | suffstr = "" 182 | token_prebatch = tokenizer.tokenize_text( 183 | "[CLIP][UNK]" + fill_in_from + suffstr, pad=False 184 | ) 185 | assert h_token.dim() == 2 186 | assert h_token.shape[-1] == self.xformer.n_embd 187 | generation = self.xformer.generate_top_k_with_inj_batch( 188 | prefix=token_prebatch, 189 | stop_token=tokenizer.stop_token, 190 | inv_temp=inv_temp, 191 | k=k, 192 | pad_token=tokenizer.pad_token, 193 | inj_token=tokenizer.unk_token, 194 | inj_payload=h_token, 195 | ) 196 | smiles_list = [ 197 | tokenizer.decode(token_out, special=keep_special) 198 | for token_out in generation 199 | ] 200 | 201 | if return_tokens: 202 | return smiles_list, generation 203 | 204 | return smiles_list 205 | 206 | def hcoati_to_2d_batch_beam( 207 | self, 208 | h_coati: torch.Tensor, 209 | tokenizer, 210 | fill_in_from: str = "[SMILES]", 211 | noise_scale: float = 0.0, 212 | max_len: int = 100, 213 | beam_width: int = 5, 214 | beam_iter_batch: int = 128, 215 | do_suffix=False, 216 | keep_special: bool = False, 217 | return_tokens: bool = False, 218 | force_stop=False, 219 | return_probs=False, 220 | ): 221 | """ 222 | Testing generation of SMILES (or GRAPH) 223 | from atoms and coords 224 | """ 225 | 226 | if noise_scale > 0: 227 | h_coati += torch.normal( 228 | mean=torch.zeros_like(h_coati), 229 | std=noise_scale * torch.ones_like(h_coati), 230 | ) 231 | h_token = self.coati_to_token(h_coati) 232 | if do_suffix: 233 | suffstr = "[SUFFIX][MIDDLE]" 234 | else: 235 | suffstr = "" 236 | token_prebatch = tokenizer.tokenize_text( 237 | "[CLIP][UNK]" + fill_in_from + suffstr, pad=False 238 | ) 239 | assert h_token.dim() == 2 240 | assert h_token.shape[-1] == self.xformer.n_embd 241 | generation, probs = self.xformer.generate_beam_search_batch( 242 | prefix=token_prebatch, 243 | predictions=max_len, 244 | beam_width=beam_width, 245 | batch_size=beam_iter_batch, 246 | inj_token=tokenizer.unk_token, 247 | inj_payload=h_token, 248 | stop_token=tokenizer.stop_token, 249 | force_stop=force_stop, 250 | progress_bar=1, 251 | ) 252 | 253 | smiles_list = [ 254 | [ 255 | tokenizer.decode(token_out, special=keep_special) 256 | for token_out in generation[i].tolist() 257 | ] 258 | for i in range(len(generation)) 259 | ] 260 | tore = {"smiles": smiles_list} 261 | if return_tokens: 262 | tore["tokens"] = generation 263 | if return_probs: 264 | tore["probs"] = probs 265 | 266 | return tore 267 | -------------------------------------------------------------------------------- /coatiLDM/trainers/ldm_direct.py: -------------------------------------------------------------------------------- 1 | # 2 | # Unified training for any CONTINUOUS 3 | # vector & set of conditions. 4 | # 5 | 6 | import pickle, os, argparse, copy 7 | import numpy as np 8 | import pandas as pd 9 | import json 10 | 11 | import matplotlib.pyplot as plt 12 | from tqdm.auto import tqdm 13 | 14 | import torch 15 | from torch.optim import Adam 16 | import torchdata 17 | from torchdata.datapipes.iter import IterableWrapper, Mapper, Shuffler 18 | from smart_open import open 19 | from coatiLDM.common.fd import calc_fd 20 | 21 | from torch.utils.data.datapipes.iter import IterableWrapper 22 | 23 | from coatiLDM.models.diffusion_models.schedulers import DDPMScheduler 24 | from coatiLDM.common.utils import makedir, utc_epoch_now 25 | from coatiLDM.common.ema import ExponentialMovingAverage 26 | 27 | 28 | from coatiLDM.data.transforms import xform_basic 29 | 30 | # from coatiLDM.models.diffusion_models.ddpm import DDPM 31 | from coatiLDM.models.diffusion_models.ddpm_lightweight import DDPMScoreNetTrainer 32 | 33 | # has continuous order. 34 | 35 | from coatiLDM.models.score_models.non_conv_unet import NonConvUNet 36 | 37 | 38 | def save_model(model, output_path, norm_summary=None, train_args=None): 39 | torch.save(model, output_path + ".pt") 40 | with open(output_path + ".pkl", "wb") as f: 41 | pickle.dump({"norm_summary": norm_summary, "train_args": train_args}, f) 42 | 43 | 44 | def serialize_score_model(score_model, norm_summary, train_args, score_model_params): 45 | model_state_dict = score_model.to("cpu").state_dict() 46 | model_serialized = { 47 | "model": model_state_dict, 48 | "norm_summary": norm_summary, 49 | "score_model_params": score_model_params, 50 | "train_args": train_args, 51 | } 52 | return pickle.dumps(model_serialized) 53 | 54 | 55 | def train_diffusion(args): 56 | makedir(args.model_dir) 57 | 58 | data_split_name = args.data_path.split("/")[-1].split(".")[0] 59 | tags = { 60 | "data_path": f"direct_{args.data_path}", 61 | "run_name": f"run__{args.run_name}", 62 | "diff_model": args.diff_type, 63 | "score_model": args.score_model, 64 | "data_split_name": data_split_name, 65 | } 66 | 67 | makedir(args.model_dir) 68 | 69 | train_val_meta_dict = pickle.load(open(args.data_path, "rb")) 70 | 71 | try: 72 | norm_summary = train_val_meta_dict["cond_cdfs"] 73 | except: 74 | norm_summary = None 75 | 76 | if isinstance(train_val_meta_dict, dict): 77 | coati_doc = train_val_meta_dict["metadata"]["coati_doc"] 78 | else: 79 | train_val_meta_dict = {"train": train_val_meta_dict} 80 | coati_doc = None 81 | # load it all into memory. 82 | 83 | base_pipe = IterableWrapper(train_val_meta_dict["train"], deepcopy=False) 84 | 85 | print("loaded data from " + args.data_path) 86 | print(f"batch_size:{args.batch_size}") 87 | print("device:", args.device) 88 | 89 | datapipe = ( 90 | base_pipe.shuffle() 91 | .batch(args.batch_size) 92 | .collate( 93 | lambda batch: xform_basic( 94 | batch, 95 | x_field=args.x_field, 96 | scalar_cond_fields=args.scalar_cond_fields, 97 | cond_emb_dim=args.dim_per_cond, 98 | device=args.device, 99 | ) 100 | ) 101 | ) 102 | 103 | print("obtaining test batch... ") 104 | test_batch = next(iter(datapipe)) 105 | x_dim = test_batch["samples"].shape[-1] 106 | if not test_batch["cond_vector"] is None: 107 | cond_dim = test_batch["cond_vector"].shape[-1] 108 | else: 109 | cond_dim = 0 110 | print("building model... ") 111 | 112 | scheduler_params = { 113 | "schedule": args.schedule, 114 | "timesteps": args.timesteps, 115 | "beta_start": args.beta_start, 116 | "beta_end": args.beta_end, 117 | } 118 | scheduler = DDPMScheduler(**scheduler_params) 119 | 120 | if args.score_model == "non_conv_unet": 121 | score_model_params = { 122 | "x_dim": x_dim, 123 | "cond_dim": cond_dim, 124 | "time_max": 1.0 if args.diff_type.count("bfn") > 0 else args.timesteps, 125 | "time_dim": args.time_dim, 126 | "dropout": args.dropout, 127 | "scheduler": scheduler, 128 | "bias": args.bias, 129 | "use_weight_norm": args.use_weight_norm, 130 | } 131 | score_model = NonConvUNet(**score_model_params) 132 | else: 133 | raise ValueError("only score_model == non_conv_unet is supported currently") 134 | 135 | print(f"Cond dim {cond_dim}") 136 | 137 | diff_model = DDPMScoreNetTrainer(score_model).to(args.device) 138 | 139 | print("DEVICE: ", args.device) 140 | print("Diffusion Model: ") 141 | print(diff_model) 142 | 143 | train_args = vars(args) 144 | train_args["x_dim"] = x_dim 145 | train_args["cond_dim"] = cond_dim 146 | train_args["coati_doc"] = coati_doc 147 | 148 | optimizer = torch.optim.AdamW( 149 | diff_model.parameters(), lr=args.lr, weight_decay=args.weight_decay 150 | ) 151 | if args.ema: 152 | ema = ExponentialMovingAverage(diff_model.parameters(), decay=args.ema_const) 153 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.gamma) 154 | num = 1 155 | ave_losses = [] 156 | 157 | if args.skip_train: 158 | # return score_net and cdfs 159 | return diff_model.score_net, norm_summary 160 | 161 | print("starting training... ") 162 | for epoch in range(0, args.num_epochs + 1): 163 | losses = [] 164 | with tqdm(datapipe, desc="Epoch {} ".format(epoch), unit="batch") as tepoch: 165 | for batch in tepoch: 166 | # print(batch['loss_weights']) 167 | optimizer.zero_grad() 168 | loss = diff_model( 169 | batch["samples"], 170 | cond=batch["cond_vector"], 171 | loss_weight=None, # Not doing this anymore 172 | ) 173 | loss.backward() 174 | torch.nn.utils.clip_grad_norm_(diff_model.parameters(), 1.0) 175 | optimizer.step() 176 | if args.ema: 177 | ema.update() 178 | losses.append(loss.item()) 179 | tepoch.set_postfix(loss=loss.item()) 180 | num += 1 181 | 182 | scheduler.step() 183 | ave = 0 184 | for loss in losses: 185 | ave += loss 186 | ave = ave / len(losses) 187 | ave_losses.append(ave) 188 | print("Epoch {}: Loss: {:.8f}".format(epoch, ave)) 189 | 190 | output_path = os.path.join(args.model_dir, f"{args.exp_name}_{args.run_name}_final") 191 | print(f"writing model and metadata to {output_path}") 192 | 193 | diff_model.eval() 194 | 195 | if args.ema: 196 | with ema.average_parameters(): 197 | print("serializing ema model") 198 | score_model_serialized = serialize_score_model( 199 | diff_model.score_net, norm_summary, train_args, score_model_params 200 | ) 201 | else: 202 | print("serializing non-ema model") 203 | score_model_serialized = serialize_score_model( 204 | diff_model.score_net, norm_summary, train_args, score_model_params 205 | ) 206 | 207 | # write the serialized model to local output_path 208 | with open(output_path + ".pkl", "wb") as f: 209 | f.write(score_model_serialized) 210 | 211 | return diff_model.score_net 212 | 213 | 214 | def do_args(mock_args=False): 215 | parser = argparse.ArgumentParser() 216 | parser.add_argument("--data_path", type=str, default=None) 217 | parser.add_argument("--exp_name", type=str, default="ddpm") 218 | parser.add_argument("--run_name", type=str, default=str(int(utc_epoch_now()))) 219 | parser.add_argument("--model_dir", type=str, default="./") 220 | 221 | parser.add_argument("--diff_type", type=str, default="ddpm") 222 | parser.add_argument("--score_model", type=str, default="non_conv_unet") 223 | parser.add_argument("--scalar_cond_fields", type=list, default=["logp"]) 224 | parser.add_argument("--dim_per_cond", type=int, default=16) 225 | parser.add_argument("--dropout", type=float, default=0.0) 226 | parser.add_argument("--time_dim", type=int, default=16) 227 | parser.add_argument("--timesteps", type=int, default=1000) 228 | 229 | parser.add_argument("--beta_start", type=float, default=1e-4) 230 | parser.add_argument("--beta_end", type=float, default=0.02) 231 | 232 | parser.add_argument("--bias", type=bool, default=False) 233 | parser.add_argument("--use_weight_norm", type=bool, default=True) 234 | 235 | parser.add_argument("--device", type=str, default="cuda:0") 236 | parser.add_argument("--num_epochs", type=int, default=50) 237 | parser.add_argument("--skip_train", type=bool, default=False) 238 | parser.add_argument("--save_every_n_epochs", type=int, default=1) 239 | parser.add_argument("--batch_size", type=int, default=896) 240 | parser.add_argument("--gamma", type=float, default=0.9) 241 | parser.add_argument("--weight_decay", type=float, default=1e-4) 242 | parser.add_argument("--lr", type=float, default=2e-4) 243 | parser.add_argument("--schedule", type=str, default="linear") 244 | parser.add_argument("--ema", type=bool, default=True) 245 | parser.add_argument("--ema_const", type=float, default=0.996) 246 | if mock_args: 247 | args = parser.parse_args(args=[]) 248 | else: 249 | args = parser.parse_args() 250 | return args 251 | -------------------------------------------------------------------------------- /coatiLDM/models/coati/tokenizers/trie.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from collections import OrderedDict 3 | 4 | 5 | class Trie: 6 | """ 7 | Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass 8 | Loose reference https://en.wikipedia.org/wiki/Trie 9 | """ 10 | 11 | def __init__(self): 12 | self.data = {} 13 | 14 | def add(self, word: str): 15 | """ 16 | Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation. 17 | The special key `""` is used to represent termination. 18 | This function is idempotent, adding twice the same word will leave the trie unchanged 19 | Example: 20 | ```python 21 | >>> trie = Trie() 22 | >>> trie.add("Hello 友達") 23 | >>> trie.data 24 | {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}} 25 | >>> trie.add("Hello") 26 | >>> trie.data 27 | {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}} 28 | ``` 29 | """ 30 | if not word: 31 | # Prevent empty string 32 | return 33 | ref = self.data 34 | for char in word: 35 | ref[char] = char in ref and ref[char] or {} 36 | ref = ref[char] 37 | ref[""] = 1 38 | 39 | def split(self, text: str) -> List[str]: 40 | """ 41 | Will look for the words added to the trie within `text`. Output is the original string splitted along the 42 | boundaries of the words found. 43 | This trie will match the longest possible word first ! 44 | Example: 45 | ```python 46 | >>> trie = Trie() 47 | >>> trie.split("[CLS] This is a extra_id_100") 48 | ["[CLS] This is a extra_id_100"] 49 | >>> trie.add("[CLS]") 50 | >>> trie.add("extra_id_1") 51 | >>> trie.add("extra_id_100") 52 | >>> trie.split("[CLS] This is a extra_id_100") 53 | ["[CLS]", " This is a ", "extra_id_100"] 54 | ``` 55 | """ 56 | # indexes are counted left of the chars index. 57 | # "hello", index 0, is left of h, index 1 is between h and e. 58 | # index 5 is right of the "o". 59 | 60 | # States are going to capture every possible start (indexes as above) 61 | # as keys, and have as values, a pointer to the position in the trie 62 | # where we're at. This is a partial match for now. 63 | # This enables to keep track of multiple matches while we're iterating 64 | # the string 65 | # If the trie contains, "blowing", and "lower" and we encounter the 66 | # string "blower", we need to split into ["b", "lower"]. 67 | # This is where we need to keep track of multiple possible starts. 68 | states = OrderedDict() 69 | 70 | # This will contain every indices where we need 71 | # to cut. 72 | # We force to cut at offset 0 and len(text) (added later) 73 | offsets = [0] 74 | 75 | # This is used by the lookahead which needs to skip over 76 | # some text where the full match exceeded the place in the initial 77 | # for loop 78 | skip = 0 79 | # Main loop, Giving this algorithm O(n) complexity 80 | for current, current_char in enumerate(text): 81 | if skip and current < skip: 82 | # Prevents the lookahead for matching twice 83 | # like extra_id_100 and id_100 84 | continue 85 | 86 | # This will track every state 87 | # that stop matching, we need to stop tracking them. 88 | # If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then 89 | # fail on "b", we need to remove 0 from the valid states. 90 | to_remove = set() 91 | # Whenever we found a match, we need to drop everything 92 | # this is a greedy algorithm, it will match on the first found token 93 | reset = False 94 | 95 | # In this case, we already have partial matches (But unfinished) 96 | for start, trie_pointer in states.items(): 97 | if "" in trie_pointer: 98 | # This is a final match, we need to reset and 99 | # store the results in `offsets`. 100 | 101 | # Lookahead to match longest first 102 | # Important in case of extra_id_1 vs extra_id_100 103 | # Here we are also actively looking for other earlier partial 104 | # matches 105 | # "[CLS]", "L", we need to match CLS even if L is special 106 | for lookstart, looktrie_pointer in states.items(): 107 | if lookstart > start: 108 | # This partial match is later, we can stop looking 109 | break 110 | elif lookstart < start: 111 | # This partial match is earlier, the trie pointer 112 | # was already updated, so index is + 1 113 | lookahead_index = current + 1 114 | end = current + 1 115 | else: 116 | # Here lookstart == start and 117 | # looktrie_pointer == trie_pointer 118 | # It wasn't updated yet so indices are current ones 119 | lookahead_index = current 120 | end = current 121 | next_char = ( 122 | text[lookahead_index] 123 | if lookahead_index < len(text) 124 | else None 125 | ) 126 | if "" in looktrie_pointer: 127 | start = lookstart 128 | end = lookahead_index 129 | skip = lookahead_index 130 | 131 | while next_char in looktrie_pointer: 132 | looktrie_pointer = looktrie_pointer[next_char] 133 | lookahead_index += 1 134 | if "" in looktrie_pointer: 135 | start = lookstart 136 | end = lookahead_index 137 | skip = lookahead_index 138 | 139 | if lookahead_index == len(text): 140 | # End of string 141 | break 142 | next_char = text[lookahead_index] 143 | # End lookahead 144 | 145 | # Storing and resetting 146 | offsets.append(start) 147 | offsets.append(end) 148 | reset = True 149 | break 150 | elif current_char in trie_pointer: 151 | # The current character being looked at has a match within the trie 152 | # update the pointer (it will be stored back into states later). 153 | trie_pointer = trie_pointer[current_char] 154 | 155 | # Storing back the new pointer into the states. 156 | # Partial matches got longer by one. 157 | states[start] = trie_pointer 158 | else: 159 | # The new character has not match in the trie, we need 160 | # to stop keeping track of this partial match. 161 | # We can't do it directly within the loop because of how 162 | # python iteration works 163 | to_remove.add(start) 164 | 165 | # Either clearing the full start (we found a real match) 166 | # Or clearing only the partial matches that didn't work. 167 | if reset: 168 | states = {} 169 | else: 170 | for start in to_remove: 171 | del states[start] 172 | 173 | # If this character is a starting character within the trie 174 | # start keeping track of this partial match. 175 | if current >= skip and current_char in self.data: 176 | states[current] = self.data[current_char] 177 | 178 | # We have a cut at the end with states. 179 | for start, trie_pointer in states.items(): 180 | if "" in trie_pointer: 181 | # This is a final match, we need to reset and 182 | # store the results in `offsets`. 183 | end = len(text) 184 | offsets.append(start) 185 | offsets.append(end) 186 | # Longest cut is always the one with lower start so the first 187 | # item so we need to break. 188 | break 189 | 190 | return self.cut_text(text, offsets) 191 | 192 | def cut_text(self, text, offsets): 193 | # We have all the offsets now, we just need to do the actual splitting. 194 | # We need to eventually add the first part of the string and the eventual 195 | # last part. 196 | offsets.append(len(text)) 197 | tokens = [] 198 | start = 0 199 | for end in offsets: 200 | if start > end: 201 | raise Exception( 202 | "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it" 203 | " anyway." 204 | ) 205 | continue 206 | elif start == end: 207 | # This might happen if there's a match at index 0 208 | # we're also preventing zero-width cuts in case of two 209 | # consecutive matches 210 | continue 211 | tokens.append(text[start:end]) 212 | start = end 213 | 214 | return tokens 215 | -------------------------------------------------------------------------------- /coatiLDM/models/diffusion_models/flow_matching.py: -------------------------------------------------------------------------------- 1 | # 2 | # Defines 3 types of flow matching and an adapter to create a 3 | # conditional vector field 4 | # shamelessly borrowed from this excellent notebook: 5 | # https://colab.research.google.com/github/gle-bellier/flow-matching/blob/main/Flow_Matching.ipynb#scrollTo=l_NPXHeSNg8Y 6 | # 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from zuko.utils import odeint 12 | 13 | 14 | class CondOTFlowMatching: 15 | def __init__(self, sig_min: float = 0.001) -> None: 16 | super().__init__() 17 | self.sig_min = sig_min 18 | self.eps = 1e-5 19 | 20 | def psi_t( 21 | self, x: torch.Tensor, x_1: torch.Tensor, t: torch.Tensor 22 | ) -> torch.Tensor: 23 | """Conditional Flow""" 24 | return (1 - (1 - self.sig_min) * t) * x + t * x_1 25 | 26 | def loss(self, v_t: nn.Module, x_1: torch.Tensor, cs) -> torch.Tensor: 27 | """Compute loss""" 28 | # t ~ Unif([0, 1]) 29 | t = ( 30 | torch.rand(1, device=x_1.device) 31 | + torch.arange(len(x_1), device=x_1.device) / len(x_1) 32 | ) % (1 - self.eps) 33 | t = t[:, None].expand(x_1.shape) 34 | # x ~ p_t(x_0) 35 | x_0 = torch.randn_like(x_1) 36 | v_psi = v_t(t[:, 0], self.psi_t(x_0, x_1, t), cs) 37 | d_psi = x_1 - (1 - self.sig_min) * x_0 38 | return torch.mean((v_psi - d_psi) ** 2) 39 | 40 | 41 | class OTFlowMatching: 42 | def __init__(self, sig_min: float = 0.001) -> None: 43 | super().__init__() 44 | self.sig_min = sig_min 45 | self.eps = 1e-5 46 | 47 | def psi_t( 48 | self, x: torch.Tensor, x_1: torch.Tensor, t: torch.Tensor 49 | ) -> torch.Tensor: 50 | """Conditional Flow""" 51 | return (1 - (1 - self.sig_min) * t) * x + t * x_1 52 | 53 | def loss(self, v_t: nn.Module, x_1: torch.Tensor) -> torch.Tensor: 54 | """Compute loss""" 55 | # t ~ Unif([0, 1]) 56 | t = ( 57 | torch.rand(1, device=x_1.device) 58 | + torch.arange(len(x_1), device=x_1.device) / len(x_1) 59 | ) % (1 - self.eps) 60 | t = t[:, None].expand(x_1.shape) 61 | # x ~ p_t(x_0) 62 | x_0 = torch.randn_like(x_1) 63 | v_psi = v_t(t[:, 0], self.psi_t(x_0, x_1, t)) 64 | d_psi = x_1 - (1 - self.sig_min) * x_0 65 | return torch.mean((v_psi - d_psi) ** 2) 66 | 67 | 68 | class VEDiffusionFlowMatching: 69 | def __init__(self) -> None: 70 | super().__init__() 71 | self.sigma_min = 0.01 72 | self.sigma_max = 2.0 73 | self.eps = 1e-5 74 | 75 | def sigma_t(self, t: torch.Tensor) -> torch.Tensor: 76 | return self.sigma_min * (self.sigma_max / self.sigma_min) ** t 77 | 78 | def dsigma_dt(self, t: torch.Tensor) -> torch.Tensor: 79 | return self.sigma_t(t) * torch.log( 80 | torch.tensor(self.sigma_max / self.sigma_min) 81 | ) 82 | 83 | def u_t(self, t: torch.Tensor, x: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor: 84 | return -(self.dsigma_dt(1.0 - t) / self.sigma_t(1.0 - t)) * (x - x_1) 85 | 86 | def loss(self, v_t: nn.Module, x_1: torch.Tensor) -> torch.Tensor: 87 | """Compute loss""" 88 | # t ~ Unif([0, 1]) 89 | t = ( 90 | torch.rand(1, device=x_1.device) 91 | + torch.arange(len(x_1), device=x_1.device) / len(x_1) 92 | ) % (1 - self.eps) 93 | t = t[:, None].expand(x_1.shape) 94 | # x ~ p_t(x|x_1) 95 | x = x_1 + self.sigma_t(1.0 - t) * torch.randn_like(x_1) 96 | return torch.mean((v_t(t[:, 0], x) - self.u_t(t, x, x_1)) ** 2) 97 | 98 | 99 | class VPDiffusionFlowMatching: 100 | def __init__(self) -> None: 101 | super().__init__() 102 | self.beta_min = 0.1 103 | self.beta_max = 20.0 104 | self.eps = 1e-5 105 | 106 | def T(self, s: torch.Tensor) -> torch.Tensor: 107 | return self.beta_min * s + 0.5 * (s**2) * (self.beta_max - self.beta_min) 108 | 109 | def beta(self, t: torch.Tensor) -> torch.Tensor: 110 | return self.beta_min + t * (self.beta_max - self.beta_min) 111 | 112 | def alpha(self, t: torch.Tensor) -> torch.Tensor: 113 | return torch.exp(-0.5 * self.T(t)) 114 | 115 | def mu_t(self, t: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor: 116 | return self.alpha(1.0 - t) * x_1 117 | 118 | def sigma_t(self, t: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor: 119 | return torch.sqrt(1.0 - self.alpha(1.0 - t) ** 2) 120 | 121 | def u_t(self, t: torch.Tensor, x: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor: 122 | num = torch.exp(-self.T(1.0 - t)) * x - torch.exp(-0.5 * self.T(1.0 - t)) * x_1 123 | denum = 1.0 - torch.exp(-self.T(1.0 - t)) 124 | return -0.5 * self.beta(1.0 - t) * (num / denum) 125 | 126 | def loss(self, v_t: nn.Module, x_1: torch.Tensor) -> torch.Tensor: 127 | """Compute loss""" 128 | # t ~ Unif([0, 1]) 129 | t = ( 130 | torch.rand(1, device=x_1.device) 131 | + torch.arange(len(x_1), device=x_1.device) / len(x_1) 132 | ) % (1 - self.eps) 133 | t = t[:, None].expand(x_1.shape) 134 | # x ~ p_t(x|x_1) 135 | x = self.mu_t(t, x_1) + self.sigma_t(t, x_1) * torch.randn_like(x_1) 136 | return torch.mean((v_t(t[:, 0], x) - self.u_t(t, x, x_1)) ** 2) 137 | 138 | 139 | class CondVF(nn.Module): 140 | """ 141 | conditional vector field... 142 | """ 143 | 144 | def __init__(self, net: nn.Module, n_steps: int = 100) -> None: 145 | super().__init__() 146 | self.net = net 147 | 148 | def forward(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor: 149 | return self.net(t, x) 150 | 151 | def wrapper(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor: 152 | """ 153 | Args t: a zero-D tensor of time in [0,1] 154 | """ 155 | t = t * torch.ones(len(x), device=x.device) 156 | return self(t, x) 157 | 158 | def decode_t0_t1(self, x_0, t0, t1): 159 | return odeint(self.wrapper, x_0, t0, t1, self.parameters()) 160 | 161 | def encode(self, x_1: torch.Tensor) -> torch.Tensor: 162 | return odeint(self.wrapper, x_1, 1.0, 0.0, self.parameters()) 163 | 164 | def decode(self, x_0: torch.Tensor) -> torch.Tensor: 165 | return odeint(self.wrapper, x_0, 0.0, 1.0, self.parameters()) 166 | 167 | 168 | class CondCondVF(nn.Module): 169 | """ 170 | conditional vector field... 171 | """ 172 | 173 | def __init__(self, net: nn.Module, n_steps: int = 100) -> None: 174 | super().__init__() 175 | self.net = net 176 | 177 | def forward(self, t: torch.Tensor, x: torch.Tensor, cs) -> torch.Tensor: 178 | return self.net(t, x, cs) 179 | 180 | def wrapper(self, t: torch.Tensor, x: torch.Tensor, cs) -> torch.Tensor: 181 | """ 182 | Args t: a zero-D tensor of time in [0,1] 183 | """ 184 | t = t * torch.ones(len(x), device=x.device) 185 | return self(t, x, cs) 186 | 187 | def decode_t0_t1(self, x_0, t0, t1, cs): 188 | return odeint( 189 | lambda T, X: self.wrapper(T, X, cs), x_0, t0, t1, self.parameters() 190 | ) 191 | 192 | def encode(self, x_1: torch.Tensor, cs) -> torch.Tensor: 193 | return odeint( 194 | lambda T, X: self.wrapper(T, X, cs), x_1, 1.0, 0.0, self.parameters() 195 | ) 196 | 197 | def decode(self, x_0: torch.Tensor, cs) -> torch.Tensor: 198 | return odeint( 199 | lambda T, X: self.wrapper(T, X, cs), x_0, 0.0, 1.0, self.parameters() 200 | ) 201 | 202 | 203 | class ScoreNetCondVF(nn.Module): 204 | """ 205 | conditional vector field... 206 | """ 207 | 208 | def __init__(self, net: nn.Module, n_steps: int = 100) -> None: 209 | super().__init__() 210 | self.score_net = net 211 | 212 | def forward(self, t: torch.Tensor, x: torch.Tensor, cs) -> torch.Tensor: 213 | return self.score_net(x, t, cs) 214 | 215 | def wrapper(self, t: torch.Tensor, x: torch.Tensor, cs) -> torch.Tensor: 216 | """ 217 | Args t: a zero-D tensor of time in [0,1] 218 | """ 219 | t = t * torch.ones(len(x), device=x.device) 220 | return self(t, x, cs) 221 | 222 | def decode_t0_t1(self, x_0, t0, t1, cs): 223 | return odeint( 224 | lambda T, X: self.wrapper(T, X, cs), x_0, t0, t1, self.parameters() 225 | ) 226 | 227 | def encode(self, x_1: torch.Tensor, cs) -> torch.Tensor: 228 | return odeint( 229 | lambda T, X: self.wrapper(T, X, cs), x_1, 1.0, 0.0, self.parameters() 230 | ) 231 | 232 | def decode(self, x_0: torch.Tensor, cs) -> torch.Tensor: 233 | return odeint( 234 | lambda T, X: self.wrapper(T, X, cs), x_0, 0.0, 1.0, self.parameters() 235 | ) 236 | 237 | 238 | _one_third = 1.0 / 3 239 | _two_thirds = 2.0 / 3 240 | 241 | 242 | def euler_step(func, x0, t0, dt=1.0 / 1000.0, f0=None): 243 | x1 = func(t0, x0) 244 | return x0 + dt * (x1) 245 | 246 | 247 | def rk4_step(func, x0, t0, dt=1.0 / 1000.0, f0=None): 248 | """Smaller error with slightly more compute.""" 249 | k1 = f0 250 | if k1 is None: 251 | k1 = func(t0, x0) 252 | k2 = func(t0 + dt * _one_third, x0 + dt * k1 * _one_third) 253 | k3 = func(t0 + dt * _two_thirds, x0 + dt * (k2 - k1 * _one_third)) 254 | k4 = func(t0 + dt, x0 + dt * (k1 - k2 + k3)) 255 | return x0 + (k1 + 3 * (k2 + k3) + k4) * dt * 0.125 256 | 257 | 258 | def ode_int_explicit( 259 | func, x0, t0=0.0, t1=1.0, nT=1000, differentiable=False, rule=euler_step 260 | ): 261 | """ 262 | Solves 263 | d(X(t))/dt = func(t, x) 264 | """ 265 | x = x0.clone() 266 | dt = (t1 - t0) / nT 267 | if differentiable: 268 | with torch.grad_enable(): 269 | for i in range(nT): 270 | t = t0 + i * dt 271 | x = rule(func, x, t * torch.ones(x.shape[0], device=x.device), dt=dt) 272 | 273 | else: 274 | with torch.no_grad(): 275 | for i in range(nT): 276 | t = t0 + i * dt 277 | x = rule(func, x, t * torch.ones(x.shape[0], device=x.device), dt=dt) 278 | return x 279 | 280 | 281 | class OT_cond_flow_matching(nn.Module): 282 | """ 283 | For use with allegro_vector_field 284 | """ 285 | 286 | def __init__(self, score_net): 287 | super().__init__() 288 | self.score_net = score_net 289 | self.eps = 1e-5 290 | self.sig_min = 1e-5 291 | # self.sig_min = 0.001 292 | 293 | def forward(self, x1, x0, cond=None): 294 | """ 295 | Returns the loss of this type of flow matching. 296 | Requires the prior x0 samples (same shape as x1) 297 | """ 298 | t = ( 299 | torch.rand(1, device=x1.device) 300 | + torch.arange(x1.shape[0], device=x1.device) / x1.shape[0] 301 | ) % (1 - self.eps) 302 | psi_t = ( 303 | t.unsqueeze(-1) * x1 + (1.0 - (1.0 - self.sig_min) * t.unsqueeze(-1)) * x0 304 | ) 305 | dpsi_dt = x1 - (1.0 - self.sig_min) * x0 306 | return torch.pow((self.score_net(psi_t, t, cond=cond) - dpsi_dt), 2.0).mean() 307 | 308 | def sample( 309 | self, 310 | x0, 311 | t0=0.0, 312 | t1=1.0, 313 | nT=1000, 314 | rule=euler_step, 315 | differentiable=False, 316 | cond=None, 317 | ): 318 | return ode_int_explicit( 319 | lambda t, x: self.score_net(x, t, cond=cond), 320 | x0, 321 | t0=t0, 322 | t1=t1, 323 | nT=nT, 324 | rule=rule, 325 | differentiable=differentiable, 326 | ) 327 | -------------------------------------------------------------------------------- /coatiLDM/common/ema.py: -------------------------------------------------------------------------------- 1 | # Borrowed from https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py 2 | 3 | from __future__ import division 4 | from __future__ import unicode_literals 5 | 6 | from typing import Iterable, Optional 7 | import weakref 8 | import copy 9 | import contextlib 10 | 11 | import torch 12 | 13 | 14 | # Partially based on: 15 | # https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py 16 | class ExponentialMovingAverage: 17 | """ 18 | Maintains (exponential) moving average of a set of parameters. 19 | 20 | Args: 21 | parameters: Iterable of `torch.nn.Parameter` (typically from 22 | `model.parameters()`). 23 | Note that EMA is computed on *all* provided parameters, 24 | regardless of whether or not they have `requires_grad = True`; 25 | this allows a single EMA object to be consistantly used even 26 | if which parameters are trainable changes step to step. 27 | 28 | If you want to some parameters in the EMA, do not pass them 29 | to the object in the first place. For example: 30 | 31 | ExponentialMovingAverage( 32 | parameters=[p for p in model.parameters() if p.requires_grad], 33 | decay=0.9 34 | ) 35 | 36 | will ignore parameters that do not require grad. 37 | 38 | decay: The exponential decay. 39 | 40 | use_num_updates: Whether to use number of updates when computing 41 | averages. 42 | """ 43 | 44 | def __init__( 45 | self, 46 | parameters: Iterable[torch.nn.Parameter], 47 | decay: float, 48 | use_num_updates: bool = True, 49 | ): 50 | if decay < 0.0 or decay > 1.0: 51 | raise ValueError("Decay must be between 0 and 1") 52 | self.decay = decay 53 | self.num_updates = 0 if use_num_updates else None 54 | parameters = list(parameters) 55 | self.shadow_params = [p.clone().detach() for p in parameters] 56 | self.collected_params = None 57 | # By maintaining only a weakref to each parameter, 58 | # we maintain the old GC behaviour of ExponentialMovingAverage: 59 | # if the model goes out of scope but the ExponentialMovingAverage 60 | # is kept, no references to the model or its parameters will be 61 | # maintained, and the model will be cleaned up. 62 | self._params_refs = [weakref.ref(p) for p in parameters] 63 | 64 | def _get_parameters( 65 | self, parameters: Optional[Iterable[torch.nn.Parameter]] 66 | ) -> Iterable[torch.nn.Parameter]: 67 | if parameters is None: 68 | parameters = [p() for p in self._params_refs] 69 | if any(p is None for p in parameters): 70 | raise ValueError( 71 | "(One of) the parameters with which this " 72 | "ExponentialMovingAverage " 73 | "was initialized no longer exists (was garbage collected);" 74 | " please either provide `parameters` explicitly or keep " 75 | "the model to which they belong from being garbage " 76 | "collected." 77 | ) 78 | return parameters 79 | else: 80 | parameters = list(parameters) 81 | if len(parameters) != len(self.shadow_params): 82 | raise ValueError( 83 | "Number of parameters passed as argument is different " 84 | "from number of shadow parameters maintained by this " 85 | "ExponentialMovingAverage" 86 | ) 87 | return parameters 88 | 89 | def update(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None: 90 | """ 91 | Update currently maintained parameters. 92 | 93 | Call this every time the parameters are updated, such as the result of 94 | the `optimizer.step()` call. 95 | 96 | Args: 97 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 98 | parameters used to initialize this object. If `None`, the 99 | parameters with which this `ExponentialMovingAverage` was 100 | initialized will be used. 101 | """ 102 | parameters = self._get_parameters(parameters) 103 | decay = self.decay 104 | if self.num_updates is not None: 105 | self.num_updates += 1 106 | decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) 107 | one_minus_decay = 1.0 - decay 108 | with torch.no_grad(): 109 | for s_param, param in zip(self.shadow_params, parameters): 110 | tmp = s_param - param 111 | # tmp will be a new tensor so we can do in-place 112 | tmp.mul_(one_minus_decay) 113 | s_param.sub_(tmp) 114 | 115 | def copy_to( 116 | self, parameters: Optional[Iterable[torch.nn.Parameter]] = None 117 | ) -> None: 118 | """ 119 | Copy current averaged parameters into given collection of parameters. 120 | 121 | Args: 122 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 123 | updated with the stored moving averages. If `None`, the 124 | parameters with which this `ExponentialMovingAverage` was 125 | initialized will be used. 126 | """ 127 | parameters = self._get_parameters(parameters) 128 | for s_param, param in zip(self.shadow_params, parameters): 129 | param.data.copy_(s_param.data) 130 | 131 | def store(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None: 132 | """ 133 | Save the current parameters for restoring later. 134 | 135 | Args: 136 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 137 | temporarily stored. If `None`, the parameters of with which this 138 | `ExponentialMovingAverage` was initialized will be used. 139 | """ 140 | parameters = self._get_parameters(parameters) 141 | self.collected_params = [param.clone() for param in parameters] 142 | 143 | def restore( 144 | self, parameters: Optional[Iterable[torch.nn.Parameter]] = None 145 | ) -> None: 146 | """ 147 | Restore the parameters stored with the `store` method. 148 | Useful to validate the model with EMA parameters without affecting the 149 | original optimization process. Store the parameters before the 150 | `copy_to` method. After validation (or model saving), use this to 151 | restore the former parameters. 152 | 153 | Args: 154 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 155 | updated with the stored parameters. If `None`, the 156 | parameters with which this `ExponentialMovingAverage` was 157 | initialized will be used. 158 | """ 159 | if self.collected_params is None: 160 | raise RuntimeError( 161 | "This ExponentialMovingAverage has no `store()`ed weights " 162 | "to `restore()`" 163 | ) 164 | parameters = self._get_parameters(parameters) 165 | for c_param, param in zip(self.collected_params, parameters): 166 | param.data.copy_(c_param.data) 167 | 168 | @contextlib.contextmanager 169 | def average_parameters( 170 | self, parameters: Optional[Iterable[torch.nn.Parameter]] = None 171 | ): 172 | r""" 173 | Context manager for validation/inference with averaged parameters. 174 | 175 | Equivalent to: 176 | 177 | ema.store() 178 | ema.copy_to() 179 | try: 180 | ... 181 | finally: 182 | ema.restore() 183 | 184 | Args: 185 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 186 | updated with the stored parameters. If `None`, the 187 | parameters with which this `ExponentialMovingAverage` was 188 | initialized will be used. 189 | """ 190 | parameters = self._get_parameters(parameters) 191 | self.store(parameters) 192 | self.copy_to(parameters) 193 | try: 194 | yield 195 | finally: 196 | self.restore(parameters) 197 | 198 | def to(self, device=None, dtype=None) -> None: 199 | r"""Move internal buffers of the ExponentialMovingAverage to `device`. 200 | 201 | Args: 202 | device: like `device` argument to `torch.Tensor.to` 203 | """ 204 | # .to() on the tensors handles None correctly 205 | self.shadow_params = [ 206 | ( 207 | p.to(device=device, dtype=dtype) 208 | if p.is_floating_point() 209 | else p.to(device=device) 210 | ) 211 | for p in self.shadow_params 212 | ] 213 | if self.collected_params is not None: 214 | self.collected_params = [ 215 | ( 216 | p.to(device=device, dtype=dtype) 217 | if p.is_floating_point() 218 | else p.to(device=device) 219 | ) 220 | for p in self.collected_params 221 | ] 222 | return 223 | 224 | def state_dict(self) -> dict: 225 | r"""Returns the state of the ExponentialMovingAverage as a dict.""" 226 | # Following PyTorch conventions, references to tensors are returned: 227 | # "returns a reference to the state and not its copy!" - 228 | # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict 229 | return { 230 | "decay": self.decay, 231 | "num_updates": self.num_updates, 232 | "shadow_params": self.shadow_params, 233 | "collected_params": self.collected_params, 234 | } 235 | 236 | def load_state_dict(self, state_dict: dict) -> None: 237 | r"""Loads the ExponentialMovingAverage state. 238 | 239 | Args: 240 | state_dict (dict): EMA state. Should be an object returned 241 | from a call to :meth:`state_dict`. 242 | """ 243 | # deepcopy, to be consistent with module API 244 | state_dict = copy.deepcopy(state_dict) 245 | self.decay = state_dict["decay"] 246 | if self.decay < 0.0 or self.decay > 1.0: 247 | raise ValueError("Decay must be between 0 and 1") 248 | self.num_updates = state_dict["num_updates"] 249 | assert self.num_updates is None or isinstance( 250 | self.num_updates, int 251 | ), "Invalid num_updates" 252 | 253 | self.shadow_params = state_dict["shadow_params"] 254 | assert isinstance(self.shadow_params, list), "shadow_params must be a list" 255 | assert all( 256 | isinstance(p, torch.Tensor) for p in self.shadow_params 257 | ), "shadow_params must all be Tensors" 258 | 259 | self.collected_params = state_dict["collected_params"] 260 | if self.collected_params is not None: 261 | assert isinstance( 262 | self.collected_params, list 263 | ), "collected_params must be a list" 264 | assert all( 265 | isinstance(p, torch.Tensor) for p in self.collected_params 266 | ), "collected_params must all be Tensors" 267 | assert len(self.collected_params) == len( 268 | self.shadow_params 269 | ), "collected_params and shadow_params had different lengths" 270 | 271 | if len(self.shadow_params) == len(self._params_refs): 272 | # Consistant with torch.optim.Optimizer, cast things to consistant 273 | # device and dtype with the parameters 274 | params = [p() for p in self._params_refs] 275 | # If parameters have been garbage collected, just load the state 276 | # we were given without change. 277 | if not any(p is None for p in params): 278 | # ^ parameter references are still good 279 | for i, p in enumerate(params): 280 | self.shadow_params[i] = self.shadow_params[i].to( 281 | device=p.device, dtype=p.dtype 282 | ) 283 | if self.collected_params is not None: 284 | self.collected_params[i] = self.collected_params[i].to( 285 | device=p.device, dtype=p.dtype 286 | ) 287 | else: 288 | raise ValueError( 289 | "Tried to `load_state_dict()` with the wrong number of " 290 | "parameters in the saved state." 291 | ) 292 | -------------------------------------------------------------------------------- /coatiLDM/models/coati/basic_transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Causal and non-causal transformer blocks. 3 | """ 4 | 5 | import math 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | 12 | class NewGELU(nn.Module): 13 | """ 14 | Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). 15 | Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 16 | """ 17 | 18 | def forward(self, x): 19 | return ( 20 | 0.5 21 | * x 22 | * ( 23 | 1.0 24 | + torch.tanh( 25 | math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)) 26 | ) 27 | ) 28 | ) 29 | 30 | 31 | class RotaryEmbedding(torch.nn.Module): 32 | def __init__( 33 | self, 34 | n_seq=256, 35 | n_embd: int = 128, 36 | n_tok: int = 512, 37 | n_head=8, 38 | norm_embed=False, 39 | device=torch.device("cpu"), 40 | dtype=torch.float, 41 | base=10000, 42 | ): 43 | """ 44 | Eq. (34) of https://arxiv.org/pdf/2104.09864.pdf 45 | also inspired by https://blog.eleuther.ai/rotary-embeddings/ 46 | The rotation is done after the hidden dimension is split into heads. 47 | so, the cached sin/cos tensors operate on a space (n_embd // n_head) 48 | 49 | Args: 50 | n_seq: Maximum sequence dimension. 51 | n_embd: embedding dimension (pre head split) 52 | n_tok: size of tokenspace. 53 | n_head: number of attention heads. 54 | """ 55 | super().__init__() 56 | assert n_embd % (2 * n_head) == 0 57 | inv_freq = 1.0 / ( 58 | base 59 | ** ( 60 | torch.arange(0, (n_embd // n_head), 2, device=device).float() 61 | / (n_embd // n_head) 62 | ) 63 | ) 64 | t = torch.arange(n_seq, device=device).type_as(inv_freq) 65 | freqs = torch.einsum("i,j->ij", t, inv_freq) 66 | emb = torch.cat((freqs, freqs), dim=-1) # (nseq X n_embd//n_head) 67 | self.cos_cached = emb.cos() 68 | self.sin_cached = emb.sin() 69 | self.n_head = n_head 70 | self.n_seq = n_seq 71 | self.n_embd = n_embd 72 | if norm_embed: 73 | raise Exception("Depreciate soon.") 74 | self.tok_emb = nn.Sequential( 75 | nn.Embedding(n_tok, n_embd, device=device, dtype=dtype), 76 | nn.LayerNorm(n_embd), 77 | ) 78 | else: 79 | self.tok_emb = nn.Embedding(n_tok, n_embd, device=device, dtype=dtype) 80 | 81 | def forward(self, idx): 82 | return self.tok_emb(idx) 83 | 84 | def rotate(self, x): 85 | """ 86 | Rotate along the embedding dimension. 87 | """ 88 | return torch.cat([-x[..., x.shape[-1] // 2 :], x[..., : x.shape[-1] // 2]], -1) 89 | 90 | def rotary_embed(self, q, k): 91 | """ 92 | Args: 93 | q: A query (batch, n_head, seq, n_embd//n_head) 94 | k: A key. (batch, n_head, seq, n_embd//n_head) 95 | Returns: 96 | q,k (with the multiplicative rotary embedding applied.) 97 | """ 98 | seq_len = q.shape[2] 99 | cos = self.cos_cached[None, None, :seq_len, :].to(q.device) 100 | sin = self.sin_cached[None, None, :seq_len, :].to(q.device) 101 | return (q * cos) + (self.rotate(q) * sin), (k * cos) + (self.rotate(k) * sin) 102 | 103 | 104 | class RotarySelfAttention(nn.Module): 105 | """ 106 | A self attention block with rotary relative position encoding. 107 | (and causality) 108 | """ 109 | 110 | def __init__(self, config): 111 | super().__init__() 112 | assert config.n_embd % config.n_head == 0 113 | # key, query, value projections for all heads, but in a batch 114 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.biases) 115 | # output projection 116 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.biases) 117 | # causal mask to ensure that attention is only applied to the left in the input sequence 118 | self.register_buffer( 119 | "bias", 120 | torch.tril(torch.ones(config.n_seq, config.n_seq)).view( 121 | 1, 1, config.n_seq, config.n_seq 122 | ), 123 | ) 124 | self.n_head = config.n_head 125 | self.n_embd = config.n_embd 126 | 127 | def forward(self, x, rotary_embedding: RotaryEmbedding): 128 | B, T, C = ( 129 | x.size() 130 | ) # batch size, sequence length, embedding dimensionality (n_embd) 131 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 132 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 133 | k = k.view(B, T, self.n_head, C // self.n_head).transpose( 134 | 1, 2 135 | ) # (B, nh, T, hs) 136 | q = q.view(B, T, self.n_head, C // self.n_head).transpose( 137 | 1, 2 138 | ) # (B, nh, T, hs) 139 | v = v.view(B, T, self.n_head, C // self.n_head).transpose( 140 | 1, 2 141 | ) # (B, nh, T, hs) 142 | q, k = rotary_embedding.rotary_embed(q, k) 143 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 144 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 145 | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) 146 | att = F.softmax(att, dim=-1) 147 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 148 | y = ( 149 | y.transpose(1, 2).contiguous().view(B, T, C) 150 | ) # re-assemble all head outputs side by side 151 | # output projection 152 | y = self.c_proj(y) 153 | return y 154 | 155 | 156 | class RotaryBlock(nn.Module): 157 | """A causal, rotary Self-Attention Block.""" 158 | 159 | def __init__(self, config): 160 | super().__init__() 161 | self.ln_1 = nn.LayerNorm(config.n_embd) 162 | self.attn = RotarySelfAttention(config) 163 | self.ln_2 = nn.LayerNorm(config.n_embd) 164 | self.mlpf = nn.Sequential( 165 | nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.biases), 166 | NewGELU(), 167 | nn.Linear(4 * config.n_embd, config.n_embd, bias=config.biases), 168 | ) 169 | 170 | def forward(self, x, rotary_embedding: RotaryEmbedding): 171 | x = x + self.attn(self.ln_1(x), rotary_embedding) 172 | x = x + self.mlpf(self.ln_2(x)) 173 | return x 174 | 175 | 176 | class CausalSelfAttention(nn.Module): 177 | """ 178 | A vanilla multi-head masked self-attention layer with a projection at the end. 179 | It is possible to use torch.nn.MultiheadAttention here but I am including an 180 | explicit implementation here to show that there is nothing too scary here. 181 | """ 182 | 183 | def __init__(self, config): 184 | super().__init__() 185 | assert config.n_embd % config.n_head == 0 186 | # key, query, value projections for all heads, but in a batch 187 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) 188 | # output projection 189 | self.c_proj = nn.Linear(config.n_embd, config.n_embd) 190 | # causal mask to ensure that attention is only applied to the left in the input sequence 191 | self.register_buffer( 192 | "bias", 193 | torch.tril(torch.ones(config.n_seq, config.n_seq)).view( 194 | 1, 1, config.n_seq, config.n_seq 195 | ), 196 | ) 197 | self.n_head = config.n_head 198 | self.n_embd = config.n_embd 199 | 200 | def forward(self, x): 201 | B, T, C = ( 202 | x.size() 203 | ) # batch size, sequence length, embedding dimensionality (n_embd) 204 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 205 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 206 | k = k.view(B, T, self.n_head, C // self.n_head).transpose( 207 | 1, 2 208 | ) # (B, nh, T, hs) 209 | q = q.view(B, T, self.n_head, C // self.n_head).transpose( 210 | 1, 2 211 | ) # (B, nh, T, hs) 212 | v = v.view(B, T, self.n_head, C // self.n_head).transpose( 213 | 1, 2 214 | ) # (B, nh, T, hs) 215 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 216 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 217 | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) 218 | att = F.softmax(att, dim=-1) 219 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 220 | y = ( 221 | y.transpose(1, 2).contiguous().view(B, T, C) 222 | ) # re-assemble all head outputs side by side 223 | # output projection 224 | y = self.c_proj(y) 225 | return y 226 | 227 | 228 | class Block(nn.Module): 229 | """A causal Self-Attention Block.""" 230 | 231 | def __init__(self, config): 232 | super().__init__() 233 | self.ln_1 = nn.LayerNorm(config.n_embd) 234 | self.attn = CausalSelfAttention(config) 235 | self.ln_2 = nn.LayerNorm(config.n_embd) 236 | self.mlpf = nn.Sequential( 237 | nn.Linear(config.n_embd, 4 * config.n_embd), 238 | NewGELU(), 239 | nn.Linear(4 * config.n_embd, config.n_embd), 240 | ) 241 | 242 | def forward(self, x): 243 | x = x + self.attn(self.ln_1(x)) 244 | x = x + self.mlpf(self.ln_2(x)) 245 | return x 246 | 247 | 248 | class NonCausalSelfAttention(nn.Module): 249 | """ 250 | A vanilla multi-head masked self-attention layer with a projection at the end. 251 | It is possible to use torch.nn.MultiheadAttention here but I am including an 252 | explicit implementation here to show that there is nothing too scary here. 253 | """ 254 | 255 | def __init__(self, config): 256 | super().__init__() 257 | assert config.n_embd % config.n_head == 0 258 | # key, query, value projections for all heads, but in a batch 259 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) 260 | # output projection 261 | self.c_proj = nn.Linear(config.n_embd, config.n_embd) 262 | # causal mask to ensure that attention is only applied to the left in the input sequence 263 | # self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 264 | # .view(1, 1, config.block_size, config.block_size)) 265 | self.n_head = config.n_head 266 | self.n_embd = config.n_embd 267 | 268 | def forward(self, x): 269 | B, T, C = ( 270 | x.size() 271 | ) # batch size, sequence length, embedding dimensionality (n_embd) 272 | 273 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 274 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 275 | k = k.view(B, T, self.n_head, C // self.n_head).transpose( 276 | 1, 2 277 | ) # (B, nh, T, hs) 278 | q = q.view(B, T, self.n_head, C // self.n_head).transpose( 279 | 1, 2 280 | ) # (B, nh, T, hs) 281 | v = v.view(B, T, self.n_head, C // self.n_head).transpose( 282 | 1, 2 283 | ) # (B, nh, T, hs) 284 | 285 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 286 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 287 | # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) # This is why it's non-causal. 288 | att = F.softmax(att, dim=-1) 289 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 290 | y = ( 291 | y.transpose(1, 2).contiguous().view(B, T, C) 292 | ) # re-assemble all head outputs side by side 293 | 294 | # output projection 295 | y = self.c_proj(y) 296 | return y 297 | 298 | 299 | class NonCausalBlock(nn.Module): 300 | """A _n-causal_ Self-Attention Block.""" 301 | 302 | def __init__(self, config): 303 | super().__init__() 304 | self.ln_1 = nn.LayerNorm(config.n_embd) 305 | self.attn = NonCausalSelfAttention(config) 306 | self.ln_2 = nn.LayerNorm(config.n_embd) 307 | self.mlpf = nn.Sequential( 308 | nn.Linear(config.n_embd, 4 * config.n_embd), 309 | NewGELU(), 310 | nn.Linear(4 * config.n_embd, config.n_embd), 311 | ) 312 | 313 | def forward(self, x): 314 | x = x + self.attn(self.ln_1(x)) 315 | x = x + self.mlpf(self.ln_2(x)) 316 | return x 317 | -------------------------------------------------------------------------------- /coatiLDM/models/diffusion_models/ddim_sample_routines.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from coatiLDM.models.diffusion_models.particle_guidance import ( 3 | similarity_guidance_gradient, 4 | ) 5 | from coatiLDM.data.transforms import embed_scalar 6 | import numpy as np 7 | 8 | 9 | def ddim_basic_nearby( 10 | score_net, 11 | x_start, 12 | cond=None, 13 | pg_weight=0.0, 14 | embed_dim=None, 15 | eta=1.0, 16 | T_start=200, 17 | skip=1, 18 | ): 19 | 20 | assert T_start % skip == 0 21 | with torch.no_grad(): 22 | batch_size = x_start.shape[0] 23 | 24 | device = next(score_net.parameters()).device 25 | if not cond is None: 26 | if cond is None and score_net.cond_dim > 0: 27 | raise Exception("Give me a condition.") 28 | else: 29 | assert cond.shape[0] == batch_size 30 | else: 31 | # Conditions are normally distributed random variables. 32 | cond = torch.randn((batch_size, score_net.cond_dim), device=device) 33 | 34 | if embed_dim: 35 | cond = embed_scalar(cond, embedding_dim=embed_dim) 36 | 37 | T_init = torch.ones(batch_size, dtype=torch.long, device=device) * T_start 38 | noise = torch.randn((batch_size, score_net.x_dim), device=device) 39 | 40 | x_t = ( 41 | score_net.scheduler.bar_alpha(T_init).sqrt() * x_start 42 | + (1.0 - score_net.scheduler.bar_alpha(T_init)).sqrt() * noise 43 | ) 44 | 45 | # DDIM setup force this 46 | 47 | seq = range(0, T_start, skip) 48 | seq_next = [-1] + list(seq[:-1]) 49 | 50 | for T_, T_NEXT_ in zip(reversed(seq), reversed(seq_next)): 51 | T = torch.ones(batch_size, dtype=torch.long, device=device) * T_ 52 | T_NEXT = torch.ones(batch_size, dtype=torch.long, device=device) * T_NEXT_ 53 | barat = score_net.scheduler.bar_alpha(T) 54 | if T_NEXT_ == -1: 55 | barat_next = torch.ones((batch_size, 1), device=device) 56 | else: 57 | barat_next = score_net.scheduler.bar_alpha(T_NEXT) 58 | extracted_noise = score_net(x_t, t=T.float(), cond=cond) 59 | if pg_weight > 0: 60 | extracted_noise = extracted_noise + ( 61 | similarity_guidance_gradient(x_t) * pg_weight 62 | ) 63 | x_t = (x_t - extracted_noise * (1 - barat).sqrt()) / barat.sqrt() 64 | c1 = ( 65 | eta * ((1 - barat / barat_next) * (1 - barat_next) / (1 - barat)).sqrt() 66 | ) 67 | c2 = ((1 - barat_next) - c1**2).sqrt() 68 | x_t = ( 69 | barat_next.sqrt() * x_t 70 | + c1 * torch.randn_like(x_t) 71 | + c2 * extracted_noise 72 | ) 73 | 74 | return x_t.detach() 75 | 76 | 77 | def ddim_basic_sample( 78 | score_net, 79 | cond=None, 80 | batch_size=4, 81 | pg_weight=0.0, 82 | embed_dim=None, 83 | eta=1.0, 84 | ddim_steps=1000, 85 | ): 86 | with torch.no_grad(): 87 | device = next(score_net.parameters()).device 88 | if not cond is None: 89 | if cond is None and score_net.cond_dim > 0: 90 | raise Exception("Give me a condition.") 91 | else: 92 | assert cond.shape[0] == batch_size 93 | else: 94 | # Conditions are normally distributed random variables. 95 | cond = torch.randn((batch_size, score_net.cond_dim), device=device) 96 | 97 | if embed_dim: 98 | cond = embed_scalar(cond, embedding_dim=embed_dim) 99 | # torch.manual_seed(0) 100 | x_t = torch.randn((batch_size, score_net.x_dim), device=device) 101 | 102 | # DDIM setup 103 | skip = score_net.scheduler.timesteps // ddim_steps 104 | assert score_net.scheduler.timesteps % skip == 0 105 | seq = range(0, score_net.scheduler.timesteps, skip) 106 | seq_next = [-1] + list(seq[:-1]) 107 | 108 | for T_, T_NEXT_ in zip(reversed(seq), reversed(seq_next)): 109 | T = torch.ones(batch_size, dtype=torch.long, device=device) * T_ 110 | T_NEXT = torch.ones(batch_size, dtype=torch.long, device=device) * T_NEXT_ 111 | barat = score_net.scheduler.bar_alpha(T) 112 | if T_NEXT_ == -1: 113 | barat_next = torch.ones((batch_size, 1), device=device) 114 | else: 115 | barat_next = score_net.scheduler.bar_alpha(T_NEXT) 116 | extracted_noise = score_net(x_t, t=T.float(), cond=cond) 117 | if pg_weight > 0: 118 | extracted_noise = extracted_noise + ( 119 | similarity_guidance_gradient(x_t) * pg_weight 120 | ) 121 | x_t = (x_t - extracted_noise * (1 - barat).sqrt()) / barat.sqrt() 122 | c1 = ( 123 | eta * ((1 - barat / barat_next) * (1 - barat_next) / (1 - barat)).sqrt() 124 | ) 125 | c2 = ((1 - barat_next) - c1**2).sqrt() 126 | x_t = ( 127 | barat_next.sqrt() * x_t 128 | + c1 * torch.randn_like(x_t) 129 | + c2 * extracted_noise 130 | ) 131 | 132 | return x_t.detach() 133 | 134 | 135 | def ddim_cfg_nearby( 136 | uncond_score_net, 137 | cond_score_net, 138 | x_start, 139 | cond, 140 | pg_weight=0.0, 141 | cfg_weight=0.2, 142 | eta=1.0, 143 | T_start=200, 144 | skip=1, 145 | ): 146 | 147 | batch_size = x_start.size(0) 148 | try: 149 | assert uncond_score_net.cond_dim == 0 150 | except: 151 | ValueError("Unconditional score net is first argument, conditional second") 152 | 153 | try: 154 | assert uncond_score_net.scheduler.is_same(cond_score_net.scheduler) 155 | except: 156 | ValueError("Score nets must have the same noise schedule for sampling") 157 | 158 | scheduler = uncond_score_net.scheduler 159 | 160 | assert T_start % skip == 0 161 | 162 | device = next(uncond_score_net.parameters()).device 163 | assert device == next(cond_score_net.parameters()).device 164 | 165 | # DDIM setup force this 166 | 167 | seq = range(0, T_start, skip) 168 | seq_next = [-1] + list(seq[:-1]) 169 | 170 | with torch.no_grad(): 171 | T_init = torch.ones(batch_size, dtype=torch.long, device=device) * T_start 172 | noise = torch.randn((batch_size, uncond_score_net.x_dim), device=device) 173 | 174 | x_t = ( 175 | scheduler.bar_alpha(T_init).sqrt() * x_start 176 | + (1.0 - scheduler.bar_alpha(T_init)).sqrt() * noise 177 | ) 178 | 179 | for T_, T_NEXT_ in zip(reversed(seq), reversed(seq_next)): 180 | 181 | T = torch.ones(batch_size, dtype=torch.long, device=device) * T_ 182 | T_NEXT = torch.ones(batch_size, dtype=torch.long, device=device) * T_NEXT_ 183 | barat = scheduler.bar_alpha(T) 184 | if T_NEXT_ == -1: 185 | barat_next = torch.ones((batch_size, 1), device=device) 186 | else: 187 | barat_next = scheduler.bar_alpha(T_NEXT) 188 | extracted_noise_cond = cond_score_net(x_t, t=T.float(), cond=cond) 189 | extracted_noise_uncond = uncond_score_net(x_t, t=T.float(), cond=None) 190 | extracted_noise = ((1 + cfg_weight) * extracted_noise_cond) - ( 191 | cfg_weight * extracted_noise_uncond 192 | ) 193 | if pg_weight > 0: 194 | extracted_noise = extracted_noise + ( 195 | similarity_guidance_gradient(x_t) * pg_weight 196 | ) 197 | x_t = (x_t - extracted_noise * (1 - barat).sqrt()) / barat.sqrt() 198 | c1 = ( 199 | eta * ((1 - barat / barat_next) * (1 - barat_next) / (1 - barat)).sqrt() 200 | ) 201 | c2 = ((1 - barat_next) - c1**2).sqrt() 202 | x_t = ( 203 | barat_next.sqrt() * x_t 204 | + c1 * torch.randn_like(x_t) 205 | + c2 * extracted_noise 206 | ) 207 | return x_t.detach() 208 | 209 | 210 | def ddim_sample_classifier_free_guidance( 211 | uncond_score_net, 212 | cond_score_net, 213 | cond, 214 | batch_size=4, 215 | pg_weight=0.0, 216 | cfg_weight=0.2, 217 | eta=1.0, 218 | ddim_steps=1000, 219 | ): 220 | try: 221 | assert uncond_score_net.cond_dim == 0 222 | except: 223 | ValueError("Unconditional score net is first argument, conditional second") 224 | 225 | try: 226 | assert uncond_score_net.scheduler.is_same(cond_score_net.scheduler) 227 | except: 228 | ValueError("Score nets must have the same noise schedule for sampling") 229 | 230 | scheduler = uncond_score_net.scheduler 231 | 232 | with torch.no_grad(): 233 | device = next(uncond_score_net.parameters()).device 234 | assert device == next(cond_score_net.parameters()).device 235 | 236 | x_t = torch.randn((batch_size, uncond_score_net.x_dim), device=device) 237 | 238 | # DDIM setup 239 | skip = scheduler.timesteps // ddim_steps 240 | assert scheduler.timesteps % skip == 0 241 | seq = range(0, scheduler.timesteps, skip) 242 | seq_next = [-1] + list(seq[:-1]) 243 | for T_, T_NEXT_ in zip(reversed(seq), reversed(seq_next)): 244 | 245 | T = torch.ones(batch_size, dtype=torch.long, device=device) * T_ 246 | T_NEXT = torch.ones(batch_size, dtype=torch.long, device=device) * T_NEXT_ 247 | barat = scheduler.bar_alpha(T) 248 | if T_NEXT_ == -1: 249 | barat_next = torch.ones((batch_size, 1), device=device) 250 | else: 251 | barat_next = scheduler.bar_alpha(T_NEXT) 252 | extracted_noise_cond = cond_score_net(x_t, t=T.float(), cond=cond) 253 | extracted_noise_uncond = uncond_score_net(x_t, t=T.float(), cond=None) 254 | extracted_noise = ((1 + cfg_weight) * extracted_noise_cond) - ( 255 | cfg_weight * extracted_noise_uncond 256 | ) 257 | if pg_weight > 0: 258 | extracted_noise = extracted_noise + ( 259 | similarity_guidance_gradient(x_t) * pg_weight 260 | ) 261 | x_t = (x_t - extracted_noise * (1 - barat).sqrt()) / barat.sqrt() 262 | c1 = ( 263 | eta * ((1 - barat / barat_next) * (1 - barat_next) / (1 - barat)).sqrt() 264 | ) 265 | c2 = ((1 - barat_next) - c1**2).sqrt() 266 | x_t = ( 267 | barat_next.sqrt() * x_t 268 | + c1 * torch.randn_like(x_t) 269 | + c2 * extracted_noise 270 | ) 271 | return x_t.detach() 272 | 273 | 274 | def ddim_sample_classifier_guidance( 275 | score_net, 276 | batch_size, 277 | cg_weight, 278 | cg_due, 279 | cg_targets, 280 | cond=None, 281 | pg_weight=0.0, 282 | eta=1.0, 283 | ddim_steps=1000, 284 | ): 285 | try: 286 | score_net.scheduler.is_same(cg_due.scheduler) 287 | except: 288 | raise ValueError( 289 | f"classifier must share noise schedule with score net. different betas detected" 290 | ) 291 | 292 | device = next(score_net.parameters()).device 293 | if not cond is None: 294 | if cond is None and score_net.cond_dim > 0: 295 | raise Exception("Give me a condition.") 296 | else: 297 | assert cond.shape[0] == batch_size 298 | else: 299 | # Conditions are normally distributed random variables. 300 | cond = torch.randn((batch_size, score_net.cond_dim), device=device) 301 | 302 | x_t = torch.randn((batch_size, score_net.x_dim), device=device, requires_grad=True) 303 | 304 | # DDIM setup 305 | skip = score_net.scheduler.timesteps // ddim_steps 306 | assert score_net.scheduler.timesteps % skip == 0 307 | seq = range(0, score_net.scheduler.timesteps, skip) 308 | seq_next = [-1] + list(seq[:-1]) 309 | for T_, T_NEXT_ in zip(reversed(seq), reversed(seq_next)): 310 | 311 | T = torch.ones(batch_size, dtype=torch.long, device=device) * T_ 312 | T_NEXT = torch.ones(batch_size, dtype=torch.long, device=device) * T_NEXT_ 313 | barat = score_net.scheduler.bar_alpha(T) 314 | if T_NEXT_ == -1: 315 | barat_next = torch.ones((batch_size, 1), device=device, requires_grad=True) 316 | else: 317 | barat_next = score_net.scheduler.bar_alpha(T_NEXT) 318 | with torch.no_grad(): 319 | extracted_noise = score_net(x_t, t=T.float(), cond=None) 320 | cg_model_out = cg_due(x_t, T) 321 | if cg_weight > 0: 322 | pred_loss = torch.pow(cg_model_out.mean - cg_targets, 2.0).sum() 323 | G = torch.autograd.grad(pred_loss, x_t)[0].detach() 324 | extracted_noise = extracted_noise + (cg_weight * G) 325 | if pg_weight > 0: 326 | with torch.no_grad(): 327 | extracted_noise = extracted_noise + ( 328 | similarity_guidance_gradient(x_t) * pg_weight 329 | ) 330 | 331 | x_t = (x_t - extracted_noise * (1 - barat).sqrt()) / barat.sqrt() 332 | c1 = eta * ((1 - barat / barat_next) * (1 - barat_next) / (1 - barat)).sqrt() 333 | c2 = ((1 - barat_next) - c1**2).sqrt() 334 | x_t = ( 335 | barat_next.sqrt() * x_t + c1 * torch.randn_like(x_t) + c2 * extracted_noise 336 | ) 337 | return x_t.detach() 338 | -------------------------------------------------------------------------------- /coatiLDM/common/fd.py: -------------------------------------------------------------------------------- 1 | from scipy import linalg 2 | from coatiLDM.models.diffusion_models import ddpm_sample_routines 3 | from coatiLDM.models.diffusion_models.dflow import dflow, dflow_multi 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 9 | """Numpy implementation of the Frechet Distance. 10 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 11 | and X_2 ~ N(mu_2, C_2) is 12 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 13 | 14 | Stable version by Dougal J. Sutherland. 15 | 16 | Params: 17 | -- mu1 : Numpy array containing the activations of a layer of the 18 | inception net (like returned by the function 'get_predictions') 19 | for generated samples. 20 | -- mu2 : The sample mean over activations, precalculated on an 21 | representative data set. 22 | -- sigma1: The covariance matrix over activations for generated samples. 23 | -- sigma2: The covariance matrix over activations, precalculated on an 24 | representative data set. 25 | 26 | Returns: 27 | -- : The Frechet Distance. 28 | """ 29 | 30 | mu1 = np.atleast_1d(mu1) 31 | mu2 = np.atleast_1d(mu2) 32 | 33 | sigma1 = np.atleast_2d(sigma1) 34 | sigma2 = np.atleast_2d(sigma2) 35 | 36 | assert ( 37 | mu1.shape == mu2.shape 38 | ), "Training and test mean vectors have different lengths" 39 | assert ( 40 | sigma1.shape == sigma2.shape 41 | ), "Training and test covariances have different dimensions" 42 | 43 | diff = mu1 - mu2 44 | 45 | # Product might be almost singular 46 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 47 | if not np.isfinite(covmean).all() or not np.allclose( 48 | np.diagonal(covmean).imag, 0, atol=1e-3 49 | ): 50 | msg = ( 51 | "fd calculation produces singular product; " 52 | "adding %s to diagonal of cov estimates" 53 | ) % eps 54 | print(msg) 55 | offset = np.eye(sigma1.shape[0]) * eps 56 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 57 | 58 | # Numerical error might give slight imaginary component 59 | if np.iscomplexobj(covmean): 60 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 61 | m = np.max(np.abs(covmean.imag)) 62 | raise ValueError("Imaginary component {}".format(m)) 63 | covmean = covmean.real 64 | 65 | tr_covmean = np.trace(covmean) 66 | 67 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 68 | 69 | 70 | def calc_fd(s1: np.ndarray, s2: np.ndarray): 71 | """ 72 | Automatically checks the convergence of the FID 73 | WRT samples for you... by calculating it over 4 samples. 74 | 75 | Args: 76 | s1: np samples in rows. 77 | s2: same 78 | """ 79 | print(f"S1: {s1.shape}") 80 | print(f"S2: {s2.shape}") 81 | sample_shape = np.min([s1.shape[0], s2.shape[0]]) 82 | n_samples = sample_shape // 4 83 | mu1 = s1[:n_samples].mean(0) 84 | mu2 = s2[:n_samples].mean(0) 85 | sigma1 = np.cov(s1[:n_samples], rowvar=False) 86 | sigma2 = np.cov(s2[:n_samples], rowvar=False) 87 | FID1 = calculate_frechet_distance(mu1, sigma1, mu2, sigma2) 88 | n_samples = sample_shape // 2 89 | mu1 = s1[:n_samples].mean(0) 90 | mu2 = s2[:n_samples].mean(0) 91 | sigma1 = np.cov(s1[:n_samples], rowvar=False) 92 | sigma2 = np.cov(s2[:n_samples], rowvar=False) 93 | FID2 = calculate_frechet_distance(mu1, sigma1, mu2, sigma2) 94 | n_samples = 3 * sample_shape // 4 95 | mu1 = s1[:n_samples].mean(0) 96 | mu2 = s2[:n_samples].mean(0) 97 | sigma1 = np.cov(s1[:n_samples], rowvar=False) 98 | sigma2 = np.cov(s2[:n_samples], rowvar=False) 99 | FID3 = calculate_frechet_distance(mu1, sigma1, mu2, sigma2) 100 | n_samples = sample_shape 101 | mu1 = s1[:n_samples].mean(0) 102 | mu2 = s2[:n_samples].mean(0) 103 | sigma1 = np.cov(s1[:n_samples], rowvar=False) 104 | sigma2 = np.cov(s2[:n_samples], rowvar=False) 105 | FID4 = calculate_frechet_distance(mu1, sigma1, mu2, sigma2) 106 | print( 107 | f"FID at 25\%{FID1:.3f} 50\%{FID2:.3f} 75\%{FID3:.3f} 100\%{FID4:.3f} of {n_samples} samples" 108 | ) 109 | return FID4 110 | 111 | 112 | def run_fd( 113 | target_embs, score_model, n_samples=20_000, samples_per_batch=256, pg_weight=0.0 114 | ): 115 | print( 116 | f"Running FID with {n_samples} samples against {target_embs.shape[0]} target embeddings" 117 | ) 118 | samples = torch.zeros(n_samples, score_model.x_dim) 119 | idx = 0 120 | while idx < n_samples: 121 | batch_size = min([samples_per_batch, n_samples - idx]) 122 | sample_batch = ( 123 | ddpm_sample_routines.ddpm_basic_sample( 124 | score_model, cond=None, batch_size=batch_size, pg_weight=pg_weight 125 | ) 126 | .detach() 127 | .cpu() 128 | ) 129 | samples[idx : idx + batch_size] = sample_batch 130 | idx = idx + samples_per_batch 131 | fd = calc_fd(target_embs, samples.numpy()) 132 | return fd 133 | 134 | 135 | def run_fd_cond( 136 | target_embs, 137 | score_model, 138 | cond_set, 139 | n_samples=20_000, 140 | samples_per_batch=256, 141 | pg_weight=0.0, 142 | ): 143 | print( 144 | f"Running FID with {n_samples} samples against {target_embs.shape[0]} target embeddings" 145 | ) 146 | samples = torch.zeros(n_samples, score_model.x_dim) 147 | idx = 0 148 | while idx < n_samples: 149 | batch_size = min([samples_per_batch, n_samples - idx]) 150 | sample_batch = ( 151 | ddpm_sample_routines.ddpm_basic_sample( 152 | score_model, 153 | cond_set[idx : idx + batch_size], 154 | batch_size=batch_size, 155 | pg_weight=pg_weight, 156 | ) 157 | .detach() 158 | .cpu() 159 | ) 160 | samples[idx : idx + batch_size] = sample_batch 161 | idx = idx + samples_per_batch 162 | fd = calc_fd(target_embs, samples.numpy()) 163 | return fd 164 | 165 | 166 | def run_fd_cg( 167 | target_embs, 168 | score_model, 169 | cond_set, 170 | cg_regressor, 171 | cg_weight, 172 | n_samples=20_000, 173 | samples_per_batch=256, 174 | pg_weight=0.0, 175 | ): 176 | print( 177 | f"Running FID with {n_samples} samples against {target_embs.shape[0]} target embeddings" 178 | ) 179 | samples = torch.zeros(n_samples, score_model.x_dim) 180 | idx = 0 181 | while idx < n_samples: 182 | batch_size = min([samples_per_batch, n_samples - idx]) 183 | sample_batch = ( 184 | ddpm_sample_routines.ddpm_sample_classifier_guidance( 185 | score_net=score_model, 186 | batch_size=batch_size, 187 | cg_weight=cg_weight, 188 | cg_due=cg_regressor, 189 | cg_targets=cond_set[idx : idx + batch_size], 190 | pg_weight=pg_weight, 191 | ) 192 | .detach() 193 | .cpu() 194 | ) 195 | samples[idx : idx + batch_size] = sample_batch 196 | idx = idx + samples_per_batch 197 | fd = calc_fd(target_embs, samples.numpy()) 198 | return fd 199 | 200 | 201 | def run_fd_flow( 202 | target_embs, flow_model, cond_set, n_samples=20_000, samples_per_batch=256 203 | ): 204 | print( 205 | f"Running FID with {n_samples} samples against {target_embs.shape[0]} target embeddings" 206 | ) 207 | samples = torch.zeros(n_samples, flow_model.score_net.x_dim) 208 | idx = 0 209 | while idx < n_samples: 210 | batch_size = min([samples_per_batch, n_samples - idx]) 211 | with torch.no_grad(): 212 | x_0 = torch.randn( 213 | batch_size, 512, device=next(flow_model.score_net.parameters()).device 214 | ) 215 | if cond_set is None: 216 | sample_batch = flow_model.decode(x_0, cs=None).detach().cpu() 217 | else: 218 | sample_batch = ( 219 | flow_model.decode(x_0, cs=cond_set[idx : idx + batch_size]) 220 | .detach() 221 | .cpu() 222 | ) 223 | samples[idx : idx + batch_size] = sample_batch 224 | idx = idx + samples_per_batch 225 | fd = calc_fd(target_embs, samples.numpy()) 226 | return fd 227 | 228 | 229 | from coatiLDM.models.score_models.flow_wrapper import ODEWrapper 230 | 231 | 232 | def run_fd_dflow( 233 | target_embs, 234 | flow_net, 235 | cond_set, 236 | cond_regressor, 237 | ode_steps=200, 238 | opt_steps=2, 239 | n_samples=20_000, 240 | samples_per_batch=1000, 241 | ): 242 | print( 243 | f"Running FID with {n_samples} samples against {target_embs.shape[0]} target embeddings" 244 | ) 245 | assert ( 246 | next(flow_net.parameters()).device == next(cond_regressor.parameters()).device 247 | ) 248 | assert isinstance(flow_net, ODEWrapper) 249 | samples = torch.zeros(n_samples, flow_net.score_net.x_dim) 250 | idx = 0 251 | while idx < n_samples: 252 | batch_size = min([samples_per_batch, n_samples - idx]) 253 | 254 | x_0 = torch.randn(batch_size, 512, device=next(flow_net.parameters()).device) 255 | sample_batch = dflow( 256 | x_0, 257 | cond_set[idx : idx + batch_size], 258 | flow_net, 259 | cond_regressor, 260 | learning_rate=1.0, 261 | decode_steps=ode_steps, 262 | opt_steps=opt_steps, 263 | device=next(flow_net.parameters()).device, 264 | ).cpu()[-1] 265 | 266 | samples[idx : idx + batch_size] = sample_batch 267 | idx = idx + samples_per_batch 268 | fd = calc_fd(target_embs, samples.numpy()) 269 | return fd 270 | 271 | 272 | def run_fd_dflow_multi( 273 | target_embs, 274 | flow_net, 275 | cond_sets, 276 | cond_regressors, 277 | ode_steps=200, 278 | opt_steps=2, 279 | n_samples=20_000, 280 | samples_per_batch=1000, 281 | ): 282 | print( 283 | f"Running FID with {n_samples} samples against {target_embs.shape[0]} target embeddings" 284 | ) 285 | assert ( 286 | next(flow_net.parameters()).device 287 | == next(cond_regressors[0].parameters()).device 288 | ) 289 | assert isinstance(flow_net, ODEWrapper) 290 | device = next(flow_net.parameters()).device 291 | samples = torch.zeros(n_samples, flow_net.score_net.x_dim) 292 | idx = 0 293 | while idx < n_samples: 294 | batch_size = min([samples_per_batch, n_samples - idx]) 295 | 296 | x_0 = torch.randn(batch_size, 512, device=device) 297 | sample_batch = dflow_multi( 298 | x_0, 299 | [cond_set[idx : idx + batch_size] for cond_set in cond_sets], 300 | flow_net, 301 | cond_regressors, 302 | learning_rate=1.0, 303 | decode_steps=ode_steps, 304 | opt_steps=opt_steps, 305 | device=next(flow_net.parameters()).device, 306 | ).cpu()[-1] 307 | 308 | samples[idx : idx + batch_size] = sample_batch 309 | idx = idx + samples_per_batch 310 | fd = calc_fd(target_embs, samples.numpy()) 311 | return fd 312 | 313 | 314 | def run_fd_cg_multi( 315 | target_embs, 316 | score_model, 317 | cond_sets, 318 | cg_regressors, 319 | cg_weights, 320 | n_samples=20_000, 321 | samples_per_batch=256, 322 | pg_weight=0.0, 323 | ): 324 | print( 325 | f"Running FID with {n_samples} samples against {target_embs.shape[0]} target embeddings" 326 | ) 327 | samples = torch.zeros(n_samples, score_model.x_dim) 328 | idx = 0 329 | while idx < n_samples: 330 | batch_size = min([samples_per_batch, n_samples - idx]) 331 | cond_set_subsets = [x[idx : idx + batch_size] for x in cond_sets] 332 | sample_batch = ( 333 | ddpm_sample_routines.ddpm_sample_multi_classifier_guidance( 334 | score_net=score_model, 335 | batch_size=batch_size, 336 | cg_weights=cg_weights, 337 | cg_dues=cg_regressors, 338 | cg_targets=cond_set_subsets, 339 | pg_weight=pg_weight, 340 | ) 341 | .detach() 342 | .cpu() 343 | ) 344 | samples[idx : idx + batch_size] = sample_batch 345 | idx = idx + samples_per_batch 346 | torch.cuda.empty_cache() 347 | fd = calc_fd(target_embs, samples.numpy()) 348 | return fd 349 | 350 | 351 | def run_fd_cfg( 352 | target_embs, 353 | uncond_score_model, 354 | cond_score_model, 355 | cond_set, 356 | cfg_weight, 357 | n_samples=20_000, 358 | samples_per_batch=256, 359 | pg_weight=0.0, 360 | ): 361 | print( 362 | f"Running FID with {n_samples} samples against {target_embs.shape[0]} target embeddings" 363 | ) 364 | samples = torch.zeros(n_samples, uncond_score_model.x_dim) 365 | idx = 0 366 | while idx < n_samples: 367 | batch_size = min([samples_per_batch, n_samples - idx]) 368 | sample_batch = ( 369 | ddpm_sample_routines.ddpm_sample_classifier_free_guidance( 370 | uncond_score_net=uncond_score_model, 371 | cond_score_net=cond_score_model, 372 | cond=cond_set[idx : idx + batch_size], 373 | batch_size=batch_size, 374 | pg_weight=pg_weight, 375 | cfg_weight=cfg_weight, 376 | ) 377 | .detach() 378 | .cpu() 379 | ) 380 | samples[idx : idx + batch_size] = sample_batch 381 | idx = idx + samples_per_batch 382 | fd = calc_fd(target_embs, samples.numpy()) 383 | return fd 384 | -------------------------------------------------------------------------------- /coatiLDM/models/diffusion_models/ddpm_sample_routines.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from coatiLDM.models.diffusion_models.particle_guidance import ( 4 | similarity_guidance_gradient, 5 | cosine_guidance_gradient, 6 | cosine_guidance_updated, 7 | low_memory_cosine_guidance_gradient, 8 | ) 9 | from coatiLDM.data.transforms import embed_scalar 10 | from tqdm.auto import tqdm 11 | import numpy as np 12 | from torchdiffeq import odeint 13 | 14 | 15 | def ddpm_basic_sample( 16 | score_net, 17 | cond=None, 18 | batch_size=4, 19 | pg_weight=0.0, 20 | embed_dim=None, 21 | pg_gradient_type="euclidean", 22 | fixed_pg_ratio=False, 23 | low_mem=False, 24 | ): 25 | with torch.no_grad(): 26 | device = next(score_net.parameters()).device 27 | if not cond is None: 28 | if cond is None and score_net.cond_dim > 0: 29 | raise Exception("Give me a condition.") 30 | else: 31 | assert cond.shape[0] == batch_size 32 | else: 33 | # Conditions are normally distributed random variables. 34 | cond = torch.randn((batch_size, score_net.cond_dim), device=device) 35 | 36 | if embed_dim: 37 | cond = embed_scalar(cond, embedding_dim=embed_dim) 38 | 39 | guidance_ratio = 0.1 * pg_weight 40 | 41 | x_t = torch.randn((batch_size, score_net.x_dim), device=device) 42 | for T_ in reversed(range(score_net.scheduler.timesteps)): 43 | T = torch.ones(batch_size, dtype=torch.long, device=device) * T_ 44 | if T_ > 0: 45 | z = torch.randn((batch_size, score_net.x_dim), device=device) 46 | else: 47 | z = torch.zeros((batch_size, score_net.x_dim), device=device) 48 | extracted_noise = score_net(x_t, t=T.float(), cond=cond) 49 | if pg_weight != 0: 50 | if pg_gradient_type in ("cos", "cosine"): 51 | if low_mem: 52 | guidance = low_memory_cosine_guidance_gradient(x_t) 53 | else: 54 | guidance = cosine_guidance_updated(x_t) 55 | else: 56 | guidance = similarity_guidance_gradient(x_t) 57 | if fixed_pg_ratio: 58 | pg_multiplier = ( 59 | guidance_ratio * extracted_noise.abs().mean() 60 | ) / guidance.abs().mean() 61 | else: 62 | pg_multiplier = pg_weight 63 | extracted_noise = extracted_noise + (guidance * pg_multiplier) 64 | at = score_net.scheduler.alpha(T) 65 | barat = score_net.scheduler.bar_alpha(T) 66 | noise_factor = (1.0 - at) / ((1.0 - barat).sqrt()) 67 | x_t = ( 68 | torch.pow(score_net.scheduler.alpha(T), -0.5) 69 | * (x_t - noise_factor * extracted_noise) 70 | + score_net.scheduler.beta(T).sqrt() * z 71 | ) 72 | return x_t.detach() 73 | 74 | 75 | def ddpm_sample_classifier_free_guidance( 76 | uncond_score_net, 77 | cond_score_net, 78 | cond, 79 | batch_size=4, 80 | pg_weight=0.0, 81 | cfg_weight=0.2, 82 | pg_gradient_type="euclidean", 83 | fixed_pg_ratio="False", 84 | ): 85 | try: 86 | assert uncond_score_net.cond_dim == 0 87 | except: 88 | ValueError("Unconditional score net is first argument, conditional second") 89 | 90 | try: 91 | assert uncond_score_net.scheduler.is_same(cond_score_net.scheduler) 92 | except: 93 | ValueError("Score nets must have the same noise schedule for sampling") 94 | 95 | scheduler = uncond_score_net.scheduler 96 | 97 | guidance_ratio = 0.1 * pg_weight 98 | 99 | with torch.no_grad(): 100 | device = next(uncond_score_net.parameters()).device 101 | assert device == next(cond_score_net.parameters()).device 102 | 103 | x_t = torch.randn((batch_size, uncond_score_net.x_dim), device=device) 104 | for T_ in reversed(range(scheduler.timesteps)): 105 | T = torch.ones(batch_size, dtype=torch.long, device=device) * T_ 106 | if T_ > 0: 107 | z = torch.randn((batch_size, uncond_score_net.x_dim), device=device) 108 | else: 109 | z = torch.zeros((batch_size, uncond_score_net.x_dim), device=device) 110 | extracted_noise_cond = cond_score_net(x_t, t=T.float(), cond=cond) 111 | extracted_noise_uncond = uncond_score_net(x_t, t=T.float(), cond=None) 112 | # linear combination 113 | extracted_noise = ((1 + cfg_weight) * extracted_noise_cond) - ( 114 | cfg_weight * extracted_noise_uncond 115 | ) 116 | if pg_weight != 0 and T_ > 0: 117 | if pg_gradient_type in ("cos", "cosine"): 118 | guidance = cosine_guidance_updated(x_t) 119 | else: 120 | guidance = similarity_guidance_gradient(x_t) 121 | if fixed_pg_ratio: 122 | pg_multiplier = ( 123 | guidance_ratio * extracted_noise.abs().mean() 124 | ) / guidance.abs().mean() 125 | else: 126 | pg_multiplier = pg_weight 127 | extracted_noise = (guidance * pg_multiplier) + extracted_noise 128 | 129 | at = scheduler.alpha(T) 130 | barat = scheduler.bar_alpha(T) 131 | noise_factor = (1.0 - at) / ((1.0 - barat).sqrt()) 132 | x_t = ( 133 | torch.pow(scheduler.alpha(T), -0.5) 134 | * (x_t - noise_factor * extracted_noise) 135 | + scheduler.beta(T).sqrt() * z 136 | ) 137 | return x_t.detach() 138 | 139 | 140 | def ddpm_sample_classifier_guidance( 141 | score_net, 142 | batch_size, 143 | cg_weight, 144 | cg_due, 145 | cg_targets, 146 | cond=None, 147 | pg_weight=0.0, 148 | pg_gradient_type="euclidean", 149 | fixed_pg_ratio=False, 150 | start_at=1000, 151 | low_mem=False, 152 | ): 153 | 154 | try: 155 | score_net.scheduler.is_same(cg_due.scheduler) 156 | except: 157 | raise ValueError( 158 | f"classifier must share noise schedule with score net. different betas detected" 159 | ) 160 | # cg_due.validate_self(self) 161 | 162 | device = next(score_net.parameters()).device 163 | if not cond is None: 164 | if cond is None and score_net.cond_dim > 0: 165 | raise Exception("Give me a condition.") 166 | else: 167 | assert cond.shape[0] == batch_size 168 | else: 169 | # Conditions are normally distributed random variables. 170 | cond = torch.randn((batch_size, score_net.cond_dim), device=device) 171 | 172 | guidance_ratio = 0.1 * pg_weight 173 | 174 | x_t = torch.randn((batch_size, score_net.x_dim), device=device, requires_grad=True) 175 | for T_ in reversed(range(score_net.scheduler.timesteps)): 176 | T = torch.ones(batch_size, dtype=torch.long, device=device) * T_ 177 | if T_ > 0: 178 | z = torch.randn((batch_size, score_net.x_dim), device=device) 179 | else: 180 | z = torch.zeros((batch_size, score_net.x_dim), device=device) 181 | with torch.no_grad(): 182 | extracted_noise = score_net(x_t, t=T.float(), cond=cond) 183 | cg_model_out = cg_due(x_t, T) 184 | 185 | if T_ > 0: 186 | if cg_weight > 0 and T_ < start_at: 187 | pred_loss = torch.pow(cg_model_out.mean - cg_targets, 2.0).sum() 188 | G = torch.autograd.grad(pred_loss, x_t)[0].detach() 189 | extracted_noise = extracted_noise + (cg_weight * G) 190 | if pg_weight > 0 and T_ < start_at: 191 | if pg_gradient_type in ("cos", "cosine"): 192 | if low_mem: 193 | guidance = low_memory_cosine_guidance_gradient(x_t) 194 | else: 195 | guidance = cosine_guidance_updated(x_t) 196 | else: 197 | guidance = similarity_guidance_gradient(x_t) 198 | if fixed_pg_ratio: 199 | pg_multiplier = ( 200 | guidance_ratio * extracted_noise.abs().mean() 201 | ) / guidance.abs().mean() 202 | else: 203 | pg_multiplier = pg_weight 204 | extracted_noise = (guidance * pg_multiplier) + extracted_noise 205 | 206 | at = score_net.scheduler.alpha(T) 207 | barat = score_net.scheduler.bar_alpha(T) 208 | noise_factor = (1.0 - at) / ((1.0 - barat).sqrt()) 209 | x_t = ( 210 | torch.pow(score_net.scheduler.alpha(T), -0.5) 211 | * (x_t - noise_factor * extracted_noise) 212 | + score_net.scheduler.beta(T).sqrt() * z 213 | ) 214 | return x_t.detach() 215 | 216 | 217 | def ddpm_sample_multi_classifier_guidance( 218 | score_net, 219 | batch_size, 220 | cg_weights, 221 | cg_dues, 222 | cg_targets, 223 | cond=None, 224 | pg_weight=0.0, 225 | pg_gradient_type="euclidean", 226 | fixed_pg_ratio=False, 227 | start_at=1000, 228 | ): 229 | 230 | try: 231 | for cg_due in cg_dues: 232 | score_net.scheduler.is_same(cg_due.scheduler) 233 | except: 234 | raise ValueError( 235 | f"classifier must share noise schedule with score net. different betas detected" 236 | ) 237 | # cg_due.validate_self(self) 238 | 239 | device = next(score_net.parameters()).device 240 | if not cond is None: 241 | if cond is None and score_net.cond_dim > 0: 242 | raise Exception("Give me a condition.") 243 | else: 244 | assert cond.shape[0] == batch_size 245 | else: 246 | # Conditions are normally distributed random variables. 247 | cond = torch.randn((batch_size, score_net.cond_dim), device=device) 248 | 249 | x_t = torch.randn((batch_size, score_net.x_dim), device=device, requires_grad=True) 250 | for T_ in reversed(range(score_net.scheduler.timesteps)): 251 | T = torch.ones(batch_size, dtype=torch.long, device=device) * T_ 252 | if T_ > 0: 253 | z = torch.randn((batch_size, score_net.x_dim), device=device) 254 | else: 255 | z = torch.zeros((batch_size, score_net.x_dim), device=device) 256 | with torch.no_grad(): 257 | extracted_noise = score_net(x_t, t=T.float(), cond=cond) 258 | preds = [cg_due(x_t, T) for cg_due in cg_dues] 259 | 260 | if T_ > 0: 261 | cg_term = 0.0 262 | for cg_weight, cg_model_out, cg_target in zip( 263 | cg_weights, preds, cg_targets 264 | ): 265 | if cg_weight > 0 and T_ < start_at: 266 | pred_loss = torch.pow(cg_model_out.mean - cg_target, 2.0).sum() 267 | G = torch.autograd.grad(pred_loss, x_t)[0].detach() 268 | cg_term += cg_weight * G 269 | 270 | extracted_noise = extracted_noise + cg_term 271 | if pg_weight > 0: 272 | if pg_gradient_type in ("cos", "cosine"): 273 | guidance = cosine_guidance_updated(x_t) 274 | else: 275 | guidance = similarity_guidance_gradient(x_t) 276 | guidance_ratio = 0.1 * pg_weight 277 | if fixed_pg_ratio: 278 | pg_multiplier = ( 279 | guidance_ratio * extracted_noise.abs().mean() 280 | ) / guidance.abs().mean() 281 | else: 282 | pg_multiplier = pg_weight 283 | guidance = similarity_guidance_gradient(x_t) 284 | extracted_noise = (guidance * pg_multiplier) + extracted_noise 285 | 286 | at = score_net.scheduler.alpha(T) 287 | barat = score_net.scheduler.bar_alpha(T) 288 | noise_factor = (1.0 - at) / ((1.0 - barat).sqrt()) 289 | x_t = ( 290 | torch.pow(score_net.scheduler.alpha(T), -0.5) 291 | * (x_t - noise_factor * extracted_noise) 292 | + score_net.scheduler.beta(T).sqrt() * z 293 | ) 294 | return x_t.detach() 295 | 296 | 297 | def ddpm_basic_nearby( 298 | score_net, 299 | emb_batch, 300 | T_start, 301 | cond=None, 302 | pg_weight=0.0, 303 | embed_dim=None, 304 | pg_gradient_type="euclidean", 305 | fixed_pg_ratio=False, 306 | low_mem=False, 307 | ): 308 | 309 | batch_size = emb_batch.shape[0] 310 | with torch.no_grad(): 311 | device = next(score_net.parameters()).device 312 | if not cond is None: 313 | if cond is None and score_net.cond_dim > 0: 314 | raise Exception("Give me a condition.") 315 | else: 316 | assert cond.shape[0] == batch_size 317 | else: 318 | # Conditions are normally distributed random variables. 319 | cond = torch.randn((batch_size, score_net.cond_dim), device=device) 320 | 321 | if embed_dim: 322 | cond = embed_scalar(cond, embedding_dim=embed_dim) 323 | T_init = torch.ones(batch_size, dtype=torch.long, device=device) * T_start 324 | noise = torch.randn((batch_size, score_net.x_dim), device=device) 325 | x_t = ( 326 | score_net.scheduler.bar_alpha(T_init).sqrt() * emb_batch 327 | + (1.0 - score_net.scheduler.bar_alpha(T_init)).sqrt() * noise 328 | ) 329 | for T_ in reversed(range(T_start)): 330 | T = torch.ones(batch_size, dtype=torch.long, device=device) * T_ 331 | if T_ > 0: 332 | z = torch.randn((batch_size, score_net.x_dim), device=device) 333 | else: 334 | z = torch.zeros((batch_size, score_net.x_dim), device=device) 335 | extracted_noise = score_net(x_t, t=T.float(), cond=cond) 336 | if pg_weight > 0 and T_ > 0: 337 | if pg_gradient_type in ("cos", "cosine"): 338 | if low_mem: 339 | guidance = low_memory_cosine_guidance_gradient(x_t) 340 | else: 341 | guidance = cosine_guidance_updated(x_t) 342 | else: 343 | guidance = similarity_guidance_gradient(x_t) 344 | guidance_ratio = 0.1 * pg_weight 345 | if fixed_pg_ratio: 346 | pg_multiplier = ( 347 | guidance_ratio * extracted_noise.abs().mean() 348 | ) / guidance.abs().mean() 349 | else: 350 | pg_multiplier = pg_weight 351 | extracted_noise = extracted_noise + (guidance * pg_multiplier) 352 | at = score_net.scheduler.alpha(T) 353 | barat = score_net.scheduler.bar_alpha(T) 354 | noise_factor = (1.0 - at) / ((1.0 - barat).sqrt()) 355 | x_t = ( 356 | torch.pow(score_net.scheduler.alpha(T), -0.5) 357 | * (x_t - noise_factor * extracted_noise) 358 | + score_net.scheduler.beta(T).sqrt() * z 359 | ) 360 | return x_t.detach() 361 | 362 | 363 | def ddpm_cg_nearby( 364 | uncond_score_net, 365 | emb_batch, 366 | T_start, 367 | cg_due, 368 | targets, 369 | cg_weight=100.0, 370 | pg_weight=0.0, 371 | ): 372 | 373 | batch_size = emb_batch.shape[0] 374 | device = next(uncond_score_net.parameters()).device 375 | 376 | T_init = torch.ones(batch_size, dtype=torch.long, device=device) * T_start 377 | noise = torch.randn( 378 | (batch_size, uncond_score_net.x_dim), device=device, requires_grad=True 379 | ) 380 | x_t = ( 381 | uncond_score_net.scheduler.bar_alpha(T_init).sqrt() * emb_batch 382 | + (1.0 - uncond_score_net.scheduler.bar_alpha(T_init)).sqrt() * noise 383 | ) 384 | for T_ in reversed(range(T_start)): 385 | T = torch.ones(batch_size, dtype=torch.long, device=device) * T_ 386 | if T_ > 0: 387 | z = torch.randn((batch_size, uncond_score_net.x_dim), device=device) 388 | else: 389 | z = torch.zeros((batch_size, uncond_score_net.x_dim), device=device) 390 | with torch.no_grad(): 391 | extracted_noise = uncond_score_net(x_t, t=T.float(), cond=None) 392 | cg_model_out = cg_due(x_t, T) 393 | if T_ > 0: 394 | if cg_weight > 0: 395 | pred_loss = torch.pow(cg_model_out.mean - targets, 2.0).sum() 396 | G = torch.autograd.grad(pred_loss, x_t)[0].detach() 397 | extracted_noise = extracted_noise + (cg_weight * G) 398 | if pg_weight > 0: 399 | with torch.no_grad(): 400 | guidance = similarity_guidance_gradient(x_t) 401 | extracted_noise = (guidance * pg_weight) + extracted_noise 402 | if pg_weight > 0 and T_ > 0: 403 | extracted_noise = extracted_noise + ( 404 | similarity_guidance_gradient(x_t) * pg_weight 405 | ) 406 | at = uncond_score_net.scheduler.alpha(T) 407 | barat = uncond_score_net.scheduler.bar_alpha(T) 408 | noise_factor = (1.0 - at) / ((1.0 - barat).sqrt()) 409 | x_t = ( 410 | torch.pow(uncond_score_net.scheduler.alpha(T), -0.5) 411 | * (x_t - noise_factor * extracted_noise) 412 | + uncond_score_net.scheduler.beta(T).sqrt() * z 413 | ) 414 | return x_t.detach() 415 | 416 | 417 | def ddpm_multi_cg_nearby( 418 | uncond_score_net, emb_batch, T_start, cg_dues, cg_targets, cg_weights, pg_weight=0.0 419 | ): 420 | 421 | batch_size = emb_batch.shape[0] 422 | device = next(uncond_score_net.parameters()).device 423 | 424 | try: 425 | for cg_due in cg_dues: 426 | uncond_score_net.scheduler.is_same(cg_due.scheduler) 427 | except: 428 | raise ValueError( 429 | f"classifier must share noise schedule with score net. different betas detected" 430 | ) 431 | 432 | T_init = torch.ones(batch_size, dtype=torch.long, device=device) * T_start 433 | noise = torch.randn( 434 | (batch_size, uncond_score_net.x_dim), device=device, requires_grad=True 435 | ) 436 | x_t = ( 437 | uncond_score_net.scheduler.bar_alpha(T_init).sqrt() * emb_batch 438 | + (1.0 - uncond_score_net.scheduler.bar_alpha(T_init)).sqrt() * noise 439 | ) 440 | for T_ in reversed(range(T_start)): 441 | T = torch.ones(batch_size, dtype=torch.long, device=device) * T_ 442 | if T_ > 0: 443 | z = torch.randn((batch_size, uncond_score_net.x_dim), device=device) 444 | else: 445 | z = torch.zeros((batch_size, uncond_score_net.x_dim), device=device) 446 | with torch.no_grad(): 447 | extracted_noise = uncond_score_net(x_t, t=T.float(), cond=None) 448 | preds = [cg_due(x_t, T) for cg_due in cg_dues] 449 | if T_ > 0: 450 | cg_term = 0.0 451 | for cg_weight, cg_model_out, cg_target in zip( 452 | cg_weights, preds, cg_targets 453 | ): 454 | if cg_weight > 0: 455 | pred_loss = torch.pow(cg_model_out.mean - cg_target, 2.0).sum() 456 | G = torch.autograd.grad(pred_loss, x_t)[0].detach() 457 | cg_term += cg_weight * G 458 | extracted_noise = extracted_noise + cg_term 459 | if pg_weight > 0: 460 | with torch.no_grad(): 461 | guidance = similarity_guidance_gradient(x_t) 462 | extracted_noise = (guidance * pg_weight) + extracted_noise 463 | 464 | at = uncond_score_net.scheduler.alpha(T) 465 | barat = uncond_score_net.scheduler.bar_alpha(T) 466 | noise_factor = (1.0 - at) / ((1.0 - barat).sqrt()) 467 | x_t = ( 468 | torch.pow(uncond_score_net.scheduler.alpha(T), -0.5) 469 | * (x_t - noise_factor * extracted_noise) 470 | + uncond_score_net.scheduler.beta(T).sqrt() * z 471 | ) 472 | return x_t.detach() 473 | 474 | 475 | def ddpm_cfg_nearby( 476 | uncond_score_net, 477 | cond_score_net, 478 | emb_batch, 479 | T_start, 480 | cond, 481 | batch_size=4, 482 | pg_weight=0.0, 483 | cfg_weight=0.2, 484 | ): 485 | try: 486 | assert uncond_score_net.cond_dim == 0 487 | except: 488 | ValueError("Unconditional score net is first argument, conditional second") 489 | 490 | try: 491 | assert uncond_score_net.scheduler.is_same(cond_score_net.scheduler) 492 | except: 493 | ValueError("Score nets must have the same noise schedule for sampling") 494 | 495 | scheduler = uncond_score_net.scheduler 496 | 497 | with torch.no_grad(): 498 | device = next(uncond_score_net.parameters()).device 499 | assert device == next(cond_score_net.parameters()).device 500 | 501 | T_init = torch.ones(batch_size, dtype=torch.long, device=device) * T_start 502 | noise = torch.randn( 503 | (batch_size, uncond_score_net.x_dim), device=device, requires_grad=True 504 | ) 505 | x_t = ( 506 | uncond_score_net.scheduler.bar_alpha(T_init).sqrt() * emb_batch 507 | + (1.0 - uncond_score_net.scheduler.bar_alpha(T_init)).sqrt() * noise 508 | ) 509 | for T_ in reversed(range(T_start)): 510 | T = torch.ones(batch_size, dtype=torch.long, device=device) * T_ 511 | if T_ > 0: 512 | z = torch.randn((batch_size, uncond_score_net.x_dim), device=device) 513 | else: 514 | z = torch.zeros((batch_size, uncond_score_net.x_dim), device=device) 515 | extracted_noise_cond = cond_score_net(x_t, t=T.float(), cond=cond) 516 | extracted_noise_uncond = uncond_score_net(x_t, t=T.float(), cond=None) 517 | # linear combination 518 | extracted_noise = ((1 + cfg_weight) * extracted_noise_cond) - ( 519 | cfg_weight * extracted_noise_uncond 520 | ) 521 | if pg_weight > 0 and T_ > 0: 522 | extracted_noise = extracted_noise + ( 523 | similarity_guidance_gradient(x_t) * pg_weight 524 | ) 525 | at = scheduler.alpha(T) 526 | barat = scheduler.bar_alpha(T) 527 | noise_factor = (1.0 - at) / ((1.0 - barat).sqrt()) 528 | x_t = ( 529 | torch.pow(scheduler.alpha(T), -0.5) 530 | * (x_t - noise_factor * extracted_noise) 531 | + scheduler.beta(T).sqrt() * z 532 | ) 533 | return x_t.detach() 534 | --------------------------------------------------------------------------------