├── src └── byprot │ ├── models │ ├── seq2seq │ │ ├── __init__.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── ffn.py │ │ │ ├── utils.py │ │ │ └── embedding.py │ │ ├── transformer_encoder.py │ │ ├── transformer_decoder.py │ │ └── transformer.py │ ├── fixedbb │ │ ├── pifold │ │ │ ├── __init__.py │ │ │ └── pifold.py │ │ ├── __init__.py │ │ ├── lm_design │ │ │ ├── modules │ │ │ │ └── gvp_transformer_encoder.py │ │ │ ├── esm_adapter_pifold.py │ │ │ ├── esm2_adapter_gvptrans.py │ │ │ └── esm_adapter.py │ │ ├── protein_mpnn_cmlm │ │ │ └── protein_mpnn.py │ │ └── generator.py │ └── __init__.py │ ├── tasks │ └── fixedbb │ │ └── __init__.py │ ├── __init__.py │ ├── datamodules │ ├── datasets │ │ ├── __init__.py │ │ ├── vocab.py │ │ └── parallel_dataset.py │ ├── __init__.py │ ├── cath_datamodule.py │ └── multichain_datamodule.py │ ├── utils │ ├── registry.py │ ├── strategies.py │ ├── lr_scheduler.py │ ├── optim.py │ ├── config.py │ └── io.py │ ├── modules │ ├── __init__.py │ ├── metrics.py │ └── cross_entropy.py │ ├── testing_pipeline.py │ └── training_pipeline.py ├── configs ├── trainer │ ├── ddp_fp16.yaml │ ├── ddp.yaml │ └── default.yaml ├── logger │ ├── csv.yaml │ ├── many_loggers.yaml │ ├── tensorboard.yaml │ └── wandb.yaml ├── paths │ ├── evaluation.yaml │ └── default.yaml ├── test.yaml ├── hydra │ └── default.yaml ├── experiment │ ├── base.yaml │ ├── fixedbb_multichain │ │ ├── pifold.yaml │ │ ├── protein_mpnn_cmlm.yaml │ │ ├── lm_design_esm1b_650m.yaml │ │ ├── lm_design_esm2_650m.yaml │ │ └── lm_design_esm2_3b.yaml │ └── fixedbb │ │ ├── pifold.yaml │ │ ├── protein_mpnn_cmlm.yaml │ │ ├── lm_design_esm1b_650m.yaml │ │ ├── lm_design_esm2_35m.yaml │ │ ├── lm_design_esm2_8m.yaml │ │ ├── lm_design_esm2_150m.yaml │ │ ├── lm_design_esm2_650m.yaml │ │ ├── lm_design_esm2_3b.yaml │ │ ├── lm_design_esm2_150m_gvptrans_cath4.3.yaml │ │ └── lm_design_esm2_650m_gvptrans_cath4.3.yaml ├── datamodule │ ├── multichain.yaml │ ├── cath_4.3.yaml │ ├── cath_4.2.yaml │ └── cath_4.3_TS50.yaml ├── callbacks │ ├── default.yaml │ └── fixedbb.yaml └── config.yaml ├── assets └── lm_design.png ├── .gitmodules ├── scripts ├── download_multichain.sh ├── download_cath.sh ├── design_pdb.sh └── design_pdb.py ├── examples ├── pmpnn_compatible │ ├── helper_scripts │ │ ├── parse_multiple_chains.sh │ │ ├── make_bias_AA.py │ │ ├── assign_fixed_chains.py │ │ ├── make_bias_per_res_dict.py │ │ ├── other_tools │ │ │ ├── make_omit_AA.py │ │ │ └── make_pssm_dict.py │ │ ├── make_fixed_positions_dict.py │ │ ├── make_tied_positions_dict.py │ │ ├── make_pos_neg_tied_positions_dict.py │ │ └── parse_multiple_chains.py │ └── design_pdb.sh └── inspect_data_and_model.ipynb ├── install.sh ├── setup.py ├── setup.cfg ├── test.py ├── requirements.txt ├── train.py ├── .gitignore ├── .pre-commit-config.yaml └── env.yml /src/byprot/models/seq2seq/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/byprot/tasks/fixedbb/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/byprot/models/seq2seq/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/trainer/ddp_fp16.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ddp.yaml 3 | 4 | precision: 16 5 | -------------------------------------------------------------------------------- /src/byprot/models/fixedbb/pifold/__init__.py: -------------------------------------------------------------------------------- 1 | from .pifold import PiFold, PiFoldConfig -------------------------------------------------------------------------------- /assets/lm_design.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BytedProtein/ByProt/HEAD/assets/lm_design.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "vendor/esm"] 2 | path = vendor/esm 3 | url = https://github.com/facebookresearch/esm.git 4 | -------------------------------------------------------------------------------- /src/byprot/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import byprot.datamodules 3 | import byprot.models 4 | import byprot.tasks 5 | import byprot.utils -------------------------------------------------------------------------------- /src/byprot/datamodules/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # from .datapipe import * 2 | 3 | from .data_utils import Alphabet, DataProcessor -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | gpus: auto 5 | strategy: ddp_sharded_fbo 6 | precision: 32 7 | 8 | # sync_batchnorm: True 9 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | save_dir: "." 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /scripts/download_multichain.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data 2 | wget -r -nd -np https://files.ipd.uw.edu/pub/training_sets/pdb_2021aug02_sample.tar.gz -P data/multichain 3 | cd data/multichain 4 | tar -xzvf pdb_2021aug02_sample.tar.gz -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet.yaml 5 | - csv.yaml 6 | # - mlflow.yaml 7 | # - neptune.yaml 8 | - tensorboard.yaml 9 | # - wandb.yaml 10 | -------------------------------------------------------------------------------- /configs/paths/evaluation.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | hydra: 4 | run: 5 | dir: logs/evaluations/runs/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 6 | sweep: 7 | dir: logs/evaluations/multiruns/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 8 | subdir: ${hydra.job.num} 9 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "tensorboard/" 6 | name: null 7 | version: ${name} 8 | log_graph: False 9 | default_hp_metric: True 10 | prefix: "" 11 | -------------------------------------------------------------------------------- /examples/pmpnn_compatible/helper_scripts/parse_multiple_chains.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem=32g 3 | #SBATCH -c 2 4 | #SBATCH --output=parse_multiple_chains.out 5 | 6 | source activate mlfold 7 | python parse_multiple_chains.py --input_path='../PDB_complexes/pdbs/' --output_path='../PDB_complexes/parsed_pdbs.jsonl' 8 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | gpus: 'auto' 4 | 5 | min_epochs: 1 6 | max_epochs: 10 7 | enable_progress_bar: true 8 | log_every_n_steps: 10 9 | 10 | # number of validation steps to execute at the beginning of the training 11 | # num_sanity_val_steps: 0 12 | 13 | # ckpt path 14 | resume_from_checkpoint: null 15 | -------------------------------------------------------------------------------- /configs/test.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default evaluation configuration 4 | defaults: 5 | - _self_ 6 | - config 7 | 8 | experiment_path: ??? # experiment folder containing checkpoints and configs (.hydra) 9 | ckpt_path: ??? # passing checkpoint path is necessary 10 | data_split: test # train, valid, test 11 | mode: [test, predict] # test, or predict -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir} 11 | # sweep: 12 | # dir: ${paths.log_dir}/multiruns 13 | # subdir: ${hydra.job.num} -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | pip install torch==1.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 2 | 3 | pip install -e . 4 | pip install -e vendor/esm 5 | 6 | pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-1.12.0+cu113.html 7 | pip install torch_geometric biotite 8 | pip install dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html 9 | -------------------------------------------------------------------------------- /scripts/download_cath.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data 2 | wget -r -nd -np http://people.csail.mit.edu/ingraham/graph-protein-design/data/cath/ -P data/cath_4.2 3 | 4 | 5 | mkdir -p data/cath_4.3 6 | wget -r -nd -np https://dl.fbaipublicfiles.com/fair-esm/data/cath4.3_topologysplit_202206/chain_set.jsonl -P data/cath_4.3 7 | wget -r -nd -np https://dl.fbaipublicfiles.com/fair-esm/data/cath4.3_topologysplit_202206/split.jsonl -P data/cath_4.3 -------------------------------------------------------------------------------- /src/byprot/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from omegaconf import DictConfig 3 | import os 4 | import glob 5 | 6 | from byprot.utils import import_modules 7 | 8 | DATAMODULE_REGISTRY = {} 9 | 10 | 11 | def register_datamodule(name): 12 | def decorator(cls): 13 | DATAMODULE_REGISTRY[name] = cls 14 | return cls 15 | return decorator 16 | 17 | 18 | import_modules(os.path.dirname(__file__), "byprot.datamodules") 19 | -------------------------------------------------------------------------------- /configs/experiment/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - /datamodule: cath_4.2 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | # name of the run determines folder name in logs 13 | name: "fixedbb/protein_mpnn-cath_esm" 14 | 15 | 16 | -------------------------------------------------------------------------------- /configs/datamodule/multichain.yaml: -------------------------------------------------------------------------------- 1 | _target_: multichain 2 | 3 | # data_dir: ${data_dir} # data_dir is specified in config.yaml 4 | data_dir: '${paths.data_dir}/multichain/pdb_2021aug02' 5 | max_length: 1000 6 | atoms: ['N', 'CA', 'C', 'O'] 7 | 8 | # alphabet related 9 | alphabet: 10 | name: esm 11 | featurizer: multichain 12 | 13 | # dataloader related 14 | max_tokens: 10000 15 | sort: true 16 | num_workers: 8 17 | pin_memory: true 18 | debug: ${train.debug} -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 5 | project: "template-tests" 6 | # name: ${name} 7 | save_dir: "." 8 | offline: False # set True to store all logs only locally 9 | id: null # pass correct id to resume experiment! 10 | # entity: "" # set to name of your wandb team 11 | log_model: False 12 | prefix: "" 13 | job_type: "train" 14 | group: "" 15 | tags: [] 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | 4 | setup( 5 | name="ByProt", 6 | version="1.0.0", 7 | description="A pytorch library for swift protein design research and developing.", 8 | author="ByteDance Research", 9 | author_email="zhengzaixiang@bytedance.com", 10 | # url="https://github.com/bytedance/ByProt", 11 | install_requires=open("requirements.txt").readlines(), 12 | package_dir={"": "src"}, 13 | packages=find_packages("src") 14 | ) 15 | -------------------------------------------------------------------------------- /configs/datamodule/cath_4.3.yaml: -------------------------------------------------------------------------------- 1 | _target_: cath 2 | 3 | # data_dir: ${data_dir} # data_dir is specified in config.yaml 4 | data_dir: '${paths.data_dir}/cath_4.3' 5 | chain_set_jsonl: 'chain_set.jsonl' 6 | chain_set_splits_json: 'chain_set_splits.json' 7 | max_length: 500 # 393 8 | atoms: ['N', 'CA', 'C', 'O'] 9 | 10 | # alphabet related 11 | alphabet: 12 | name: esm 13 | featurizer: cath 14 | 15 | # dataloader related 16 | max_tokens: 6000 17 | sort: true 18 | num_workers: 8 19 | pin_memory: true 20 | -------------------------------------------------------------------------------- /src/byprot/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from omegaconf import DictConfig 3 | import os 4 | import glob 5 | 6 | from byprot.utils import import_modules 7 | 8 | MODEL_REGISTRY = {} 9 | 10 | 11 | def register_model(name): 12 | def decorator(cls): 13 | MODEL_REGISTRY[name] = cls 14 | return cls 15 | return decorator 16 | 17 | 18 | 19 | # automatically import any Python files in the models/ directory 20 | import_modules(os.path.dirname(__file__), "byprot.models", excludes=['protein_structure_prediction']) 21 | -------------------------------------------------------------------------------- /configs/datamodule/cath_4.2.yaml: -------------------------------------------------------------------------------- 1 | _target_: cath 2 | 3 | # data_dir: ${data_dir} # data_dir is specified in config.yaml 4 | data_dir: '${paths.data_dir}/cath_4.2' 5 | # data_dir: '/root/neurips19-graph-protein-design/data/cath' 6 | chain_set_jsonl: 'chain_set.jsonl' 7 | chain_set_splits_json: 'chain_set_splits.json' 8 | max_length: 500 # 393 9 | atoms: ['N', 'CA', 'C', 'O'] 10 | 11 | # alphabet related 12 | alphabet: 13 | name: esm 14 | featurizer: cath 15 | 16 | # dataloader related 17 | max_tokens: 6000 18 | sort: true 19 | num_workers: 8 20 | pin_memory: true 21 | -------------------------------------------------------------------------------- /configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # PROJECT_ROOT is inferred and set by pyrootutils package in entry_point program e.g., `train.py`. 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory, which is also 10 | # the path to output directory, created dynamically by hydra 11 | # path generation pattern is specified in `configs/hydra/default.yaml` 12 | # use it to store all files generated during the run, like ckpts and metrics 13 | log_dir: ${paths.root_dir}/logs/${name} 14 | 15 | ckpt_dir: ${paths.log_dir}/checkpoints 16 | 17 | -------------------------------------------------------------------------------- /configs/datamodule/cath_4.3_TS50.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - cath_4.3.yaml 3 | 4 | _target_: TS50 5 | 6 | # data_dir: ${data_dir} # data_dir is specified in config.yaml 7 | data_dir: '${paths.data_dir}/cath_4.3' 8 | chain_set_jsonl: 'chain_set.jsonl' 9 | chain_set_splits_json: 'chain_set_splits.json' 10 | max_length: 500 # 393 11 | atoms: ['N', 'CA', 'C', 'O'] 12 | 13 | # alphabet related 14 | proteinseq_toks: ['A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V'] 15 | prepend_toks: ["", ""] 16 | append_toks: [] 17 | prepend_bos: false 18 | append_eos: false 19 | 20 | # dataloader related 21 | max_tokens: 6000 22 | sort: true 23 | num_workers: 8 24 | pin_memory: true 25 | -------------------------------------------------------------------------------- /scripts/design_pdb.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DIR="$(dirname "$0")" 4 | model_path="/root/research/projects/ByProt/run/logs/fixedbb/cath_4.2/lm_design_esm2_650m" 5 | # model_path="/root/research/projects/ByProt/run/logs/fixedbb/cath_4.3/lm_design_esm2_650m_gvptrans" 6 | # model_path="/root/research/projects/ByProt_public/logs/fixedbb_multichain/lm_design_esm2_650m" 7 | 8 | temperature=0.1 9 | pdb_dir="/root/research/projects/ByProt/data/pdb_samples" 10 | out_dir="$pdb_dir/lm_design_fasta" 11 | 12 | python $DIR/design_pdb.py \ 13 | --experiment_path $model_path --ckpt "best.ckpt" \ 14 | --pdb_dir $pdb_dir --out_dir $out_dir \ 15 | --seed 42 \ 16 | --num_seqs 1 \ 17 | --temperature $temperature \ 18 | --max_iter 5 19 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length = 99 3 | profile = black 4 | filter_files = True 5 | 6 | 7 | [flake8] 8 | max_line_length = 99 9 | show_source = True 10 | format = pylint 11 | ignore = 12 | F401 # Module imported but unused 13 | W504 # Line break occurred after a binary operator 14 | F841 # Local variable name is assigned to but never used 15 | E501 # Line too long 16 | exclude = 17 | .git 18 | __pycache__ 19 | data/* 20 | tests/* 21 | notebooks/* 22 | logs/* 23 | 24 | 25 | [tool:pytest] 26 | testpaths = tests/ 27 | log_cli = True 28 | markers = 29 | slow 30 | addopts = 31 | --durations=0 32 | --strict-markers 33 | --doctest-modules 34 | filterwarnings = 35 | ignore::DeprecationWarning 36 | ignore::UserWarning 37 | -------------------------------------------------------------------------------- /src/byprot/utils/registry.py: -------------------------------------------------------------------------------- 1 | from byprot.datamodules import DATAMODULE_REGISTRY 2 | from byprot.models import MODEL_REGISTRY 3 | from byprot.tasks import TASK_REGISTRY 4 | 5 | registry_dict = dict( 6 | datamodule=DATAMODULE_REGISTRY, 7 | task=TASK_REGISTRY, 8 | model=MODEL_REGISTRY 9 | ) 10 | 11 | def get_module(group_name, module_name): 12 | group = registry_dict.get(group_name, None) 13 | if group is None: 14 | raise KeyError(f'{group_name} is not a valid registry group {registry_dict.keys()}.') 15 | 16 | return group.get(module_name) 17 | 18 | def get_registered_modules(group_name): 19 | group = registry_dict.get(group_name) 20 | if group is not None: 21 | return group.keys() 22 | else: 23 | raise KeyError(f'{group_name} is not a valid registry group {registry_dict.keys()}.') 24 | 25 | __all__ = [ 26 | 'get_module', 27 | 'get_registered_modules' 28 | ] -------------------------------------------------------------------------------- /examples/pmpnn_compatible/helper_scripts/make_bias_AA.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def main(args): 4 | 5 | import numpy as np 6 | import json 7 | 8 | bias_list = [float(item) for item in args.bias_list.split()] 9 | AA_list = [str(item) for item in args.AA_list.split()] 10 | 11 | my_dict = dict(zip(AA_list, bias_list)) 12 | 13 | with open(args.output_path, 'w') as f: 14 | f.write(json.dumps(my_dict) + '\n') 15 | 16 | 17 | if __name__ == "__main__": 18 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | argparser.add_argument("--output_path", type=str, help="Path to the output dictionary") 20 | argparser.add_argument("--AA_list", type=str, default='', help="List of AAs to be biased") 21 | argparser.add_argument("--bias_list", type=str, default='', help="AA bias strengths") 22 | 23 | args = argparser.parse_args() 24 | main(args) 25 | 26 | #e.g. output 27 | #{"A": -0.01, "G": 0.02} 28 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!python 2 | 3 | import pyrootutils 4 | 5 | root = pyrootutils.setup_root( 6 | search_from=__file__, 7 | indicator=[".git", "pyproject.toml"], 8 | pythonpath=True, 9 | # load environment variables from `.env` file if it exists 10 | # recursively searches for `.env` in all folders starting from work dir 11 | dotenv=True, 12 | ) 13 | 14 | 15 | import dotenv 16 | import hydra 17 | from omegaconf import DictConfig 18 | 19 | 20 | @hydra.main(config_path=f"{root}/configs", config_name="test.yaml") 21 | def main(config: DictConfig): 22 | 23 | # Imports can be nested inside @hydra.main to optimize tab completion 24 | # https://github.com/facebookresearch/hydra/issues/934 25 | from byprot import utils 26 | from byprot.testing_pipeline import test 27 | 28 | # resolve user provided config 29 | config = utils.resolve_experiment_config(config) 30 | # Applies optional utilities 31 | config = utils.extras(config) 32 | 33 | # Evaluate model 34 | return test(config) 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # --------- pytorch --------- # 2 | torch==1.12.0 3 | # torchvision>=0.11.0 4 | pytorch-lightning==1.7.3 5 | torchmetrics>=0.9.3 6 | torchtext 7 | torchdata 8 | 9 | # --------- hydra --------- # 10 | hydra-core==1.2.0 11 | hydra-colorlog==1.2.0 12 | hydra-optuna-sweeper==1.2.0 13 | 14 | # --------- loggers --------- # 15 | # wandb 16 | # neptune-client 17 | # mlflow 18 | # comet-ml 19 | tensorboard 20 | 21 | # --------- linters --------- # 22 | pyrootutils # standardizing the project root setup 23 | pre-commit # hooks for applying linters on commit 24 | black # code formatting 25 | isort # import sorting 26 | flake8 # code analysis 27 | nbstripout # remove output from jupyter notebooks 28 | 29 | # --------- others --------- # 30 | python-dotenv # loading env variables from .env file 31 | rich # beautiful text formatting in terminal 32 | pytest # tests 33 | sh # for running bash commands in some tests 34 | pudb # debugger 35 | 36 | # --------- project related --------- # 37 | biopython==1.79 38 | einops 39 | debugpy 40 | matplotlib 41 | pandas 42 | seaborn 43 | opt_einsum 44 | sympy 45 | e3nn 46 | fairscale 47 | -------------------------------------------------------------------------------- /src/byprot/models/fixedbb/__init__.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | try: 3 | import esm 4 | ESM_INSTALLED = True 5 | except: 6 | ESM_INSTALLED = False 7 | 8 | from byprot.utils.config import compose_config, merge_config 9 | 10 | import torch 11 | from torch import nn 12 | import numpy as np 13 | 14 | class FixedBackboneDesignEncoderDecoder(nn.Module): 15 | _default_cfg = {} 16 | 17 | def __init__(self, cfg) -> None: 18 | super().__init__() 19 | self._update_cfg(cfg) 20 | 21 | def _update_cfg(self, cfg): 22 | self.cfg = OmegaConf.merge(self._default_cfg, cfg) 23 | 24 | @classmethod 25 | def from_config(cls, cfg): 26 | raise NotImplementedError 27 | 28 | def forward_encoder(self, batch): 29 | raise NotImplementedError 30 | 31 | def forward_decoder(self, prev_decoder_out, encoder_out): 32 | raise NotImplementedError 33 | 34 | def initialize_output_tokens(self, batch, encoder_out): 35 | raise NotImplementedError 36 | 37 | def forward(self, coords, coord_mask, tokens, token_padding_mask=None, **kwargs): 38 | raise NotImplementedError 39 | 40 | def sample(self, coords, coord_mask, tokens=None, token_padding_mask=None, **kwargs): 41 | raise NotImplementedError 42 | -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | model_summary: 2 | _target_: pytorch_lightning.callbacks.RichModelSummary 3 | max_depth: -1 4 | 5 | # rich_progress_bar: 6 | # _target_: src.utils.callbacks.BetterRichProgressBar 7 | # leave: false 8 | 9 | model_checkpoint: 10 | _target_: byprot.utils.callbacks.ModelCheckpoint 11 | monitor: ${train.monitor} # name of the logged metric which determines when model is improving 12 | mode: ${train.mode} # "max" means higher metric value is better, can be also "min" 13 | save_top_k: 1 # save k best models (determined by above metric) 14 | save_last: True # additionaly always save model from last epoch 15 | verbose: True 16 | dirpath: "checkpoints" 17 | filename: "step_{global_step}-${train.monitor}_{${train.monitor}:.2f}" 18 | auto_insert_metric_name: False 19 | # every_n_train_steps: 10 20 | 21 | early_stopping: 22 | _target_: pytorch_lightning.callbacks.EarlyStopping 23 | monitor: ${train.monitor} # name of the logged metric which determines when model is improving 24 | mode: ${train.mode} # "max" means higher metric value is better, can be also "min" 25 | patience: ${train.patience} # how many validation epochs of not improving until training stops 26 | min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement 27 | check_on_train_epoch_end: false -------------------------------------------------------------------------------- /src/byprot/models/fixedbb/lm_design/modules/gvp_transformer_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import esm 4 | 5 | class GVPTransformerEncoderWrapper(nn.Module): 6 | def __init__(self, alphabet, freeze=True): 7 | super().__init__() 8 | _model, _alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50() 9 | self.encoder = _model.encoder 10 | if freeze: 11 | for param in self.encoder.parameters(): 12 | param.requires_grad_(False) 13 | 14 | self.embed_dim = self.encoder.embed_tokens.embedding_dim 15 | self.out_proj = nn.Linear(self.embed_dim, len(alphabet)) 16 | 17 | def forward(self, batch, **kwargs): 18 | return_all_hiddens = False 19 | padding_mask = torch.isnan(batch['coords'][:, :, 0, 0]) 20 | coords = batch['coords'][:, :, :3, :] 21 | confidence = torch.ones(batch['coords'].shape[0:2]).to(coords.device) 22 | encoder_out = self.encoder(coords, padding_mask, confidence, 23 | return_all_hiddens=return_all_hiddens) 24 | # encoder_out['encoder_out'][0] = torch.transpose(encoder_out['encoder_out'][0], 0, 1) 25 | encoder_out['feats'] = encoder_out['encoder_out'][0].transpose(0, 1) 26 | logits = self.out_proj(encoder_out['feats']) 27 | return logits, encoder_out 28 | 29 | -------------------------------------------------------------------------------- /src/byprot/modules/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | _registry = {} 7 | class _Criterion(nn.Module): 8 | def __init__(self, cfg) -> None: 9 | super().__init__() 10 | self.cfg = cfg 11 | 12 | self.criterions = {} 13 | self.weights = {} 14 | 15 | self._build() 16 | 17 | def _build(self): 18 | for name, cfg in self.cfg.items(): 19 | _target_ = cfg.pop('_target_') 20 | weight = cfg.pop('weight', 1.0) 21 | self.criterions[name] = _instantiate(_target_, cfg=cfg, registry=_registry) 22 | self.weights[name] = weight 23 | 24 | def forward(self, model_outs, targets): 25 | """ 26 | 27 | Args: 28 | model_outs (dict): dict of loss_name: model_out 29 | targets (_type_): _description_ 30 | """ 31 | logging_outs = {} 32 | total_loss = 0. 33 | 34 | for name, model_out in model_outs.items(): 35 | if name in self.criterions: 36 | loss, logging_out = self.criterions[name](model_out, targets[name]) 37 | 38 | total_loss += self.weights[name] * loss 39 | logging_out = {f'{name}/{key}': val for key, val in logging_out.items()} 40 | 41 | logging_outs.update(logging_out) 42 | 43 | return total_loss, logging_outs -------------------------------------------------------------------------------- /configs/experiment/fixedbb_multichain/pifold.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - /datamodule: cath 8 | # - /task: fixedbb.yaml 9 | - /callbacks: fixedbb.yaml 10 | - /trainer: default.yaml 11 | 12 | # - /model: protein_mpnn.yaml 13 | 14 | 15 | # name of the run determines folder name in logs 16 | name: "fixedbb/cath/protein_mpnn_cmlm" 17 | 18 | # model 19 | model: 20 | _target_: pifold 21 | 22 | # datamodule 23 | datamodule: 24 | to_pifold_format: true 25 | 26 | # task 27 | task: 28 | _target_: protein/fixedbb 29 | noise: full_mask # enable cmlm training with uniform random masking 30 | 31 | criterion: 32 | _target_: src.criterions.cross_entropy.Coord2SeqCrossEntropyLoss 33 | label_smoothing: 0.0 34 | ignore_index: 0 35 | 36 | optimizer: 37 | type: adamw 38 | lr: ${train.lr} 39 | betas: 40 | - 0.9 41 | - 0.98 42 | weight_decay: 0.0001 43 | 44 | lr_scheduler: 45 | type: noam 46 | lr: ${train.lr} 47 | warmup_steps: 4000 48 | model_size: 128 49 | warmup_init_lr: 1e-07 50 | 51 | generator: 52 | max_iter: 1 53 | strategy: null 54 | 55 | # training related 56 | train: 57 | seed: 42 58 | lr: 3e-3 59 | monitor: "val/acc_median" 60 | mode: "max" 61 | 62 | trainer: 63 | min_epochs: 10 64 | max_epochs: 10000 65 | gradient_clip_val: 0.0 66 | # val_check_interval: 10 67 | num_sanity_val_steps: 1 68 | reload_dataloaders_every_n_epochs: 1 69 | replace_sampler_ddp: false 70 | max_steps: 200_000 71 | -------------------------------------------------------------------------------- /src/byprot/utils/strategies.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Union 2 | 3 | from pytorch_lightning.strategies import StrategyRegistry 4 | from pytorch_lightning.strategies.sharded import DDPShardedStrategy 5 | from torch.optim import Optimizer 6 | 7 | from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE 8 | 9 | if _FAIRSCALE_AVAILABLE: 10 | from fairscale.optim import OSS 11 | 12 | class DDPShardedFBOStrategy(DDPShardedStrategy): 13 | strategy_name = 'ddp_sharded_fbo' 14 | 15 | def __init__(self, force_broadcast_object=True, **kwargs) -> None: 16 | super().__init__(**kwargs) 17 | self.force_broadcast_object = force_broadcast_object 18 | 19 | def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: 20 | oos_optimizers = super()._wrap_optimizers(optimizers) 21 | if self.force_broadcast_object: 22 | for oos_optimizer in oos_optimizers: 23 | oos_optimizer.force_broadcast_object = True 24 | return oos_optimizers 25 | 26 | @classmethod 27 | def register_strategies(cls, strategy_registry: Dict) -> None: 28 | strategy_registry.register( 29 | cls.strategy_name, 30 | cls, 31 | description="DDP Shared Strategy with force_broadcast_object enabled", 32 | ) 33 | 34 | StrategyRegistry.register( 35 | "ddp_sharded_fbo", 36 | DDPShardedFBOStrategy, 37 | description="DDP Shared Strategy with force_broadcast_object enabled", 38 | force_broadcast_object=True, 39 | ) -------------------------------------------------------------------------------- /configs/experiment/fixedbb/pifold.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - /datamodule: cath 8 | # - /task: fixedbb.yaml 9 | - /callbacks: fixedbb.yaml 10 | - /trainer: default.yaml 11 | 12 | # - /model: protein_mpnn.yaml 13 | 14 | 15 | # name of the run determines folder name in logs 16 | name: "fixedbb/cath/pifold" 17 | 18 | # model 19 | model: 20 | _target_: pifold 21 | 22 | # datamodule 23 | datamodule: 24 | alphabet: 25 | name: protein_mpnn 26 | featurizer: cath 27 | featurizer_cfg: 28 | to_pifold_format: true 29 | 30 | # task 31 | task: 32 | _target_: fixedbb/cmlm 33 | alphabet: ${datamodule.alphabet} 34 | learning: 35 | noise: full_mask # 36 | criterion: 37 | _target_: byprot.modules.cross_entropy.Coord2SeqCrossEntropyLoss 38 | label_smoothing: 0.0 39 | ignore_index: 0 40 | optimizer: 41 | type: adamw 42 | lr: ${train.lr} 43 | betas: 44 | - 0.9 45 | - 0.98 46 | weight_decay: 0.0001 47 | lr_scheduler: 48 | type: noam 49 | lr: ${train.lr} 50 | warmup_steps: 4000 51 | model_size: 128 52 | warmup_init_lr: 1e-07 53 | generator: 54 | max_iter: 1 55 | strategy: "mask_predict" 56 | 57 | # training related 58 | train: 59 | seed: 42 60 | lr: 3e-3 61 | monitor: "val/acc_median" 62 | mode: "max" 63 | 64 | trainer: 65 | min_epochs: 10 66 | max_epochs: 10000 67 | gradient_clip_val: 0.0 68 | # val_check_interval: 10 69 | num_sanity_val_steps: 1 70 | reload_dataloaders_every_n_epochs: 1 71 | replace_sampler_ddp: false 72 | max_steps: 200_000 -------------------------------------------------------------------------------- /configs/experiment/fixedbb/protein_mpnn_cmlm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - /datamodule: cath_4.2 8 | - /callbacks: fixedbb 9 | - /trainer: default 10 | 11 | # name of the run determines folder name in logs 12 | name: "fixedbb/cath_4.2/protein_mpnn_cmlm" 13 | 14 | datamodule: 15 | alphabet: 16 | name: protein_mpnn 17 | featurizer: cath 18 | 19 | # model 20 | model: 21 | _target_: protein_mpnn_cmlm 22 | d_model: 128 23 | n_enc_layers: 3 24 | n_dec_layers: 3 25 | 26 | # task 27 | task: 28 | _target_: fixedbb/cmlm 29 | alphabet: ${datamodule.alphabet} 30 | learning: 31 | noise: random_mask # enable cmlm training with uniform random masking 32 | criterion: 33 | _target_: byprot.modules.cross_entropy.Coord2SeqCrossEntropyLoss 34 | label_smoothing: 0.0 35 | ignore_index: 0 36 | optimizer: 37 | type: adamw 38 | lr: ${train.lr} 39 | betas: 40 | - 0.9 41 | - 0.98 42 | weight_decay: 0.0001 43 | lr_scheduler: 44 | type: noam 45 | lr: ${train.lr} 46 | warmup_steps: 4000 47 | model_size: 128 48 | warmup_init_lr: 1e-07 49 | generator: 50 | max_iter: 1 51 | strategy: "mask_predict" 52 | 53 | # training related 54 | train: 55 | seed: 42 56 | lr: 3e-3 57 | monitor: "val/acc_median" 58 | mode: "max" 59 | 60 | trainer: 61 | min_epochs: 10 62 | max_epochs: 10000 63 | gradient_clip_val: 0.0 64 | # val_check_interval: 10 65 | num_sanity_val_steps: 1 66 | reload_dataloaders_every_n_epochs: 1 67 | replace_sampler_ddp: false 68 | max_steps: 200_000 69 | -------------------------------------------------------------------------------- /configs/callbacks/fixedbb.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | model_checkpoint: 5 | _target_: byprot.utils.callbacks.ModelCheckpoint 6 | monitor: "val/acc" # name of the logged metric which determines when model is improving 7 | mode: "max" # "max" means higher metric value is better, can be also "min" 8 | save_top_k: 1 # save k best models (determined by above metric) 9 | save_last: True # additionaly always save model from last epoch 10 | verbose: True 11 | dirpath: ${paths.ckpt_dir} 12 | filename: "step_{global_step}-ppl_{val/ppl:.2f}-acc_{val/acc:.2f}" 13 | auto_insert_metric_name: False 14 | # every_n_train_steps: 10 15 | 16 | early_stopping: 17 | _target_: pytorch_lightning.callbacks.EarlyStopping 18 | monitor: "val/acc" # name of the logged metric which determines when model is improving 19 | mode: "max" # "max" means higher metric value is better, can be also "min" 20 | patience: 30 # how many validation epochs of not improving until training stops 21 | min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement 22 | check_on_train_epoch_end: false 23 | 24 | # CheckpointEveryNSteps: 25 | # _target_: src.utils.callbacks.CheckpointEveryNSteps 26 | # save_step_frequency: 5 27 | 28 | # val_every_Nsteps: 29 | # _target_: src.utils.callbacks.ValEveryNSteps 30 | # every_n_step: 10 31 | 32 | # model_summary: 33 | # _target_: pytorch_lightning.callbacks.RichModelSummary 34 | # max_depth: -1 35 | 36 | # rich_progress_bar: 37 | # _target_: src.utils.callbacks.BetterRichProgressBar 38 | # leave: true 39 | 40 | # lr_monitor: 41 | # _target_: pytorch_lightning.callbacks.LearningRateMonitor 42 | # logging_interval: 'step' -------------------------------------------------------------------------------- /examples/pmpnn_compatible/helper_scripts/assign_fixed_chains.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def main(args): 4 | import json 5 | 6 | with open(args.input_path, 'r') as json_file: 7 | json_list = list(json_file) 8 | 9 | global_designed_chain_list = [] 10 | if args.chain_list != '': 11 | global_designed_chain_list = [str(item) for item in args.chain_list.split()] 12 | my_dict = {} 13 | for json_str in json_list: 14 | result = json.loads(json_str) 15 | all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain'] #['A','B', 'C',...] 16 | if len(global_designed_chain_list) > 0: 17 | designed_chain_list = global_designed_chain_list 18 | else: 19 | #manually specify, e.g. 20 | designed_chain_list = ["A"] 21 | fixed_chain_list = [letter for letter in all_chain_list if letter not in designed_chain_list] #fix/do not redesign these chains 22 | my_dict[result['name']]= (designed_chain_list, fixed_chain_list) 23 | 24 | with open(args.output_path, 'w') as f: 25 | f.write(json.dumps(my_dict) + '\n') 26 | 27 | 28 | if __name__ == "__main__": 29 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 30 | argparser.add_argument("--input_path", type=str, help="Path to the parsed PDBs") 31 | argparser.add_argument("--output_path", type=str, help="Path to the output dictionary") 32 | argparser.add_argument("--chain_list", type=str, default='', help="List of the chains that need to be designed") 33 | 34 | args = argparser.parse_args() 35 | main(args) 36 | 37 | # Output looks like this: 38 | # {"5TTA": [["A"], ["B"]], "3LIS": [["A"], ["B"]]} 39 | 40 | -------------------------------------------------------------------------------- /configs/experiment/fixedbb_multichain/protein_mpnn_cmlm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - /datamodule: multichain 8 | - /callbacks: fixedbb.yaml 9 | - /trainer: default.yaml 10 | 11 | # - /model: protein_mpnn.yaml 12 | 13 | 14 | # name of the run determines folder name in logs 15 | name: "fixedbb_multichain/protein_mpnn_cmlm" 16 | 17 | # datamodule 18 | datamodule: 19 | alphabet: 20 | name: protein_mpnn 21 | featurizer: multichain 22 | 23 | # model 24 | model: 25 | _target_: protein_mpnn_cmlm 26 | n_enc_layers: 3 27 | n_dec_layers: 3 28 | use_esm_alphabet: false # ${datamodule.use_esm_alphabet} 29 | 30 | 31 | # task 32 | task: 33 | _target_: fixedbb/cmlm 34 | alphabet: ${datamodule.alphabet} 35 | learning: 36 | noise: random_mask # enable cmlm training with uniform random masking 37 | criterion: 38 | _target_: byprot.modules.cross_entropy.Coord2SeqCrossEntropyLoss 39 | label_smoothing: 0.0 40 | ignore_index: 0 41 | optimizer: 42 | type: adamw 43 | lr: ${train.lr} 44 | betas: 45 | - 0.9 46 | - 0.98 47 | weight_decay: 0.0001 48 | lr_scheduler: 49 | type: noam 50 | lr: ${train.lr} 51 | warmup_steps: 4000 52 | model_size: 128 53 | warmup_init_lr: 1e-07 54 | generator: 55 | max_iter: 1 56 | strategy: 'mask_predict' 57 | 58 | # training related 59 | train: 60 | seed: 42 61 | lr: 3e-3 62 | monitor: "val/acc_median" 63 | mode: "max" 64 | 65 | trainer: 66 | min_epochs: 10 67 | max_epochs: 10000 68 | gradient_clip_val: 0.0 69 | # val_check_interval: 10 70 | num_sanity_val_steps: 1 71 | reload_dataloaders_every_n_epochs: 1 72 | replace_sampler_ddp: false 73 | max_steps: 200_000 74 | -------------------------------------------------------------------------------- /configs/experiment/fixedbb_multichain/lm_design_esm1b_650m.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - /datamodule: multichain 8 | - /callbacks: fixedbb.yaml 9 | - /trainer: default.yaml 10 | 11 | # - /model: protein_mpnn.yaml 12 | 13 | 14 | # name of the run determines folder name in logs 15 | name: "fixedbb_multichain/lm_design_esm2_3b" 16 | 17 | # datamodule 18 | datamodule: 19 | alphabet: 20 | name: esm 21 | featurizer: multichain 22 | 23 | # model 24 | model: 25 | _target_: esm_adapter 26 | encoder: 27 | d_model: 128 28 | n_enc_layers: 3 29 | n_dec_layers: 3 30 | use_esm_alphabet: true 31 | adapter_layer_indices: [32, ] 32 | 33 | # task 34 | task: 35 | _target_: fixedbb/cmlm 36 | alphabet: ${datamodule.alphabet} 37 | learning: 38 | noise: random_mask # enable cmlm training with uniform random masking 39 | criterion: 40 | _target_: byprot.modules.cross_entropy.Coord2SeqCrossEntropyLoss 41 | label_smoothing: 0.0 42 | ignore_index: 0 43 | optimizer: 44 | type: adamw 45 | lr: ${train.lr} 46 | betas: 47 | - 0.9 48 | - 0.98 49 | weight_decay: 0.0001 50 | lr_scheduler: 51 | type: noam 52 | lr: ${train.lr} 53 | warmup_steps: 4000 54 | model_size: 128 55 | warmup_init_lr: 1e-07 56 | generator: 57 | max_iter: 1 58 | strategy: 'mask_predict' 59 | 60 | # training related 61 | train: 62 | seed: 42 63 | lr: 3e-3 64 | monitor: "val/acc_median" 65 | mode: "max" 66 | 67 | trainer: 68 | min_epochs: 10 69 | max_epochs: 10000 70 | gradient_clip_val: 0.0 71 | # val_check_interval: 10 72 | num_sanity_val_steps: 1 73 | reload_dataloaders_every_n_epochs: 1 74 | replace_sampler_ddp: false 75 | max_steps: 200_000 76 | -------------------------------------------------------------------------------- /configs/experiment/fixedbb/lm_design_esm1b_650m.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - /datamodule: cath_4.2 8 | - /callbacks: fixedbb 9 | - /trainer: default 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | # name of the run determines folder name in logs 15 | name: "fixedbb/cath_4.2/lm_design_esm1b_650m" 16 | 17 | datamodule: 18 | alphabet: 19 | name: esm 20 | featurizer: cath 21 | 22 | model: 23 | _target_: esm_adapter 24 | encoder: 25 | d_model: 128 26 | n_enc_layers: 3 27 | n_dec_layers: 3 28 | use_esm_alphabet: true 29 | adapter_layer_indices: [32, ] 30 | 31 | task: 32 | _target_: fixedbb/cmlm 33 | alphabet: ${datamodule.alphabet} 34 | learning: 35 | noise: random_mask # enable cmlm training with uniform random masking 36 | criterion: 37 | _target_: byprot.modules.cross_entropy.Coord2SeqCrossEntropyLoss 38 | label_smoothing: 0.0 39 | ignore_index: 1 40 | optimizer: 41 | type: adamw 42 | _partial_: true 43 | lr: ${train.lr} 44 | betas: 45 | - 0.9 46 | - 0.98 47 | weight_decay: 0.0001 48 | lr_scheduler: 49 | type: noam 50 | warmup_steps: 4000 51 | model_size: 128 52 | lr: ${train.lr} 53 | warmup_init_lr: 1e-07 54 | generator: 55 | max_iter: 5 56 | strategy: 'denoise' 57 | 58 | train: 59 | seed: 42 60 | lr: 0.001 61 | monitor: "val/acc_median" 62 | mode: "max" 63 | 64 | trainer: 65 | min_epochs: 10 66 | max_epochs: 10000 67 | gradient_clip_val: 0.0 68 | # val_check_interval: 10 69 | num_sanity_val_steps: 1 70 | reload_dataloaders_every_n_epochs: 1 71 | replace_sampler_ddp: false 72 | max_steps: 200_000 -------------------------------------------------------------------------------- /configs/experiment/fixedbb_multichain/lm_design_esm2_650m.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - /datamodule: multichain 8 | - /callbacks: fixedbb.yaml 9 | - /trainer: default.yaml 10 | 11 | # - /model: protein_mpnn.yaml 12 | 13 | 14 | # name of the run determines folder name in logs 15 | name: "fixedbb_multichain/lm_design_esm1b_650m" 16 | 17 | # datamodule 18 | datamodule: 19 | alphabet: 20 | name: esm 21 | featurizer: multichain 22 | 23 | # model 24 | model: 25 | _target_: esm2_adapter 26 | encoder: 27 | d_model: 128 28 | n_enc_layers: 3 29 | n_dec_layers: 3 30 | use_esm_alphabet: true 31 | 32 | name: esm2_t33_650M_UR50D 33 | adapter_layer_indices: [-1, ] 34 | separate_loss: true 35 | 36 | # task 37 | task: 38 | _target_: fixedbb/cmlm 39 | alphabet: ${datamodule.alphabet} 40 | learning: 41 | noise: random_mask # enable cmlm training with uniform random masking 42 | criterion: 43 | _target_: byprot.modules.cross_entropy.Coord2SeqCrossEntropyLoss 44 | label_smoothing: 0.0 45 | ignore_index: 0 46 | optimizer: 47 | type: adamw 48 | lr: ${train.lr} 49 | betas: 50 | - 0.9 51 | - 0.98 52 | weight_decay: 0.0001 53 | lr_scheduler: 54 | type: noam 55 | lr: ${train.lr} 56 | warmup_steps: 4000 57 | model_size: 128 58 | warmup_init_lr: 1e-07 59 | generator: 60 | max_iter: 1 61 | strategy: 'mask_predict' 62 | 63 | # training related 64 | train: 65 | seed: 42 66 | lr: 3e-3 67 | monitor: "val/acc_median" 68 | mode: "max" 69 | 70 | trainer: 71 | min_epochs: 10 72 | max_epochs: 10000 73 | gradient_clip_val: 0.0 74 | # val_check_interval: 10 75 | num_sanity_val_steps: 1 76 | reload_dataloaders_every_n_epochs: 1 77 | replace_sampler_ddp: false 78 | max_steps: 200_000 79 | -------------------------------------------------------------------------------- /configs/experiment/fixedbb_multichain/lm_design_esm2_3b.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - /datamodule: multichain 8 | - /callbacks: fixedbb.yaml 9 | - /trainer: default.yaml 10 | 11 | # - /model: protein_mpnn.yaml 12 | 13 | 14 | # name of the run determines folder name in logs 15 | name: "fixedbb_multichain/lm_design_esm2_3b" 16 | 17 | # datamodule 18 | datamodule: 19 | alphabet: 20 | name: esm 21 | featurizer: multichain 22 | 23 | # model 24 | model: 25 | _target_: esm2_adapter 26 | encoder: 27 | d_model: 128 28 | n_enc_layers: 3 29 | n_dec_layers: 3 30 | use_esm_alphabet: true 31 | 32 | name: esm2_t36_3B_UR50D 33 | dropout: 0.3 34 | adapter_layer_indices: [-1, ] 35 | separate_loss: true 36 | 37 | # task 38 | task: 39 | _target_: fixedbb/cmlm 40 | alphabet: ${datamodule.alphabet} 41 | learning: 42 | noise: random_mask # enable cmlm training with uniform random masking 43 | criterion: 44 | _target_: byprot.modules.cross_entropy.Coord2SeqCrossEntropyLoss 45 | label_smoothing: 0.0 46 | ignore_index: 0 47 | optimizer: 48 | type: adamw 49 | lr: ${train.lr} 50 | betas: 51 | - 0.9 52 | - 0.98 53 | weight_decay: 0.0001 54 | lr_scheduler: 55 | type: noam 56 | lr: ${train.lr} 57 | warmup_steps: 4000 58 | model_size: 128 59 | warmup_init_lr: 1e-07 60 | generator: 61 | max_iter: 1 62 | strategy: 'mask_predict' 63 | 64 | # training related 65 | train: 66 | seed: 42 67 | lr: 3e-3 68 | monitor: "val/acc_median" 69 | mode: "max" 70 | 71 | trainer: 72 | min_epochs: 10 73 | max_epochs: 10000 74 | gradient_clip_val: 0.0 75 | # val_check_interval: 10 76 | num_sanity_val_steps: 1 77 | reload_dataloaders_every_n_epochs: 1 78 | replace_sampler_ddp: false 79 | max_steps: 200_000 80 | -------------------------------------------------------------------------------- /configs/experiment/fixedbb/lm_design_esm2_35m.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - /datamodule: cath_4.2 8 | - /callbacks: fixedbb 9 | - /trainer: default 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | # name of the run determines folder name in logs 15 | name: "fixedbb/cath_4.2/lm_design_esm1b_650m" 16 | 17 | datamodule: 18 | alphabet: 19 | name: esm 20 | featurizer: cath 21 | 22 | model: 23 | _target_: esm2_adapter 24 | encoder: 25 | d_model: 128 26 | n_enc_layers: 3 27 | n_dec_layers: 3 28 | use_esm_alphabet: true 29 | 30 | name: esm2_t12_35M_UR50D 31 | adapter_layer_indices: [-1, ] 32 | separate_loss: true 33 | 34 | task: 35 | _target_: fixedbb/cmlm 36 | alphabet: ${datamodule.alphabet} 37 | learning: 38 | noise: random_mask # enable cmlm training with uniform random masking 39 | criterion: 40 | _target_: byprot.modules.cross_entropy.Coord2SeqCrossEntropyLoss 41 | label_smoothing: 0.0 42 | ignore_index: 1 43 | optimizer: 44 | type: adamw 45 | _partial_: true 46 | lr: ${train.lr} 47 | betas: 48 | - 0.9 49 | - 0.98 50 | weight_decay: 0.0001 51 | lr_scheduler: 52 | type: noam 53 | warmup_steps: 4000 54 | model_size: 128 55 | lr: ${train.lr} 56 | warmup_init_lr: 1e-07 57 | generator: 58 | max_iter: 5 59 | strategy: 'denoise' 60 | 61 | train: 62 | seed: 42 63 | lr: 0.001 64 | monitor: "val/acc_median" 65 | mode: "max" 66 | 67 | trainer: 68 | min_epochs: 10 69 | max_epochs: 10000 70 | gradient_clip_val: 0.0 71 | # val_check_interval: 10 72 | num_sanity_val_steps: 1 73 | reload_dataloaders_every_n_epochs: 1 74 | replace_sampler_ddp: false 75 | max_steps: 200_000 -------------------------------------------------------------------------------- /configs/experiment/fixedbb/lm_design_esm2_8m.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - /datamodule: cath_4.2 8 | - /callbacks: fixedbb 9 | - /trainer: default 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | # name of the run determines folder name in logs 15 | name: "fixedbb/cath_4.2/lm_design_esm1b_650m" 16 | 17 | datamodule: 18 | alphabet: 19 | name: esm 20 | featurizer: cath 21 | 22 | model: 23 | _target_: esm2_adapter 24 | encoder: 25 | d_model: 128 26 | n_enc_layers: 3 27 | n_dec_layers: 3 28 | use_esm_alphabet: true 29 | 30 | name: esm2_t6_8M_UR50D 31 | adapter_layer_indices: [-1, ] 32 | separate_loss: true 33 | 34 | task: 35 | _target_: fixedbb/cmlm 36 | alphabet: ${datamodule.alphabet} 37 | learning: 38 | noise: random_mask # enable cmlm training with uniform random masking 39 | criterion: 40 | _target_: byprot.modules.cross_entropy.Coord2SeqCrossEntropyLoss 41 | label_smoothing: 0.0 42 | ignore_index: 1 43 | optimizer: 44 | type: adamw 45 | _partial_: true 46 | lr: ${train.lr} 47 | betas: 48 | - 0.9 49 | - 0.98 50 | weight_decay: 0.0001 51 | lr_scheduler: 52 | type: noam 53 | warmup_steps: 4000 54 | model_size: 128 55 | lr: ${train.lr} 56 | warmup_init_lr: 1e-07 57 | generator: 58 | max_iter: 5 59 | strategy: 'denoise' 60 | 61 | train: 62 | seed: 42 63 | lr: 0.001 64 | monitor: "val/acc_median" 65 | mode: "max" 66 | 67 | trainer: 68 | min_epochs: 10 69 | max_epochs: 10000 70 | gradient_clip_val: 0.0 71 | # val_check_interval: 10 72 | num_sanity_val_steps: 1 73 | reload_dataloaders_every_n_epochs: 1 74 | replace_sampler_ddp: false 75 | max_steps: 200_000 -------------------------------------------------------------------------------- /configs/experiment/fixedbb/lm_design_esm2_150m.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - /datamodule: cath_4.2 8 | - /callbacks: fixedbb 9 | - /trainer: default 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | # name of the run determines folder name in logs 15 | name: "fixedbb/cath_4.2/lm_design_esm1b_650m" 16 | 17 | datamodule: 18 | alphabet: 19 | name: esm 20 | featurizer: cath 21 | 22 | model: 23 | _target_: esm2_adapter 24 | encoder: 25 | d_model: 128 26 | n_enc_layers: 3 27 | n_dec_layers: 3 28 | use_esm_alphabet: true 29 | 30 | name: esm2_t30_150M_UR50D 31 | adapter_layer_indices: [-1, ] 32 | separate_loss: true 33 | 34 | task: 35 | _target_: fixedbb/cmlm 36 | alphabet: ${datamodule.alphabet} 37 | learning: 38 | noise: random_mask # enable cmlm training with uniform random masking 39 | criterion: 40 | _target_: byprot.modules.cross_entropy.Coord2SeqCrossEntropyLoss 41 | label_smoothing: 0.0 42 | ignore_index: 1 43 | optimizer: 44 | type: adamw 45 | _partial_: true 46 | lr: ${train.lr} 47 | betas: 48 | - 0.9 49 | - 0.98 50 | weight_decay: 0.0001 51 | lr_scheduler: 52 | type: noam 53 | warmup_steps: 4000 54 | model_size: 128 55 | lr: ${train.lr} 56 | warmup_init_lr: 1e-07 57 | generator: 58 | max_iter: 5 59 | strategy: 'denoise' 60 | 61 | train: 62 | seed: 42 63 | lr: 0.001 64 | monitor: "val/acc_median" 65 | mode: "max" 66 | 67 | trainer: 68 | min_epochs: 10 69 | max_epochs: 10000 70 | gradient_clip_val: 0.0 71 | # val_check_interval: 10 72 | num_sanity_val_steps: 1 73 | reload_dataloaders_every_n_epochs: 1 74 | replace_sampler_ddp: false 75 | max_steps: 200_000 -------------------------------------------------------------------------------- /configs/experiment/fixedbb/lm_design_esm2_650m.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - /datamodule: cath_4.2 8 | - /callbacks: fixedbb 9 | - /trainer: default 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | # name of the run determines folder name in logs 15 | name: "fixedbb/cath_4.2/lm_design_esm1b_650m" 16 | 17 | datamodule: 18 | alphabet: 19 | name: esm 20 | featurizer: cath 21 | 22 | model: 23 | _target_: esm2_adapter 24 | encoder: 25 | d_model: 128 26 | n_enc_layers: 3 27 | n_dec_layers: 3 28 | use_esm_alphabet: true 29 | 30 | name: esm2_t33_650M_UR50D 31 | adapter_layer_indices: [-1, ] 32 | separate_loss: true 33 | 34 | task: 35 | _target_: fixedbb/cmlm 36 | alphabet: ${datamodule.alphabet} 37 | learning: 38 | noise: random_mask # enable cmlm training with uniform random masking 39 | criterion: 40 | _target_: byprot.modules.cross_entropy.Coord2SeqCrossEntropyLoss 41 | label_smoothing: 0.0 42 | ignore_index: 1 43 | optimizer: 44 | type: adamw 45 | _partial_: true 46 | lr: ${train.lr} 47 | betas: 48 | - 0.9 49 | - 0.98 50 | weight_decay: 0.0001 51 | lr_scheduler: 52 | type: noam 53 | warmup_steps: 4000 54 | model_size: 128 55 | lr: ${train.lr} 56 | warmup_init_lr: 1e-07 57 | generator: 58 | max_iter: 5 59 | strategy: 'denoise' 60 | 61 | train: 62 | seed: 42 63 | lr: 0.001 64 | monitor: "val/acc_median" 65 | mode: "max" 66 | 67 | trainer: 68 | min_epochs: 10 69 | max_epochs: 10000 70 | gradient_clip_val: 0.0 71 | # val_check_interval: 10 72 | num_sanity_val_steps: 1 73 | reload_dataloaders_every_n_epochs: 1 74 | replace_sampler_ddp: false 75 | max_steps: 200_000 -------------------------------------------------------------------------------- /src/byprot/models/seq2seq/modules/ffn.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | def get_activation_fn(activation: str) -> Callable: 7 | """Returns the activation function corresponding to `activation`""" 8 | 9 | if activation == "relu": 10 | return F.relu 11 | elif activation == "glu": 12 | return F.glu 13 | elif activation == "gelu": 14 | return F.gelu 15 | elif activation == "tanh": 16 | return torch.tanh 17 | elif activation == "linear": 18 | return lambda x: x 19 | elif activation == "swish": 20 | return torch.nn.SiLU 21 | else: 22 | raise RuntimeError("activation {} not supported".format(activation)) 23 | 24 | 25 | class FFN(nn.Module): 26 | """ Feed-forward neural network """ 27 | 28 | def __init__(self, 29 | d_model, 30 | d_inner=None, 31 | activation="gelu", 32 | dropout=0.0): 33 | super().__init__() 34 | d_inner = d_inner or d_model 35 | 36 | self.fc1 = nn.Linear(d_model, d_inner) 37 | self.fc2 = nn.Linear(d_inner, d_model) 38 | self.activation = get_activation_fn(activation) 39 | self.dropout = nn.Dropout(dropout) 40 | 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | self.fc1.reset_parameters() 45 | self.fc2.reset_parameters() 46 | 47 | def forward(self, x): 48 | """ 49 | Args: 50 | x: feature to perform ffn 51 | :math:`(*, D)`, where D is feature dimension 52 | 53 | Returns: 54 | - feed forward output 55 | :math:`(*, D)`, where D is feature dimension 56 | """ 57 | x = self.fc1(x) 58 | x = self.activation(x) 59 | x = self.dropout(x) 60 | x = self.fc2(x) 61 | return x -------------------------------------------------------------------------------- /configs/experiment/fixedbb/lm_design_esm2_3b.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - /datamodule: cath_4.2 8 | - /callbacks: fixedbb 9 | - /trainer: default 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | # name of the run determines folder name in logs 15 | name: "fixedbb/cath_4.2/lm_design_esm1b_650m" 16 | 17 | datamodule: 18 | alphabet: 19 | name: esm 20 | featurizer: cath 21 | 22 | model: 23 | _target_: esm2_adapter 24 | encoder: 25 | d_model: 128 26 | n_enc_layers: 3 27 | n_dec_layers: 3 28 | use_esm_alphabet: true 29 | 30 | name: esm2_t36_3B_UR50D 31 | dropout: 0.3 32 | adapter_layer_indices: [-1, ] 33 | separate_loss: true 34 | 35 | task: 36 | _target_: fixedbb/cmlm 37 | alphabet: ${datamodule.alphabet} 38 | learning: 39 | noise: random_mask # enable cmlm training with uniform random masking 40 | criterion: 41 | _target_: byprot.modules.cross_entropy.Coord2SeqCrossEntropyLoss 42 | label_smoothing: 0.0 43 | ignore_index: 1 44 | optimizer: 45 | type: adamw 46 | _partial_: true 47 | lr: ${train.lr} 48 | betas: 49 | - 0.9 50 | - 0.98 51 | weight_decay: 0.0001 52 | lr_scheduler: 53 | type: noam 54 | warmup_steps: 4000 55 | model_size: 128 56 | lr: ${train.lr} 57 | warmup_init_lr: 1e-07 58 | generator: 59 | max_iter: 5 60 | strategy: 'denoise' 61 | 62 | train: 63 | seed: 42 64 | lr: 0.001 65 | monitor: "val/acc_median" 66 | mode: "max" 67 | 68 | trainer: 69 | min_epochs: 10 70 | max_epochs: 10000 71 | gradient_clip_val: 0.0 72 | # val_check_interval: 10 73 | num_sanity_val_steps: 1 74 | reload_dataloaders_every_n_epochs: 1 75 | replace_sampler_ddp: false 76 | max_steps: 200_000 -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - _self_ 6 | - experiment: fixedbb/protein_mpnn_cmlm # specifies pipeline and model 7 | 8 | - callbacks: # pytorch-lightning callbacks 9 | - default 10 | - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 11 | - paths: default 12 | - hydra: default 13 | 14 | # default name for the experiment, determines logging folder path 15 | # (you can overwrite this name in experiment configs) 16 | name: ??? 17 | 18 | train: 19 | # set False to skip model training 20 | train: True 21 | # evaluate on test set, using best model weights achieved during training 22 | # lightning chooses best weights based on the metric specified in checkpoint callback 23 | test: True 24 | 25 | debug: false 26 | 27 | force_restart: false # force to train from scratch 28 | 29 | # simply provide checkpoint path to resume training 30 | # it can be either an absolute path, 31 | # or an relative path which will then be inferred from 32 | # 1) current workding directory (cwd), or 33 | # 2) checkpoint directory (${paths.ckpt_dir}) 34 | ckpt_path: last.ckpt 35 | 36 | seed: 42 # seed for random number generators in pytorch, numpy and python.random 37 | 38 | lr: 1e-3 # learning rate 39 | monitor: ??? # name of the logged metric which determines when model is improving. Used by scheduler (plateau), checkpointer, and early stopping 40 | mode: ??? # "max" means higher metric value is better, can be also "min". Used by scheduler (plateau), checkpointer, and early stopping 41 | patience: 30 # how many validation epochs of not improving until training stops 42 | 43 | print_config: True # pretty print config at the start of the run using Rich library 44 | ignore_warnings: True # disable python warnings if they annoy you 45 | seed: 42 # seed for random number generators in pytorch, numpy and python.random -------------------------------------------------------------------------------- /src/byprot/models/seq2seq/modules/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from functools import partial 3 | 4 | import torch 5 | import torch.functional as F 6 | import torch.nn as nn 7 | 8 | 9 | def _get_clones(module, N): 10 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 11 | 12 | 13 | class ResNorm(nn.Module): 14 | def __init__( 15 | self, 16 | net, 17 | dim, 18 | alpha=1.0, 19 | dropout=0.0, 20 | normalize_before=True, 21 | norm=None, 22 | fn=None 23 | ): 24 | super().__init__() 25 | 26 | self.net = net 27 | self.fn = partial( 28 | fn or (lambda net, *args, **kwargs: net(*args, **kwargs)), self.net 29 | ) 30 | self.normalize_before = normalize_before 31 | self.norm = norm or nn.LayerNorm(dim) 32 | self.dropout = nn.Dropout(dropout) 33 | self.alpha = alpha 34 | 35 | self.reset_parameters() 36 | 37 | def reset_parameters(self): 38 | self.net.reset_parameters() 39 | 40 | def forward(self, x, *args, **kwargs): 41 | identity = x 42 | 43 | if self.normalize_before: 44 | x = self.norm(x) 45 | 46 | x = self.fn(x, *args, **kwargs) 47 | x = [x] if not isinstance(x, tuple) else list(x) 48 | 49 | x[0] = self.alpha * identity + self.dropout(x[0]) 50 | 51 | if not self.normalize_before: 52 | x[0] = self.norm(x[0]) 53 | 54 | return x[0] if len(x) == 1 else tuple(x) 55 | 56 | def extra_repr(self): 57 | lines = f"(normalize_before): {self.normalize_before}" 58 | return lines 59 | 60 | 61 | def apply_weight_norm(net: nn.Module): 62 | for name, module in net.named_modules(): 63 | if isinstance(module, nn.Linear): 64 | nn.utils.weight_norm(module, name='weight') 65 | 66 | 67 | class RepeatedSequential(nn.Sequential): 68 | def forward(self, x, *args, **kwds): 69 | for m in self: 70 | x = m(x, *args, **kwds) 71 | return x 72 | -------------------------------------------------------------------------------- /configs/experiment/fixedbb/lm_design_esm2_150m_gvptrans_cath4.3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - /datamodule: cath_4.3 # note that gvp-transformer were trained on CATH 4.3, 8 | # so it could be at the risk of data leakage if evaluating on cath 4.2 9 | - /callbacks: fixedbb 10 | - /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | # name of the run determines folder name in logs 16 | name: "fixedbb/cath_4.3/lm_design_esm1b_150m_gvptransf" 17 | 18 | datamodule: 19 | alphabet: 20 | name: esm 21 | featurizer: cath 22 | featurizer_cfg: 23 | coord_nan_to_zero: false 24 | 25 | model: 26 | _target_: esm2_adapter_gvptrans 27 | encoder: 28 | d_model: 512 29 | 30 | name: esm2_t30_150M_UR50D 31 | adapter_layer_indices: [-1, ] 32 | separate_loss: true 33 | 34 | task: 35 | _target_: fixedbb/cmlm 36 | alphabet: ${datamodule.alphabet} 37 | learning: 38 | noise: random_mask # enable cmlm training with uniform random masking 39 | criterion: 40 | _target_: byprot.modules.cross_entropy.Coord2SeqCrossEntropyLoss 41 | label_smoothing: 0.0 42 | ignore_index: 1 43 | optimizer: 44 | type: adamw 45 | _partial_: true 46 | lr: ${train.lr} 47 | betas: 48 | - 0.9 49 | - 0.98 50 | weight_decay: 0.0001 51 | lr_scheduler: 52 | type: noam 53 | warmup_steps: 4000 54 | model_size: 128 55 | lr: ${train.lr} 56 | warmup_init_lr: 1e-07 57 | generator: 58 | max_iter: 5 59 | strategy: 'denoise' 60 | 61 | train: 62 | seed: 42 63 | lr: 0.001 64 | monitor: "val/acc_median" 65 | mode: "max" 66 | 67 | trainer: 68 | min_epochs: 10 69 | max_epochs: 10000 70 | gradient_clip_val: 0.0 71 | # val_check_interval: 10 72 | num_sanity_val_steps: 1 73 | reload_dataloaders_every_n_epochs: 1 74 | replace_sampler_ddp: false 75 | max_steps: 200_000 -------------------------------------------------------------------------------- /configs/experiment/fixedbb/lm_design_esm2_650m_gvptrans_cath4.3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - /datamodule: cath_4.3 # note that gvp-transformer were trained on CATH 4.3, 8 | # so it could be at the risk of data leakage if evaluating on cath 4.2 9 | - /callbacks: fixedbb 10 | - /trainer: default 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | # name of the run determines folder name in logs 16 | name: "fixedbb/cath_4.3/lm_design_esm1b_150m_gvptransf" 17 | 18 | datamodule: 19 | alphabet: 20 | name: esm 21 | featurizer: cath 22 | featurizer_cfg: 23 | coord_nan_to_zero: false 24 | 25 | model: 26 | _target_: esm2_adapter_gvptrans 27 | encoder: 28 | d_model: 512 29 | 30 | name: esm2_t33_650M_UR50D 31 | adapter_layer_indices: [-1, ] 32 | separate_loss: true 33 | 34 | task: 35 | _target_: fixedbb/cmlm 36 | alphabet: ${datamodule.alphabet} 37 | learning: 38 | noise: random_mask # enable cmlm training with uniform random masking 39 | criterion: 40 | _target_: byprot.modules.cross_entropy.Coord2SeqCrossEntropyLoss 41 | label_smoothing: 0.0 42 | ignore_index: 1 43 | optimizer: 44 | type: adamw 45 | _partial_: true 46 | lr: ${train.lr} 47 | betas: 48 | - 0.9 49 | - 0.98 50 | weight_decay: 0.0001 51 | lr_scheduler: 52 | type: noam 53 | warmup_steps: 4000 54 | model_size: 128 55 | lr: ${train.lr} 56 | warmup_init_lr: 1e-07 57 | generator: 58 | max_iter: 5 59 | strategy: 'denoise' 60 | 61 | train: 62 | seed: 42 63 | lr: 0.001 64 | monitor: "val/acc_median" 65 | mode: "max" 66 | 67 | trainer: 68 | min_epochs: 10 69 | max_epochs: 10000 70 | gradient_clip_val: 0.0 71 | # val_check_interval: 10 72 | num_sanity_val_steps: 1 73 | reload_dataloaders_every_n_epochs: 1 74 | replace_sampler_ddp: false 75 | max_steps: 200_000 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!python 2 | 3 | import pyrootutils 4 | 5 | # ------------------------------------------------------------------------------------ # 6 | # `pyrootutils.setup_root(...)` is recommended at the top of each start file 7 | # to make the environment more robust and consistent 8 | # 9 | # the line above searches for ".git" or "pyproject.toml" in present and parent dirs 10 | # to determine the project root dir 11 | # 12 | # adds root dir to the PYTHONPATH (if `pythonpath=True`) 13 | # so this file can be run from any place without installing project as a package 14 | # 15 | # sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml" 16 | # this makes all paths relative to the project root 17 | # 18 | # additionally loads environment variables from ".env" file (if `dotenv=True`) 19 | # 20 | # you can get away without using `pyrootutils.setup_root(...)` if you: 21 | # 1. move this file to the project root dir or install project as a package 22 | # 2. modify paths in "configs/paths/default.yaml" to not use PROJECT_ROOT 23 | # 3. always run this file from the project root dir 24 | # 25 | # https://github.com/ashleve/pyrootutils 26 | # ------------------- 27 | 28 | root = pyrootutils.setup_root( 29 | search_from=__file__, 30 | indicator=[".git", "pyproject.toml"], 31 | pythonpath=True, 32 | # load environment variables from `.env` file if it exists 33 | # recursively searches for `.env` in all folders starting from work dir 34 | dotenv=True, 35 | ) 36 | 37 | 38 | import hydra 39 | from omegaconf import DictConfig 40 | 41 | 42 | @hydra.main(version_base='1.1', config_path=f"{root}/configs", config_name="config.yaml") 43 | def main(config: DictConfig): 44 | 45 | # Imports can be nested inside @hydra.main to optimize tab completion 46 | # https://github.com/facebookresearch/hydra/issues/934 47 | from byprot import utils 48 | from byprot.training_pipeline import train 49 | 50 | # Applies optional utilities 51 | config = utils.extras(config) 52 | 53 | # Train model 54 | return train(config) 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /examples/pmpnn_compatible/helper_scripts/make_bias_per_res_dict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def main(args): 4 | import glob 5 | import random 6 | import numpy as np 7 | import json 8 | 9 | mpnn_alphabet = 'ACDEFGHIKLMNPQRSTVWYX' 10 | 11 | mpnn_alphabet_dict = {'A': 0,'C': 1,'D': 2,'E': 3,'F': 4,'G': 5,'H': 6,'I': 7,'K': 8,'L': 9,'M': 10,'N': 11,'P': 12,'Q': 13,'R': 14,'S': 15,'T': 16,'V': 17,'W': 18,'Y': 19,'X': 20} 12 | 13 | with open(args.input_path, 'r') as json_file: 14 | json_list = list(json_file) 15 | 16 | my_dict = {} 17 | for json_str in json_list: 18 | result = json.loads(json_str) 19 | all_chain_list = [item[-1:] for item in list(result) if item[:10]=='seq_chain_'] 20 | bias_by_res_dict = {} 21 | for chain in all_chain_list: 22 | chain_length = len(result[f'seq_chain_{chain}']) 23 | bias_per_residue = np.zeros([chain_length, 21]) 24 | 25 | 26 | if chain == 'A': 27 | residues = [0, 1, 2, 3, 4, 5, 11, 12, 13, 14, 15] 28 | amino_acids = [5, 9] #[G, L] 29 | for res in residues: 30 | for aa in amino_acids: 31 | bias_per_residue[res, aa] = 100.5 32 | 33 | if chain == 'C': 34 | residues = [0, 1, 2, 3, 4, 5, 11, 12, 13, 14, 15] 35 | amino_acids = range(21)[1:] #[G, L] 36 | for res in residues: 37 | for aa in amino_acids: 38 | bias_per_residue[res, aa] = -100.5 39 | 40 | bias_by_res_dict[chain] = bias_per_residue.tolist() 41 | my_dict[result['name']] = bias_by_res_dict 42 | 43 | with open(args.output_path, 'w') as f: 44 | f.write(json.dumps(my_dict) + '\n') 45 | 46 | 47 | if __name__ == "__main__": 48 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 49 | argparser.add_argument("--input_path", type=str, help="Path to the parsed PDBs") 50 | argparser.add_argument("--output_path", type=str, help="Path to the output dictionary") 51 | 52 | args = argparser.parse_args() 53 | main(args) 54 | -------------------------------------------------------------------------------- /src/byprot/modules/metrics.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from functools import partial 5 | import numpy as np 6 | 7 | 8 | def luost_rmsd(res_list1: list, res_list2: list): 9 | res_short, res_long = (res_list1, res_list1) if len(res_list1) < len(res_list2) else (res_list2, res_list1) 10 | M, N = len(res_short), len(res_long) 11 | 12 | def d(i, j): 13 | coord_i = res_short[i] 14 | coord_j = res_long[j] 15 | return ((coord_i - coord_j) ** 2).sum() 16 | 17 | SD = np.full([M, N], np.inf) 18 | for i in range(M): 19 | j = N - (M - i) 20 | SD[i, j] = sum([d(i + k, j + k) for k in range(N - j)]) 21 | 22 | for j in range(N): 23 | SD[M - 1, j] = d(M - 1, j) 24 | 25 | for i in range(M - 2, -1, -1): 26 | for j in range((N - (M - i)) - 1, -1, -1): 27 | SD[i, j] = min( 28 | d(i, j) + SD[i + 1, j + 1], 29 | SD[i, j + 1] 30 | ) 31 | 32 | min_SD = SD[0, :N - M + 1].min() 33 | best_RMSD = np.sqrt(min_SD / M) 34 | return best_RMSD 35 | 36 | 37 | def rmsd(pred, target, mask=None): 38 | assert pred.shape == target.shape 39 | if mask is None: 40 | mask = torch.ones_like(pred, dtype=torch.bool) 41 | 42 | rmsd = [] 43 | for p, t, m in zip(pred, target, mask): 44 | rmsd.append(luost_rmsd(p[m], t[m])) 45 | return np.mean(rmsd) 46 | 47 | 48 | def accuracy(pred, target, mask=None, reduction='all'): 49 | assert pred.shape == target.shape 50 | if mask is None: 51 | mask = torch.ones_like(pred, dtype=torch.bool) 52 | 53 | return (pred[mask] == target[mask]).sum() / mask.sum() 54 | 55 | def accuracy_per_sample(pred, target, mask=None): 56 | assert pred.shape == target.shape 57 | bsz = target.shape[0] 58 | 59 | if mask is None: 60 | mask = torch.ones_like(pred, dtype=torch.bool) 61 | 62 | pred = pred.view(bsz, -1) 63 | target = target.view(bsz, -1) 64 | mask = mask.view(bsz, -1) 65 | 66 | return ((pred == target) * mask).sum(1) / mask.sum(1) 67 | 68 | 69 | from tmtools import tm_align 70 | 71 | def calc_tm_score(pos_1, pos_2, seq_1, seq_2): 72 | tm_results = tm_align(np.float64(pos_1), np.float64(pos_2), seq_1, seq_2) 73 | return tm_results.tm_norm_chain1, tm_results.tm_norm_chain2 74 | 75 | -------------------------------------------------------------------------------- /src/byprot/models/seq2seq/modules/embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class Embedding(nn.Embedding): 8 | def __init__(self, vocab_size, d_model, padding_idx=None): 9 | super().__init__(vocab_size, d_model, padding_idx=padding_idx) 10 | self.vocab_size = vocab_size 11 | self.d_model = d_model 12 | self.padding_idx = padding_idx 13 | 14 | nn.init.normal_(self.weight, mean=0, std=self.d_model ** -0.5) 15 | nn.init.constant_(self.weight[padding_idx], 0) 16 | 17 | 18 | class PositionEmbedding(nn.Module): 19 | "Implement the PE function." 20 | 21 | def __init__(self, d_model, dropout=0, max_len=512): 22 | super().__init__() 23 | self.dropout = nn.Dropout(p=dropout) 24 | self.d_model = d_model 25 | self.scaling = self.d_model ** 0.5 26 | 27 | # Compute the positional encodings once in log space. 28 | pe = torch.zeros(max_len, d_model) 29 | position = torch.arange(0, max_len).unsqueeze(1) 30 | div_term = torch.exp( 31 | torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model) 32 | ) 33 | pe[:, 0::2] = torch.sin(position * div_term) 34 | pe[:, 1::2] = torch.cos(position * div_term) 35 | # [1, max_len, d_model] 36 | pe = pe.unsqueeze(0) 37 | self.register_buffer("pe", pe) 38 | 39 | def reset_parameters(self): 40 | pass 41 | 42 | def get_as(self, x): 43 | # x: [bsz, len, d_model] 44 | return self.pe[:, :x.size(1)].requires_grad_(False) 45 | 46 | def forward(self, x): 47 | x = x * self.scaling + self.get_as(x) 48 | return self.dropout(x) 49 | 50 | 51 | class LearnedPositionEmbedding(nn.Module): 52 | def __init__(self, d_model, dropout, max_len=512): 53 | super().__init__() 54 | self.dropout = nn.Dropout(p=dropout) 55 | self.pe = nn.Embedding(max_len, d_model) 56 | 57 | self.reset_parameters() 58 | 59 | def reset_parameters(self): 60 | self.pe.reset_parameters() 61 | 62 | def get_as(self, x): 63 | return self.pe(x) 64 | 65 | def forward(self, x): 66 | x = x + self.get_as(x) 67 | return self.dropout(x) 68 | 69 | 70 | registry = { 71 | 'default': PositionEmbedding, 72 | 'learned': LearnedPositionEmbedding 73 | } 74 | -------------------------------------------------------------------------------- /examples/pmpnn_compatible/helper_scripts/other_tools/make_omit_AA.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | import numpy as np 4 | import json 5 | import itertools 6 | 7 | #MODIFY this path 8 | with open('/home/justas/projects/lab_github/mpnn/data/pdbs.jsonl', 'r') as json_file: 9 | json_list = list(json_file) 10 | 11 | my_dict = {} 12 | for json_str in json_list: 13 | result = json.loads(json_str) 14 | all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain'] 15 | fixed_position_dict = {} 16 | print(result['name']) 17 | if result['name'] == '5TTA': 18 | for chain in all_chain_list: 19 | if chain == 'A': 20 | fixed_position_dict[chain] = [ 21 | [[int(item) for item in list(itertools.chain(list(np.arange(1,4)), list(np.arange(7,10)), [22, 25, 33]))], 'GPL'], 22 | [[int(item) for item in list(itertools.chain([40, 41, 42, 43]))], 'WC'], 23 | [[int(item) for item in list(itertools.chain(list(np.arange(50,150))))], 'ACEFGHIKLMNRSTVWYX'], 24 | [[int(item) for item in list(itertools.chain(list(np.arange(160,200))))], 'FGHIKLPQDMNRSTVWYX']] 25 | else: 26 | fixed_position_dict[chain] = [] 27 | else: 28 | for chain in all_chain_list: 29 | fixed_position_dict[chain] = [] 30 | my_dict[result['name']] = fixed_position_dict 31 | 32 | #MODIFY this path 33 | with open('/home/justas/projects/lab_github/mpnn/data/omit_AA.jsonl', 'w') as f: 34 | f.write(json.dumps(my_dict) + '\n') 35 | 36 | 37 | print('Finished') 38 | #e.g. output 39 | #{"5TTA": {"A": [[[1, 2, 3, 7, 8, 9, 22, 25, 33], "GPL"], [[40, 41, 42, 43], "WC"], [[50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149], "ACEFGHIKLMNRSTVWYX"], [[160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199], "FGHIKLPQDMNRSTVWYX"]], "B": []}, "3LIS": {"A": [], "B": []}} 40 | -------------------------------------------------------------------------------- /src/byprot/testing_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | from byprot.tasks import on_prediction_mode 4 | 5 | from torch import nn 6 | import hydra 7 | from omegaconf import DictConfig 8 | from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything 9 | from pytorch_lightning.loggers import LightningLoggerBase 10 | 11 | from byprot import utils 12 | 13 | log = utils.get_logger(__name__) 14 | 15 | 16 | def test(config: DictConfig) -> None: 17 | """Contains minimal example of the testing/prediction pipeline. Evaluates given checkpoint on a testset. 18 | 19 | Args: 20 | config (DictConfig): Configuration composed by Hydra. 21 | 22 | Returns: 23 | None 24 | """ 25 | 26 | # Set seed for random number generators in pytorch, numpy and python.random 27 | if config.get("seed"): 28 | seed_everything(config.seed, workers=True) 29 | 30 | # Convert relative ckpt path to absolute path if necessary 31 | if not os.path.isabs(config.ckpt_path): 32 | config.ckpt_path = utils.resolve_ckpt_path(ckpt_dir=config.paths.ckpt_dir, ckpt_path=config.ckpt_path) 33 | 34 | # loading pipeline 35 | datamodule, pl_module, logger, callbacks = utils.common_pipeline(config) 36 | 37 | # Init lightning trainer 38 | log.info(f"Instantiating trainer <{config.trainer._target_}>") 39 | trainer: Trainer = hydra.utils.instantiate(config.trainer, logger=logger, callbacks=callbacks) 40 | 41 | # Log hyperparameters 42 | if trainer.logger: 43 | trainer.logger.log_hyperparams({"ckpt_path": config.ckpt_path}) 44 | 45 | mode = config.mode 46 | 47 | # Start prediction 48 | log.info(f"Starting on mode='{mode}'!") 49 | 50 | # (1) Specify test dataset by configuring datamodule.test_split 51 | data_split = config.get('data_split') or config.datamodule.get('test_split', 'test') 52 | datamodule.hparams.test_split = data_split 53 | log.info(f"Loading test data from '{data_split}' dataset...") 54 | 55 | # Pytorch Lightning treat predict differently compared to what we commonly think of. 56 | # Must use this context manager and trainer.test to run prediction as expected. 57 | with on_prediction_mode(pl_module, enable=mode == 'predict'): 58 | trainer.test(model=pl_module, datamodule=datamodule, ckpt_path=config.ckpt_path) 59 | 60 | log.info(f"Finished mode='{mode}' on '{data_split}' dataset.") 61 | -------------------------------------------------------------------------------- /examples/pmpnn_compatible/helper_scripts/other_tools/make_pssm_dict.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | import glob 5 | import random 6 | import numpy as np 7 | import json 8 | 9 | 10 | def softmax(x, T): 11 | return np.exp(x/T)/np.sum(np.exp(x/T), -1, keepdims=True) 12 | 13 | def parse_pssm(path): 14 | data = pd.read_csv(path, skiprows=2) 15 | floats_list_list = [] 16 | for i in range(data.values.shape[0]): 17 | str1 = data.values[i][0][4:] 18 | floats_list = [] 19 | for item in str1.split(): 20 | floats_list.append(float(item)) 21 | floats_list_list.append(floats_list) 22 | np_lines = np.array(floats_list_list) 23 | return np_lines 24 | 25 | np_lines = parse_pssm('/home/swang523/RLcage/capsid/monomersfordesign/8-16-21/pssm_rainity_final_8-16-21_int/build_0.2089_0.98_0.4653_19_2.00_0.005745.pssm') 26 | 27 | mpnn_alphabet = 'ACDEFGHIKLMNPQRSTVWYX' 28 | input_alphabet = 'ARNDCQEGHILKMFPSTWYV' 29 | 30 | permutation_matrix = np.zeros([20,21]) 31 | for i in range(20): 32 | letter1 = input_alphabet[i] 33 | for j in range(21): 34 | letter2 = mpnn_alphabet[j] 35 | if letter1 == letter2: 36 | permutation_matrix[i,j]=1. 37 | 38 | pssm_log_odds = np_lines[:,:20] @ permutation_matrix 39 | pssm_probs = np_lines[:,20:40] @ permutation_matrix 40 | 41 | X_mask = np.concatenate([np.zeros([1,20]), np.ones([1,1])], -1) 42 | 43 | def softmax(x, T): 44 | return np.exp(x/T)/np.sum(np.exp(x/T), -1, keepdims=True) 45 | 46 | #Load parsed PDBs: 47 | with open('/home/justas/projects/cages/parsed/test.jsonl', 'r') as json_file: 48 | json_list = list(json_file) 49 | 50 | my_dict = {} 51 | for json_str in json_list: 52 | result = json.loads(json_str) 53 | all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain'] 54 | pssm_dict = {} 55 | for chain in all_chain_list: 56 | pssm_dict[chain] = {} 57 | pssm_dict[chain]['pssm_coef'] = (np.ones(len(result['seq_chain_A']))).tolist() #a number between 0.0 and 1.0 specifying how much attention put to PSSM, can be adjusted later as a flag 58 | pssm_dict[chain]['pssm_bias'] = (softmax(pssm_log_odds-X_mask*1e8, 1.0)).tolist() #PSSM like, [length, 21] such that sum over the last dimension adds up to 1.0 59 | pssm_dict[chain]['pssm_log_odds'] = (pssm_log_odds).tolist() 60 | my_dict[result['name']] = pssm_dict 61 | 62 | #Write output to: 63 | with open('/home/justas/projects/lab_github/mpnn/data/pssm_dict.jsonl', 'w') as f: 64 | f.write(json.dumps(my_dict) + '\n') 65 | -------------------------------------------------------------------------------- /examples/pmpnn_compatible/design_pdb.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | function rdebug() { 4 | MASTER_ADDR="$(ss | grep 2222 | head -n 1 | tr -s " " | cut -d" " -f 6 | sed -e "s/\[\|\]:[0-9]\+$//g")" 5 | # pgrep ssh$ | xargs kill -9 6 | # echo $MASTER_ADDR 7 | ssh -R 5678:localhost:5678 -N -f root@${MASTER_ADDR} -p 9000 > /dev/null 2>&1 8 | 9 | python3 -m debugpy --listen localhost:5678 "$@" 10 | } 11 | 12 | 13 | DIR="$(dirname "$0")" 14 | cd $DIR 15 | # model_path="/root/research/projects/ByProt/run/logs/fixedbb/cath_4.2/lm_design_esm2_650m" 16 | # model_path="/root/research/projects/ByProt/run/logs/fixedbb/cath_4.3/lm_design_esm2_650m_gvptrans" 17 | model_path="/root/research/projects/others/ByProt_public/logs/fixedbb_multichain/lm_design_esm2_650m" 18 | 19 | 20 | folder_with_pdbs="./inputs/PDB_complexes/pdbs/" 21 | 22 | output_dir="./outputs/example_5_outputs" 23 | if [ ! -d $output_dir ] 24 | then 25 | mkdir -p $output_dir 26 | fi 27 | 28 | 29 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl" 30 | path_for_assigned_chains=$output_dir"/assigned_pdbs.jsonl" 31 | path_for_fixed_positions=$output_dir"/fixed_pdbs.jsonl" 32 | path_for_tied_positions=$output_dir"/tied_pdbs.jsonl" 33 | chains_to_design="A C" 34 | fixed_positions="9 10 11 12 13 14 15 16 17 18 19 20 21 22 23, 10 11 18 19 20 22" 35 | tied_positions="1 2 3 4 5 6 7 8, 1 2 3 4 5 6 7 8" #two list must match in length; residue 1 in chain A and C will be sampled togther; 36 | 37 | python ./helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains 38 | 39 | python ./helper_scripts/assign_fixed_chains.py --input_path=$path_for_parsed_chains --output_path=$path_for_assigned_chains --chain_list "$chains_to_design" 40 | 41 | python ./helper_scripts/make_fixed_positions_dict.py --input_path=$path_for_parsed_chains --output_path=$path_for_fixed_positions --chain_list "$chains_to_design" --position_list "$fixed_positions" 42 | 43 | python ./helper_scripts/make_tied_positions_dict.py --input_path=$path_for_parsed_chains --output_path=$path_for_tied_positions --chain_list "$chains_to_design" --position_list "$tied_positions" 44 | 45 | rdebug ./design_pdb.py \ 46 | --experiment_path $model_path --ckpt "best.ckpt" \ 47 | --jsonl_path $path_for_parsed_chains \ 48 | --chain_id_jsonl $path_for_assigned_chains \ 49 | --fixed_positions_jsonl $path_for_fixed_positions \ 50 | --tied_positions_jsonl $path_for_tied_positions \ 51 | --out_dir $output_dir \ 52 | --seed 42 \ 53 | --num_seqs 2 \ 54 | --temperature 1.0 \ 55 | --max_iter 5 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | ### VisualStudioCode 131 | .vscode/* 132 | !.vscode/settings.json 133 | !.vscode/tasks.json 134 | !.vscode/launch.json 135 | !.vscode/extensions.json 136 | *.code-workspace 137 | **/.vscode 138 | 139 | # JetBrains 140 | .idea/ 141 | 142 | # Lightning-Hydra-Template 143 | configs/local/default.yaml 144 | configs/local/* 145 | !*/data 146 | /data/ 147 | logs/ 148 | wandb/ 149 | .env 150 | .autoenv 151 | workspace.ipynb 152 | run/logs 153 | 154 | # model weight 155 | *.ckpt 156 | 157 | # pdb 158 | *.pdb 159 | !examples/*.pdb -------------------------------------------------------------------------------- /examples/pmpnn_compatible/helper_scripts/make_fixed_positions_dict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def main(args): 4 | import glob 5 | import random 6 | import numpy as np 7 | import json 8 | import itertools 9 | 10 | with open(args.input_path, 'r') as json_file: 11 | json_list = list(json_file) 12 | 13 | fixed_list = [[int(item) for item in one.split()] for one in args.position_list.split(",")] 14 | global_designed_chain_list = [str(item) for item in args.chain_list.split()] 15 | my_dict = {} 16 | 17 | if not args.specify_non_fixed: 18 | for json_str in json_list: 19 | result = json.loads(json_str) 20 | all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain'] 21 | fixed_position_dict = {} 22 | for i, chain in enumerate(global_designed_chain_list): 23 | fixed_position_dict[chain] = fixed_list[i] 24 | for chain in all_chain_list: 25 | if chain not in global_designed_chain_list: 26 | fixed_position_dict[chain] = [] 27 | my_dict[result['name']] = fixed_position_dict 28 | else: 29 | for json_str in json_list: 30 | result = json.loads(json_str) 31 | all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain'] 32 | fixed_position_dict = {} 33 | for chain in all_chain_list: 34 | seq_length = len(result[f'seq_chain_{chain}']) 35 | all_residue_list = (np.arange(seq_length)+1).tolist() 36 | if chain not in global_designed_chain_list: 37 | fixed_position_dict[chain] = all_residue_list 38 | else: 39 | idx = np.argwhere(np.array(global_designed_chain_list) == chain)[0][0] 40 | fixed_position_dict[chain] = list(set(all_residue_list)-set(fixed_list[idx])) 41 | my_dict[result['name']] = fixed_position_dict 42 | 43 | with open(args.output_path, 'w') as f: 44 | f.write(json.dumps(my_dict) + '\n') 45 | 46 | #e.g. output 47 | #{"5TTA": {"A": [1, 2, 3, 7, 8, 9, 22, 25, 33], "B": []}, "3LIS": {"A": [], "B": []}} 48 | 49 | if __name__ == "__main__": 50 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 51 | argparser.add_argument("--input_path", type=str, help="Path to the parsed PDBs") 52 | argparser.add_argument("--output_path", type=str, help="Path to the output dictionary") 53 | argparser.add_argument("--chain_list", type=str, default='', help="List of the chains that need to be fixed") 54 | argparser.add_argument("--position_list", type=str, default='', help="Position lists, e.g. 11 12 14 18, 1 2 3 4 for first chain and the second chain") 55 | argparser.add_argument("--specify_non_fixed", action="store_true", default=False, help="Allows specifying just residues that need to be designed (default: false)") 56 | 57 | args = argparser.parse_args() 58 | main(args) 59 | 60 | -------------------------------------------------------------------------------- /src/byprot/models/seq2seq/transformer_encoder.py: -------------------------------------------------------------------------------- 1 | """ Transformer encoder """ 2 | 3 | import copy 4 | import torch 5 | from torch import nn, Tensor 6 | from torch.nn import functional as F 7 | 8 | from .modules.multihead_attention import MHA 9 | from .modules.ffn import FFN 10 | from .modules.utils import ResNorm, _get_clones 11 | 12 | 13 | class TransformerEncoderLayer(nn.Module): 14 | def __init__( 15 | self, 16 | d_model, 17 | n_heads, 18 | d_inner=2048, 19 | dropout=0.1, 20 | attn_dropout=0., 21 | normalize_before=False, 22 | ): 23 | super().__init__() 24 | 25 | self.self_attn = ResNorm( 26 | net=MHA(embed_dim=d_model, num_heads=n_heads, dropout=attn_dropout), 27 | fn=lambda net, x, *args, **kwargs: net(query=x, key=x, value=x, *args, **kwargs), 28 | dim=d_model, normalize_before=normalize_before, dropout=dropout 29 | ) 30 | 31 | self.ffn = ResNorm( 32 | net=FFN(d_model=d_model, d_inner=d_inner, dropout=dropout), 33 | dim=d_model, normalize_before=normalize_before, dropout=dropout 34 | ) 35 | 36 | self.reset_parameters() 37 | 38 | def reset_parameters(self): 39 | self.self_attn.reset_parameters() 40 | self.ffn.reset_parameters() 41 | 42 | def forward(self, x: Tensor, self_padding_mask: Tensor = None, attn_mask: Tensor = None): 43 | x, *others = self.self_attn(x, key_padding_mask=self_padding_mask, attn_mask=attn_mask) 44 | x = self.ffn(x) 45 | return x 46 | 47 | 48 | class TransformerEncoder(nn.Module): 49 | def __init__( 50 | self, 51 | n_layers, 52 | d_model, 53 | n_heads, 54 | d_inner=2048, 55 | dropout=0.1, 56 | attn_dropout=0., 57 | normalize_before=False, 58 | layer=None, 59 | ): 60 | super().__init__() 61 | 62 | if layer is None: 63 | layer = TransformerEncoderLayer( 64 | d_model, n_heads, d_inner, dropout, attn_dropout, normalize_before 65 | ) 66 | self.layers = _get_clones(layer, N=n_layers) 67 | 68 | self.norm = None 69 | if normalize_before: 70 | self.norm = nn.LayerNorm(d_model) 71 | 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | for layer in self.layers: layer.reset_parameters() 76 | 77 | def forward(self, x, padding_mask): 78 | out = x 79 | for layer in self.layers: 80 | out = layer(out, self_padding_mask=padding_mask) 81 | 82 | if self.norm is not None: 83 | out = self.norm(out) 84 | return out 85 | 86 | 87 | if __name__ == '__main__': 88 | 89 | B, L, D = 10, 7, 32 90 | encoder = TransformerEncoder(n_layers=6, d_model=D, n_heads=4, d_inner=2*D, normalize_before=True) 91 | 92 | x = torch.randn(B, L, D) 93 | 94 | key_padding_mask = ~(torch.arange(L)[None] < torch.randint(1, L, (B,))[:, None]) 95 | 96 | # attn_mask = torch.triu(torch.ones(10, 7, 7, dtype=torch.bool), 1) 97 | x = encoder(x, padding_mask=key_padding_mask) 98 | 99 | print(x) -------------------------------------------------------------------------------- /src/byprot/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Optimizer 2 | from torch.optim.lr_scheduler import LambdaLR 3 | import torch 4 | 5 | 6 | def get_scheduler(cfg, optimizer): 7 | if cfg.type is None: 8 | return BlackHole() 9 | elif cfg.type == 'plateau': 10 | return ( 11 | torch.optim.lr_scheduler.ReduceLROnPlateau( 12 | optimizer, 13 | mode=cfg.mode, 14 | factor=cfg.factor, 15 | patience=cfg.patience, 16 | min_lr=cfg.min_lr, 17 | ), 18 | {'monitor': "val/loss", 'interval': 'epoch'} 19 | ) 20 | elif cfg.type == 'noam': 21 | return ( 22 | NoamScheduler( 23 | optimizer, 24 | lr=cfg.lr, 25 | warmup_steps=cfg.warmup_steps, 26 | model_size=cfg.model_size, 27 | warmup_init_lr=cfg.get('warmup_init_lr') 28 | ), 29 | {'frequency': 1, 'interval': 'step'} 30 | ) 31 | elif cfg.type == 'multistep': 32 | return torch.optim.lr_scheduler.MultiStepLR( 33 | optimizer, 34 | milestones=cfg.milestones, 35 | gamma=cfg.gamma, 36 | ) 37 | elif cfg.type == 'exp': 38 | return torch.optim.lr_scheduler.ExponentialLR( 39 | optimizer, 40 | gamma=cfg.gamma, 41 | ) 42 | elif cfg.type is None: 43 | return BlackHole() 44 | else: 45 | raise NotImplementedError('Scheduler not supported: %s' % cfg.type) 46 | 47 | 48 | class BlackHole(object): 49 | def __setattr__(self, name, value): 50 | pass 51 | 52 | def __call__(self, *args, **kwargs): 53 | return self 54 | 55 | def __getattr__(self, name): 56 | return self 57 | 58 | 59 | def inverse_sqrt_lr_schedule(step, warmup_steps, warmup_init_lr, lr_step, decay_step): 60 | if step == 0: 61 | step = 1 62 | if step < warmup_steps: 63 | return warmup_init_lr + lr_step * step 64 | else: 65 | return decay_step * step ** -0.5 66 | 67 | 68 | class InverseSqrtLRScheduler(LambdaLR): 69 | def __init__( 70 | self, 71 | optimizer: Optimizer, 72 | warmup_steps: int = 0, 73 | lr: float = 5e-04, 74 | warmup_init_lr: float = 1e-07, 75 | ) -> None: 76 | 77 | self.warmup_init_lr = warmup_init_lr 78 | self.warmup_steps = warmup_steps 79 | self.lr_step = (lr - warmup_init_lr) / warmup_steps 80 | self.decay_step = lr * warmup_steps ** 0.5 81 | 82 | def lr_lambda(step): 83 | return inverse_sqrt_lr_schedule( 84 | step, warmup_steps, warmup_init_lr, self.lr_step, self.decay_step 85 | ) / lr 86 | 87 | super().__init__(optimizer, lr_lambda=lr_lambda) 88 | 89 | 90 | def noam_lr_schedule(step, warmup_steps, factor, model_size): 91 | if step == 0: 92 | step = 1 93 | return factor * (model_size ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5))) 94 | 95 | 96 | class NoamScheduler(LambdaLR): 97 | def __init__( 98 | self, 99 | optimizer: Optimizer, 100 | lr, 101 | warmup_init_lr, 102 | model_size: int = 128, 103 | warmup_steps: int = 0, 104 | factor: int = 2, 105 | ) -> None: 106 | 107 | # dummy_lr = self.base_lrs[0] 108 | def lr_lambda(step): 109 | return noam_lr_schedule(step, warmup_steps, factor, model_size) / lr 110 | 111 | super().__init__(optimizer, lr_lambda=lr_lambda) 112 | -------------------------------------------------------------------------------- /src/byprot/utils/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.adamw import adamw 3 | 4 | 5 | def get_optimizer(cfg, params): 6 | if cfg.type == 'adam': 7 | return torch.optim.Adam( 8 | params=params, 9 | lr=cfg.lr, 10 | weight_decay=cfg.weight_decay, 11 | betas=(cfg.beta1, cfg.beta2, ) 12 | ) 13 | elif cfg.type == 'adamw': 14 | return AdamW( 15 | params=params, 16 | lr=cfg.lr, 17 | weight_decay=cfg.weight_decay, 18 | betas=cfg.betas, 19 | ) 20 | else: 21 | raise NotImplementedError('Optimizer not supported: %s' % cfg.type) 22 | 23 | 24 | class AdamW(torch.optim.AdamW): 25 | @torch.no_grad() 26 | def step(self, closure=None): 27 | """Performs a single optimization step. 28 | 29 | Args: 30 | closure (callable, optional): A closure that reevaluates the model 31 | and returns the loss. 32 | """ 33 | self._cuda_graph_capture_health_check() 34 | 35 | loss = None 36 | if closure is not None: 37 | with torch.enable_grad(): 38 | loss = closure() 39 | 40 | for group in self.param_groups: 41 | params_with_grad = [] 42 | grads = [] 43 | exp_avgs = [] 44 | exp_avg_sqs = [] 45 | max_exp_avg_sqs = [] 46 | state_steps = [] 47 | amsgrad = group['amsgrad'] 48 | beta1, beta2 = group['betas'] 49 | 50 | for p in group['params']: 51 | if p.grad is None: 52 | continue 53 | params_with_grad.append(p) 54 | if p.grad.is_sparse: 55 | raise RuntimeError('AdamW does not support sparse gradients') 56 | grads.append(p.grad) 57 | 58 | state = self.state[p] 59 | 60 | # State initialization 61 | if len(state) == 0: 62 | state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ 63 | if self.defaults['capturable'] else torch.tensor(0.) 64 | # Exponential moving average of gradient values 65 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 66 | # Exponential moving average of squared gradient values 67 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 68 | if amsgrad: 69 | # Maintains max of all exp. moving avg. of sq. grad. values 70 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 71 | 72 | exp_avgs.append(state['exp_avg']) 73 | exp_avg_sqs.append(state['exp_avg_sq']) 74 | 75 | if amsgrad: 76 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 77 | 78 | state_steps.append(state['step'].cpu()) 79 | 80 | adamw(params_with_grad, 81 | grads, 82 | exp_avgs, 83 | exp_avg_sqs, 84 | max_exp_avg_sqs, 85 | state_steps, 86 | amsgrad=amsgrad, 87 | beta1=beta1, 88 | beta2=beta2, 89 | lr=group['lr'], 90 | weight_decay=group['weight_decay'], 91 | eps=group['eps'], 92 | maximize=group['maximize'], 93 | foreach=group['foreach'], 94 | capturable=group['capturable']) 95 | 96 | return loss 97 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.7 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.3.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-docstring-first 12 | - id: check-yaml 13 | - id: debug-statements 14 | - id: detect-private-key 15 | - id: check-toml 16 | - id: check-case-conflict 17 | - id: check-added-large-files 18 | args: [--maxkb=1000] 19 | - id: check-json 20 | - id: check-merge-conflict 21 | - id: check-shebang-scripts-are-executable 22 | - id: fix-byte-order-marker 23 | - id: fix-encoding-pragma 24 | args: [--remove] 25 | - id: mixed-line-ending 26 | args: [--fix=lf] 27 | 28 | # python code formatting 29 | - repo: https://github.com/psf/black 30 | rev: 22.6.0 31 | hooks: 32 | - id: black 33 | args: [--line-length, "79"] 34 | language_version: python3.7 35 | 36 | # python import sorting 37 | # - repo: https://github.com/PyCQA/isort 38 | # rev: 5.12.0 39 | # hooks: 40 | # - id: isort 41 | # args: 42 | # [ 43 | # "--line-length=79", 44 | # "--multi-line=3", 45 | # "--profile=black", 46 | # "--filter-files", 47 | # ] 48 | # language_version: python3.7 49 | 50 | # python upgrading syntax to newer version 51 | # - repo: https://github.com/asottile/pyupgrade 52 | # rev: v2.32.1 53 | # hooks: 54 | # - id: pyupgrade 55 | # args: [--py38-plus] 56 | 57 | # python docstring formatting 58 | - repo: https://github.com/myint/docformatter 59 | rev: v1.4 60 | hooks: 61 | - id: docformatter 62 | args: [--in-place, --wrap-summaries=79, --wrap-descriptions=79] 63 | 64 | # python check (PEP8), programming errors and code complexity 65 | - repo: https://github.com/PyCQA/flake8 66 | rev: 4.0.1 67 | hooks: 68 | - id: flake8 69 | args: 70 | [ 71 | "--extend-ignore", 72 | "E203,E402,E501,F401,F841", 73 | "--exclude", 74 | "logs/*,data/*", 75 | ] 76 | 77 | # python security linter 78 | - repo: https://github.com/PyCQA/bandit 79 | rev: "1.7.1" 80 | hooks: 81 | - id: bandit 82 | args: ["-s", "B101"] 83 | 84 | # yaml formatting 85 | - repo: https://github.com/pre-commit/mirrors-prettier 86 | rev: v2.7.1 87 | hooks: 88 | - id: prettier 89 | types: [yaml] 90 | 91 | # shell scripts linter 92 | - repo: https://github.com/shellcheck-py/shellcheck-py 93 | rev: v0.8.0.4 94 | hooks: 95 | - id: shellcheck 96 | 97 | # md formatting 98 | - repo: https://github.com/executablebooks/mdformat 99 | rev: 0.7.14 100 | hooks: 101 | - id: mdformat 102 | args: ["--number"] 103 | additional_dependencies: 104 | - mdformat-gfm 105 | - mdformat-tables 106 | - mdformat_frontmatter 107 | # - mdformat-toc 108 | # - mdformat-black 109 | 110 | # jupyter notebook cell output clearing 111 | # - repo: https://github.com/kynan/nbstripout 112 | # rev: 0.5.0 113 | # hooks: 114 | # - id: nbstripout 115 | 116 | # jupyter notebook linting 117 | - repo: https://github.com/nbQA-dev/nbQA 118 | rev: 1.4.0 119 | hooks: 120 | - id: nbqa-black 121 | args: ["--line-length=79"] 122 | - id: nbqa-isort 123 | args: ["--profile=black"] 124 | - id: nbqa-flake8 125 | args: 126 | [ 127 | "--extend-ignore=E203,E402,E501,F401,F841", 128 | "--exclude=logs/*,data/*", 129 | ] 130 | -------------------------------------------------------------------------------- /src/byprot/training_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional 3 | 4 | from torch import nn 5 | import hydra 6 | from omegaconf import DictConfig 7 | from pytorch_lightning import ( 8 | Callback, 9 | LightningDataModule, 10 | LightningModule, 11 | Trainer, 12 | seed_everything, 13 | 14 | ) 15 | from pytorch_lightning.loggers import LightningLoggerBase 16 | 17 | from byprot import utils 18 | 19 | log = utils.get_logger(__name__) 20 | 21 | 22 | 23 | 24 | def train(config: DictConfig) -> Optional[float]: 25 | """Contains the training pipeline. Can additionally evaluate model on a testset, using best 26 | weights achieved during training. 27 | 28 | Args: 29 | config (DictConfig): Configuration composed by Hydra. 30 | 31 | Returns: 32 | Optional[float]: Metric score for hyperparameter optimization. 33 | """ 34 | 35 | # Set seed for random number generators in pytorch, numpy and python.random 36 | if config.get("seed"): 37 | seed_everything(config.seed, workers=True) 38 | 39 | # Convert relative ckpt path to absolute path if necessary 40 | ckpt_path = not config.train.get("force_restart", False) and config.train.get("ckpt_path") 41 | if ckpt_path: 42 | ckpt_path = utils.resolve_ckpt_path(ckpt_dir=config.paths.ckpt_dir, ckpt_path=ckpt_path) 43 | if os.path.exists(ckpt_path): 44 | log.info(f"Resuming checkpoint from <{ckpt_path}>") 45 | else: 46 | log.info(f"Failed to resume checkpoint from <{ckpt_path}>: file not exists. Skip.") 47 | ckpt_path = None 48 | 49 | # loading pipeline 50 | datamodule, pl_module, logger, callbacks = utils.common_pipeline(config) 51 | 52 | # Init lightning trainer 53 | log.info(f"Instantiating trainer <{config.trainer._target_}>") 54 | trainer: Trainer = hydra.utils.instantiate( 55 | config.trainer, callbacks=callbacks, logger=logger, _convert_="partial" 56 | ) 57 | 58 | # Send some parameters from config to all lightning loggers 59 | log.info("Logging hyperparameters!") 60 | utils.log_hyperparameters( 61 | config=config, 62 | datamodule=datamodule, 63 | # model=model, 64 | model=pl_module, 65 | trainer=trainer, 66 | callbacks=callbacks, 67 | logger=logger, 68 | ) 69 | 70 | # Train the model 71 | if config.get("train"): 72 | log.info("Starting training!") 73 | trainer.fit(model=pl_module, datamodule=datamodule, ckpt_path=ckpt_path) 74 | 75 | # Get metric score for hyperparameter optimization 76 | optimized_metric = config.get("optimized_metric") 77 | if optimized_metric and optimized_metric not in trainer.callback_metrics: 78 | raise Exception( 79 | "Metric for hyperparameter optimization not found! " 80 | "Make sure the `optimized_metric` in `hparams_search` config is correct!" 81 | ) 82 | score = trainer.callback_metrics.get(optimized_metric) 83 | 84 | # Test the model 85 | if config.get("test"): 86 | log.info("Starting testing!") 87 | best_ckpt_path = os.path.join(config.paths.ckpt_dir, 'best.ckpt') 88 | trainer.test(model=pl_module, datamodule=datamodule, ckpt_path=best_ckpt_path) 89 | 90 | # Make sure everything closed properly 91 | log.info("Finalizing!") 92 | utils.finish( 93 | config=config, 94 | model=pl_module, 95 | datamodule=datamodule, 96 | trainer=trainer, 97 | callbacks=callbacks, 98 | logger=logger, 99 | ) 100 | 101 | # Print path to best checkpoint 102 | if not config.trainer.get("fast_dev_run") and config.get("train"): 103 | log.info(f"Best model ckpt at {trainer.checkpoint_callback.best_model_path}") 104 | 105 | # Return metric score for hyperparameter optimization 106 | return score 107 | -------------------------------------------------------------------------------- /src/byprot/utils/config.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | from contextlib import contextmanager 4 | from copy import deepcopy 5 | from pathlib import Path 6 | from typing import Any, List, Sequence 7 | import logging 8 | from pytorch_lightning.utilities import rank_zero_only 9 | 10 | import hydra 11 | from omegaconf import DictConfig, OmegaConf 12 | 13 | def get_logger(name=__name__) -> logging.Logger: 14 | """Initializes multi-GPU-friendly python command line logger.""" 15 | 16 | logger = logging.getLogger(name) 17 | 18 | # this ensures all logging levels get marked with the rank zero decorator 19 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 20 | for level in ( 21 | "debug", 22 | "info", 23 | "warning", 24 | "error", 25 | "exception", 26 | "fatal", 27 | "critical", 28 | ): 29 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 30 | 31 | return logger 32 | 33 | 34 | log = get_logger(__name__) 35 | 36 | 37 | def make_config(**kwargs): 38 | return OmegaConf.structured(kwargs) 39 | 40 | 41 | def compose_config(**kwds): 42 | return OmegaConf.create(kwds) 43 | 44 | 45 | def merge_config(default_cfg, override_cfg): 46 | return OmegaConf.merge(default_cfg, override_cfg) 47 | 48 | 49 | def load_yaml_config(fpath: str) -> OmegaConf: 50 | return OmegaConf.load(fpath) 51 | 52 | 53 | def parse_cli_override_args(): 54 | _overrides = OmegaConf.from_cli() 55 | print(_overrides) 56 | overrides = compose_config(**{kk if not kk.startswith('+') else kk[1:]: vv for kk, vv in _overrides.items()}) 57 | return overrides 58 | 59 | 60 | def resolve_experiment_config(config: DictConfig): 61 | # Load train config from existing Hydra experiment 62 | if config.experiment_path is not None: 63 | config.experiment_path = hydra.utils.to_absolute_path(config.experiment_path) 64 | experiment_config = OmegaConf.load(os.path.join(config.experiment_path, '.hydra', 'config.yaml')) 65 | from omegaconf import open_dict 66 | with open_dict(config): 67 | config.datamodule = experiment_config.datamodule 68 | config.model = experiment_config.model 69 | config.task = experiment_config.task 70 | config.train = experiment_config.train 71 | config.paths = experiment_config.paths 72 | config.name = experiment_config.name 73 | config.paths.log_dir = config.experiment_path 74 | 75 | # deal with override args 76 | cli_overrides = parse_cli_override_args() 77 | config = merge_config(config, cli_overrides) 78 | print(cli_overrides) 79 | # chagne current directory 80 | os.chdir(config.paths.log_dir) 81 | return config 82 | 83 | 84 | def _convert_target_to_string(t: Any) -> Any: 85 | if callable(t): 86 | return f"{t.__module__}.{t.__qualname__}" 87 | else: 88 | return t 89 | 90 | 91 | def get_obj_from_str(string, reload=False): 92 | module, cls = string.rsplit(".", 1) 93 | if reload: 94 | module_imp = importlib.import_module(module) 95 | importlib.reload(module_imp) 96 | return getattr(importlib.import_module(module, package=None), cls) 97 | 98 | 99 | def instantiate_from_config(cfg: OmegaConf, group=None, **override_kwargs): 100 | if "_target_" not in cfg: 101 | raise KeyError("Expected key `_target_` to instantiate.") 102 | 103 | if group is None: 104 | return hydra.utils.instantiate(cfg, **override_kwargs) 105 | else: 106 | from . import registry 107 | _target_ = cfg.pop('_target_') 108 | target = registry.get_module(group_name=group, module_name=_target_) 109 | if target is None: 110 | raise KeyError( 111 | f'{_target_} is not a registered <{group}> class [{registry.get_registered_modules(group)}].') 112 | target = _convert_target_to_string(target) 113 | log.info(f" Resolving {group} <{_target_}> -> <{target}>") 114 | 115 | target_cls = get_obj_from_str(target) 116 | try: 117 | return target_cls(**cfg, **override_kwargs) 118 | except: 119 | cfg = merge_config(cfg, override_kwargs) 120 | return target_cls(cfg) -------------------------------------------------------------------------------- /src/byprot/models/fixedbb/pifold/pifold.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from byprot.models import register_model 5 | from byprot.models.fixedbb import FixedBackboneDesignEncoderDecoder 6 | from byprot.datamodules.datasets.data_utils import Alphabet 7 | from byprot.models.fixedbb.generator import new_arange, sample_from_categorical 8 | from torch.nn import functional as F 9 | 10 | from .model import PiFoldModel 11 | 12 | 13 | @dataclass 14 | class PiFoldConfig: 15 | display_step: int = 10 16 | d_model: int = 128 17 | hidden_dim: int = 128 18 | node_features: int = 128 19 | edge_features: int = 128 20 | k_neighbors: int = 30 21 | dropout: float = 0.1 22 | num_encoder_layers: int = 10 23 | updating_edges: int = 4 24 | node_dist: int = 1 25 | node_angle: int = 1 26 | node_direct: int = 1 27 | edge_dist: int = 1 28 | edge_angle: int = 1 29 | edge_direct: int = 1 30 | virtual_num: int = 3 31 | 32 | n_vocab: int = 22 33 | use_esm_alphabet: bool = False 34 | 35 | @register_model('pifold') 36 | class PiFold(FixedBackboneDesignEncoderDecoder): 37 | _default_cfg = PiFoldConfig() 38 | 39 | def __init__(self, cfg) -> None: 40 | super().__init__(cfg) 41 | 42 | if self.cfg.use_esm_alphabet: 43 | alphabet = Alphabet('esm') 44 | self.padding_idx = alphabet.padding_idx 45 | self.mask_idx = alphabet.mask_idx 46 | self.cfg.n_vocab = len(alphabet) 47 | else: 48 | alphabet = None 49 | self.padding_idx = 0 50 | self.mask_idx = 1 51 | 52 | self.model = PiFoldModel(args=self.cfg) 53 | 54 | def forward(self, batch, return_feats=False, **kwargs): 55 | logits, feats = self.model( 56 | X=batch['coords'], 57 | mask=batch['coord_mask'].float(), 58 | S=batch['prev_tokens'], 59 | lengths=batch.get('lengths', None)) 60 | 61 | if return_feats: 62 | return logits, feats 63 | return logits 64 | 65 | def forward_encoder(self, batch): 66 | encoder_out = self.model.encode( 67 | X=batch['coords'], 68 | mask=batch['coord_mask'].float(), 69 | lengths=batch['lengths'] 70 | ) 71 | encoder_out['coord_mask'] = batch['coord_mask'].float() 72 | 73 | return encoder_out 74 | 75 | def forward_decoder(self, prev_decoder_out, encoder_out): 76 | output_tokens = prev_decoder_out['output_tokens'] 77 | output_scores = prev_decoder_out['output_scores'] 78 | step, max_step = prev_decoder_out['step'], prev_decoder_out['max_step'] 79 | temperature = prev_decoder_out['temperature'] 80 | history = prev_decoder_out['history'] 81 | 82 | output_masks = output_tokens.eq(self.mask_idx) # & coord_mask 83 | 84 | logits, _ = self.model.decode( 85 | prev_tokens=output_tokens, 86 | encoder_out=encoder_out, 87 | ) 88 | _tokens, _scores = sample_from_categorical(logits, temperature=temperature) 89 | 90 | output_tokens.masked_scatter_(output_masks, _tokens[output_masks]) 91 | output_scores.masked_scatter_(output_masks, _scores[output_masks]) 92 | 93 | history.append(output_tokens.clone()) 94 | 95 | return dict( 96 | output_tokens=output_tokens, 97 | output_scores=output_scores, 98 | step=step + 1, 99 | max_step=max_step, 100 | history=history 101 | ) 102 | 103 | def initialize_output_tokens(self, batch, encoder_out): 104 | # mask = encoder_out.get('coord_mask', None) 105 | 106 | prev_tokens = batch['prev_tokens'] 107 | lengths = prev_tokens.ne(self.padding_idx).sum(1) 108 | 109 | initial_output_tokens = torch.full_like(prev_tokens, self.padding_idx) 110 | initial_output_tokens.masked_fill_(new_arange(prev_tokens) < lengths[:, None], self.mask_idx) 111 | 112 | # if mask is not None: 113 | # initial_output_tokens = torch.where( 114 | # ~mask, prev_tokens, initial_output_tokens 115 | # ) 116 | # initial_output_tokens = prev_tokens.clone() 117 | 118 | initial_output_scores = torch.zeros( 119 | *initial_output_tokens.size(), device=initial_output_tokens.device 120 | ) 121 | 122 | return initial_output_tokens, initial_output_scores 123 | -------------------------------------------------------------------------------- /scripts/design_pdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from copy import deepcopy 3 | import glob 4 | import logging 5 | import os 6 | import random 7 | from pathlib import Path 8 | from pprint import pprint 9 | import time 10 | 11 | import numpy as np 12 | import torch 13 | from torch.cuda.amp import autocast 14 | from omegaconf import DictConfig, OmegaConf 15 | import pytorch_lightning as pl 16 | 17 | from byprot import utils 18 | from byprot.datamodules.datasets import DataProcessor as PDBDataProcessor 19 | from byprot.models.fixedbb.generator import IterativeRefinementGenerator 20 | from byprot.utils import io 21 | from byprot.utils.config import compose_config as Cfg 22 | from tqdm.auto import tqdm 23 | 24 | 25 | logger = logging.getLogger(__name__) 26 | logger.setLevel(logging.DEBUG) 27 | 28 | 29 | from collections import namedtuple 30 | 31 | GenOut = namedtuple( 32 | 'GenOut', 33 | ['output_tokens', 'output_scores', 'attentions'] 34 | ) 35 | 36 | def setup_generation(args, ckpt): 37 | pl.seed_everything(args.seed) 38 | 39 | pl_module, exp_cfg = utils.load_from_experiment( 40 | args.experiment_path, ckpt=ckpt) 41 | model = pl_module.model 42 | alphabet = pl_module.alphabet 43 | collater = alphabet.featurize 44 | generator = IterativeRefinementGenerator( 45 | alphabet=alphabet, 46 | max_iter=args.max_iter, 47 | strategy=args.strategy, 48 | temperature=args.temperature 49 | ) 50 | return model.eval(), alphabet, collater, generator 51 | 52 | 53 | def _full_mask(target_tokens, coord_mask, alphabet): 54 | target_mask = ( 55 | target_tokens.ne(alphabet.padding_idx) # & mask 56 | & target_tokens.ne(alphabet.cls_idx) 57 | & target_tokens.ne(alphabet.eos_idx) 58 | ) 59 | _tokens = target_tokens.masked_fill( 60 | target_mask, alphabet.mask_idx 61 | ) 62 | _mask = _tokens.eq(alphabet.mask_idx) & coord_mask 63 | return _tokens, _mask 64 | 65 | 66 | def prepare_data(pdb_path, alphabet, collator, num_seqs, device): 67 | pdb_id = Path(pdb_path).stem 68 | structure = PDBDataProcessor().parse_PDB(pdb_path) 69 | batch = collator( 70 | [ 71 | deepcopy(structure) for idx in range(num_seqs) 72 | ] 73 | ) 74 | prev_tokens, prev_token_mask = _full_mask( 75 | batch['tokens'], batch['coord_mask'], alphabet 76 | ) 77 | batch['prev_tokens'] = prev_tokens 78 | batch['prev_token_mask'] = prev_tokens.eq(alphabet.mask_idx) 79 | batch = utils.recursive_to(batch, device=device) 80 | return batch, structure['seq'] 81 | 82 | 83 | def generate(args): 84 | model, alphabet, collater, generator = setup_generation(args, args.ckpt) 85 | model = model.cuda(); 86 | device = next(model.parameters()).device 87 | 88 | Path(args.out_dir).mkdir(parents=True, exist_ok=True) 89 | 90 | st = time.time() 91 | pbar = tqdm(glob.glob(f"{args.pdb_dir}/*.pdb")) 92 | for pdb_path in pbar: 93 | pdb_id = Path(pdb_path).stem 94 | fp_saveto_fasta = open(os.path.join(args.out_dir, f"{pdb_id}.fasta"), 'w') 95 | pbar.set_description_str(f"{pdb_id}") 96 | 97 | batch, native_seq = prepare_data( 98 | pdb_path, alphabet, collater, 99 | num_seqs=args.num_seqs, device=device 100 | ) 101 | 102 | with autocast(): 103 | outputs = generator.generate(model=model, batch=batch) 104 | output_tokens = outputs[0] 105 | 106 | # print('final:') 107 | # pprint(alphabet.decode(output_tokens, remove_special=False)) 108 | 109 | recs = [] 110 | for idx, seq in enumerate( 111 | alphabet.decode(output_tokens, remove_special=True) 112 | ): 113 | rec = np.mean([(a==b) for a, b in zip(native_seq, seq)]) 114 | fp_saveto_fasta.write( 115 | f">{pdb_id}: seq_{idx}, recovery={rec}\n") 116 | fp_saveto_fasta.write(f"{seq}\n") 117 | recs.append(rec) 118 | fp_saveto_fasta.close() 119 | print(f"Eta: {time.time() - st}. AAR: {np.mean(recs)}") 120 | 121 | def main(): 122 | parser = argparse.ArgumentParser() 123 | 124 | parser.add_argument('--seed', type=int, default=42) 125 | parser.add_argument('--num_seqs', type=int, default=20) 126 | parser.add_argument('--experiment_path', type=str) 127 | parser.add_argument('--ckpt', type=str, default='best.ckpt') 128 | parser.add_argument('--pdb_dir', type=str, default='./pdbs') 129 | parser.add_argument('--out_dir', type=str, default='./outputs') 130 | parser.add_argument('--temperature', type=float, default=1.0) 131 | parser.add_argument('--strategy', type=str, default='denoise') 132 | parser.add_argument('--max_iter', type=int, default=5) 133 | args = parser.parse_args() 134 | pprint(args) 135 | 136 | generate(args) 137 | 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /src/byprot/utils/io.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import json 7 | import math 8 | 9 | import biotite.structure 10 | from biotite.structure.io import pdbx, pdb 11 | from biotite.structure.residues import get_residues 12 | from biotite.structure import filter_backbone 13 | from biotite.structure import get_chains 14 | from biotite.sequence import ProteinSequence 15 | import numpy as np 16 | from scipy.spatial import transform 17 | from scipy.stats import special_ortho_group 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.utils.data as data 22 | from typing import Sequence, Tuple, List 23 | 24 | 25 | from biotite.structure import filter_amino_acids 26 | 27 | 28 | def filter_backbone2(array): 29 | """ 30 | Filter all peptide backbone atoms of one array. 31 | 32 | This includes the "N", "CA" and "C" atoms of amino acids. 33 | 34 | Parameters 35 | ---------- 36 | array : AtomArray or AtomArrayStack 37 | The array to be filtered. 38 | 39 | Returns 40 | ------- 41 | filter : ndarray, dtype=bool 42 | This array is `True` for all indices in `array`, where the atom 43 | as an backbone atom. 44 | """ 45 | return ( ((array.atom_name == "N") | 46 | (array.atom_name == "CA") | 47 | (array.atom_name == "C") | 48 | (array.atom_name == "O")) & 49 | filter_amino_acids(array) ) 50 | 51 | def load_structure(fpath, chain=None): 52 | """ 53 | Args: 54 | fpath: filepath to either pdb or cif file 55 | chain: the chain id or list of chain ids to load 56 | Returns: 57 | biotite.structure.AtomArray 58 | """ 59 | if fpath.endswith('cif'): 60 | with open(fpath) as fin: 61 | pdbxf = pdbx.PDBxFile.read(fin) 62 | structure = pdbx.get_structure(pdbxf, model=1) 63 | elif fpath.endswith('pdb'): 64 | with open(fpath) as fin: 65 | pdbf = pdb.PDBFile.read(fin) 66 | structure = pdb.get_structure(pdbf, model=1) 67 | # bbmask = filter_backbone(structure) 68 | bbmask = filter_backbone2(structure) 69 | structure = structure[bbmask] 70 | all_chains = get_chains(structure) 71 | if len(all_chains) == 0: 72 | raise ValueError('No chains found in the input file.') 73 | if chain is None: 74 | chain_ids = all_chains 75 | elif isinstance(chain, list): 76 | chain_ids = chain 77 | else: 78 | chain_ids = [chain] 79 | for chain in chain_ids: 80 | if chain not in all_chains: 81 | raise ValueError(f'Chain {chain} not found in input file') 82 | chain_filter = [a.chain_id in chain_ids for a in structure] 83 | structure = structure[chain_filter] 84 | return structure 85 | 86 | 87 | def extract_coords_from_structure(structure: biotite.structure.AtomArray, atoms=["N", "CA", "C"]): 88 | """ 89 | Args: 90 | structure: An instance of biotite AtomArray 91 | atoms: default ["N", "CA", "C"] 92 | Returns: 93 | Tuple (coords, seq) 94 | - coords is an L x 3 x 3 array for N, CA, C coordinates 95 | - seq is the extracted sequence 96 | """ 97 | # coords = get_atom_coords_residuewise(["N", "CA", "C"], structure) 98 | coords = get_atom_coords_residuewise(atoms, structure) 99 | residue_identities = get_residues(structure)[1] 100 | seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities]) 101 | return coords, seq 102 | 103 | 104 | def load_coords(fpath, chain, atoms=["N", "CA", "C", "O"]): 105 | """ 106 | Args: 107 | fpath: filepath to either pdb or cif file 108 | chain: the chain id 109 | Returns: 110 | Tuple (coords, seq) 111 | - coords is an L x 3 x 3 array for N, CA, C coordinates 112 | - seq is the extracted sequence 113 | """ 114 | structure = load_structure(fpath, chain) 115 | return extract_coords_from_structure(structure, atoms=atoms) 116 | 117 | 118 | def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray): 119 | """ 120 | Example for atoms argument: ["N", "CA", "C"] 121 | """ 122 | def filterfn(s, axis=None): 123 | filters = np.stack([s.atom_name == name for name in atoms], axis=1) 124 | sum = filters.sum(0) 125 | if not np.all(sum <= np.ones(filters.shape[1])): 126 | raise RuntimeError("structure has multiple atoms with same name") 127 | index = filters.argmax(0) 128 | coords = s[index].coord 129 | coords[sum == 0] = float("nan") 130 | return coords 131 | 132 | return biotite.structure.apply_residue_wise(struct, struct, filterfn) 133 | 134 | 135 | def save_pdb(path, coords, seq): 136 | pass -------------------------------------------------------------------------------- /src/byprot/models/fixedbb/lm_design/esm_adapter_pifold.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List 3 | 4 | import torch 5 | from byprot.models import register_model 6 | from byprot.models.fixedbb import FixedBackboneDesignEncoderDecoder 7 | from byprot.models.fixedbb.generator import new_arange, sample_from_categorical 8 | from byprot.models.fixedbb.pifold import PiFold, PiFoldConfig 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | from .modules.esm_adapter import ProteinBertModelWithStructuralAdatper 13 | 14 | 15 | @dataclass 16 | class ESMAdapterPiFoldConfig: 17 | encoder: PiFoldConfig = field(default=PiFoldConfig()) 18 | adapter_layer_indices: List = field(default_factory=lambda: [32, ]) 19 | separate_loss: bool = False 20 | 21 | 22 | @register_model('esm_adapter_pifold') 23 | class ESMAdapterPiFold(FixedBackboneDesignEncoderDecoder): 24 | # _default_cfg = ESMAdapterPiFoldConfig() 25 | 26 | def __init__(self, cfg) -> None: 27 | super().__init__(cfg) 28 | 29 | self.encoder = PiFold(self.cfg.encoder) 30 | self.decoder = ProteinBertModelWithStructuralAdatper.from_pretrained(args=self.cfg) 31 | 32 | self.padding_idx = self.decoder.padding_idx 33 | self.mask_idx = self.decoder.mask_idx 34 | # self.cls_idx = self.decoder.cls_idx 35 | # self.eos_idx = self.decoder.eos_idx 36 | 37 | def forward(self, batch, **kwargs): 38 | encoder_logits, encoder_out = self.encoder(batch, return_feats=True, **kwargs) 39 | 40 | encoder_out['feats'] = encoder_out['feats'].detach() 41 | 42 | init_pred = encoder_logits.argmax(-1) 43 | init_pred = torch.where(batch['coord_mask'], init_pred, batch['prev_tokens']) 44 | 45 | esm_logits = self.decoder( 46 | tokens=init_pred, 47 | encoder_out=encoder_out, 48 | )['logits'] 49 | 50 | if not getattr(self.cfg, 'separate_loss', False): 51 | logits = encoder_logits + esm_logits 52 | return logits 53 | else: 54 | return esm_logits, encoder_logits 55 | 56 | def forward_encoder(self, batch): 57 | encoder_logits, encoder_out = self.encoder( 58 | batch['coords'], batch['coord_mask'], batch['prev_tokens'], 59 | return_feats=True 60 | ) 61 | 62 | init_pred = encoder_logits.argmax(-1) 63 | init_pred = torch.where(batch['coord_mask'], init_pred, batch['prev_tokens']) 64 | 65 | encoder_out['logits'] = encoder_logits 66 | encoder_out['init_pred'] = init_pred 67 | encoder_out['coord_mask'] = batch['coord_mask'] 68 | return encoder_out 69 | 70 | def forward_decoder(self, prev_decoder_out, encoder_out, need_attn_weights=False): 71 | output_tokens = prev_decoder_out['output_tokens'] 72 | output_scores = prev_decoder_out['output_scores'] 73 | step, max_step = prev_decoder_out['step'], prev_decoder_out['max_step'] 74 | history = prev_decoder_out['history'] 75 | 76 | # output_masks = output_tokens.eq(self.mask_idx) # & coord_mask 77 | output_masks = output_tokens.ne(self.padding_idx) # & coord_mask 78 | 79 | esm_logits = self.decoder( 80 | # tokens=encoder_out['init_pred'], 81 | tokens=output_tokens, 82 | encoder_out=encoder_out, 83 | )['logits'] 84 | 85 | if not getattr(self.cfg, 'separate_loss', False): 86 | logits = esm_logits + encoder_out['logits'] 87 | else: 88 | logits = esm_logits # + encoder_out['logits'] 89 | 90 | _tokens, _scores = sample_from_categorical(logits, temperature=None) 91 | output_tokens.masked_scatter_(output_masks, _tokens[output_masks]) 92 | output_scores.masked_scatter_(output_masks, _scores[output_masks]) 93 | 94 | history.append(output_tokens.clone()) 95 | 96 | return dict( 97 | output_tokens=output_tokens, 98 | output_scores=output_scores, 99 | step=step + 1, 100 | max_step=max_step, 101 | history=history 102 | ) 103 | 104 | def initialize_output_tokens(self, batch, encoder_out): 105 | mask = encoder_out.get('coord_mask', None) 106 | 107 | prev_tokens = batch['prev_tokens'] 108 | lengths = prev_tokens.ne(self.padding_idx).sum(1) 109 | 110 | initial_output_tokens = torch.full_like(prev_tokens, self.padding_idx) 111 | initial_output_tokens.masked_fill_(new_arange(prev_tokens) < lengths[:, None], self.mask_idx) 112 | # initial_output_tokens[:, 0] = self.cls_idx 113 | # initial_output_tokens.scatter_(1, lengths[:, None] - 1, self.eos_idx) 114 | 115 | initial_output_tokens = encoder_out['init_pred'].clone() 116 | initial_output_scores = torch.zeros( 117 | *initial_output_tokens.size(), device=initial_output_tokens.device 118 | ) 119 | 120 | return initial_output_tokens, initial_output_scores 121 | -------------------------------------------------------------------------------- /examples/pmpnn_compatible/helper_scripts/make_tied_positions_dict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def main(args): 4 | 5 | import glob 6 | import random 7 | import numpy as np 8 | import json 9 | import itertools 10 | 11 | with open(args.input_path, 'r') as json_file: 12 | json_list = list(json_file) 13 | 14 | homooligomeric_state = args.homooligomer 15 | 16 | if homooligomeric_state == 0: 17 | tied_list = [[int(item) for item in one.split()] for one in args.position_list.split(",")] 18 | global_designed_chain_list = [str(item) for item in args.chain_list.split()] 19 | my_dict = {} 20 | for json_str in json_list: 21 | result = json.loads(json_str) 22 | all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain']) #A, B, C, ... 23 | tied_positions_list = [] 24 | for i, pos in enumerate(tied_list[0]): 25 | temp_dict = {} 26 | for j, chain in enumerate(global_designed_chain_list): 27 | temp_dict[chain] = [tied_list[j][i]] #needs to be a list 28 | tied_positions_list.append(temp_dict) 29 | my_dict[result['name']] = tied_positions_list 30 | else: 31 | my_dict = {} 32 | for json_str in json_list: 33 | result = json.loads(json_str) 34 | all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain']) #A, B, C, ... 35 | tied_positions_list = [] 36 | chain_length = len(result[f"seq_chain_{all_chain_list[0]}"]) 37 | for i in range(1,chain_length+1): 38 | temp_dict = {} 39 | for j, chain in enumerate(all_chain_list): 40 | temp_dict[chain] = [i] #needs to be a list 41 | tied_positions_list.append(temp_dict) 42 | my_dict[result['name']] = tied_positions_list 43 | 44 | with open(args.output_path, 'w') as f: 45 | f.write(json.dumps(my_dict) + '\n') 46 | 47 | if __name__ == "__main__": 48 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 49 | argparser.add_argument("--input_path", type=str, help="Path to the parsed PDBs") 50 | argparser.add_argument("--output_path", type=str, help="Path to the output dictionary") 51 | argparser.add_argument("--chain_list", type=str, default='', help="List of the chains that need to be fixed") 52 | argparser.add_argument("--position_list", type=str, default='', help="Position lists, e.g. 11 12 14 18, 1 2 3 4 for first chain and the second chain") 53 | argparser.add_argument("--homooligomer", type=int, default=0, help="If 0 do not use, if 1 then design homooligomer") 54 | 55 | args = argparser.parse_args() 56 | main(args) 57 | 58 | 59 | #e.g. output 60 | #{"5TTA": [], "3LIS": [{"A": [1], "B": [1]}, {"A": [2], "B": [2]}, {"A": [3], "B": [3]}, {"A": [4], "B": [4]}, {"A": [5], "B": [5]}, {"A": [6], "B": [6]}, {"A": [7], "B": [7]}, {"A": [8], "B": [8]}, {"A": [9], "B": [9]}, {"A": [10], "B": [10]}, {"A": [11], "B": [11]}, {"A": [12], "B": [12]}, {"A": [13], "B": [13]}, {"A": [14], "B": [14]}, {"A": [15], "B": [15]}, {"A": [16], "B": [16]}, {"A": [17], "B": [17]}, {"A": [18], "B": [18]}, {"A": [19], "B": [19]}, {"A": [20], "B": [20]}, {"A": [21], "B": [21]}, {"A": [22], "B": [22]}, {"A": [23], "B": [23]}, {"A": [24], "B": [24]}, {"A": [25], "B": [25]}, {"A": [26], "B": [26]}, {"A": [27], "B": [27]}, {"A": [28], "B": [28]}, {"A": [29], "B": [29]}, {"A": [30], "B": [30]}, {"A": [31], "B": [31]}, {"A": [32], "B": [32]}, {"A": [33], "B": [33]}, {"A": [34], "B": [34]}, {"A": [35], "B": [35]}, {"A": [36], "B": [36]}, {"A": [37], "B": [37]}, {"A": [38], "B": [38]}, {"A": [39], "B": [39]}, {"A": [40], "B": [40]}, {"A": [41], "B": [41]}, {"A": [42], "B": [42]}, {"A": [43], "B": [43]}, {"A": [44], "B": [44]}, {"A": [45], "B": [45]}, {"A": [46], "B": [46]}, {"A": [47], "B": [47]}, {"A": [48], "B": [48]}, {"A": [49], "B": [49]}, {"A": [50], "B": [50]}, {"A": [51], "B": [51]}, {"A": [52], "B": [52]}, {"A": [53], "B": [53]}, {"A": [54], "B": [54]}, {"A": [55], "B": [55]}, {"A": [56], "B": [56]}, {"A": [57], "B": [57]}, {"A": [58], "B": [58]}, {"A": [59], "B": [59]}, {"A": [60], "B": [60]}, {"A": [61], "B": [61]}, {"A": [62], "B": [62]}, {"A": [63], "B": [63]}, {"A": [64], "B": [64]}, {"A": [65], "B": [65]}, {"A": [66], "B": [66]}, {"A": [67], "B": [67]}, {"A": [68], "B": [68]}, {"A": [69], "B": [69]}, {"A": [70], "B": [70]}, {"A": [71], "B": [71]}, {"A": [72], "B": [72]}, {"A": [73], "B": [73]}, {"A": [74], "B": [74]}, {"A": [75], "B": [75]}, {"A": [76], "B": [76]}, {"A": [77], "B": [77]}, {"A": [78], "B": [78]}, {"A": [79], "B": [79]}, {"A": [80], "B": [80]}, {"A": [81], "B": [81]}, {"A": [82], "B": [82]}, {"A": [83], "B": [83]}, {"A": [84], "B": [84]}, {"A": [85], "B": [85]}, {"A": [86], "B": [86]}, {"A": [87], "B": [87]}, {"A": [88], "B": [88]}, {"A": [89], "B": [89]}, {"A": [90], "B": [90]}, {"A": [91], "B": [91]}, {"A": [92], "B": [92]}, {"A": [93], "B": [93]}, {"A": [94], "B": [94]}, {"A": [95], "B": [95]}, {"A": [96], "B": [96]}]} 61 | 62 | -------------------------------------------------------------------------------- /src/byprot/datamodules/datasets/vocab.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import reduce 3 | from typing import List, OrderedDict, Union 4 | 5 | import torch 6 | from torch import Tensor 7 | from torchtext.vocab import Vocab as V 8 | from torchtext.vocab import build_vocab_from_iterator 9 | from torchtext.vocab import vocab as build_vocab 10 | 11 | from byprot import utils 12 | 13 | log = utils.get_logger(__name__) 14 | 15 | 16 | def _AugmentedVocab(v: V): 17 | class AugmentedVocab(V): 18 | def __init__(self, vocab): 19 | super().__init__(vocab) 20 | self.pad = self[''] 21 | self.unk = self[''] 22 | self.eos = self[''] 23 | self.bos = self[''] 24 | 25 | def encode(self, tokens: List[str]): 26 | return V.lookup_indices(self, tokens) 27 | 28 | def decode( 29 | self, 30 | indices: Union[List[int], List[List[int]]], 31 | remove_special=True, 32 | bpe_symbol='@@ ', 33 | ) -> List[str]: 34 | 35 | if isinstance(indices, Tensor): 36 | indices = indices.detach().cpu().tolist() 37 | 38 | if isinstance(indices[0], List): 39 | return [ 40 | self.decode(_indices, remove_special, bpe_symbol) for _indices in indices 41 | ] 42 | 43 | if remove_special and self.eos in indices: 44 | indices = indices[:indices.index(self.eos)] 45 | text = ' '.join(self.lookup_tokens(indices)) 46 | return post_process(text, bpe_symbol) 47 | 48 | return AugmentedVocab(v.vocab) 49 | 50 | 51 | def load_vocab(root='./data', file_prefix='vocab', lang='en'): 52 | """ 53 | Expected vocab file: 54 | each line contains a , separated by a space 55 | e.g., 56 | (root/vocab.en) 57 | hello 42 58 | world 11 59 | """ 60 | vocab_full_path = os.path.join(root, f"{file_prefix}.{lang}") 61 | if not os.path.exists(vocab_full_path): 62 | return None 63 | 64 | ordered_dict = OrderedDict() 65 | with open(vocab_full_path) as fp: 66 | for line in fp: 67 | token, freq = line.strip().split() 68 | ordered_dict[token] = int(freq) 69 | 70 | vocab: V = build_vocab(ordered_dict, min_freq=0) 71 | vocab.set_default_index(vocab[""]) 72 | return _AugmentedVocab(vocab) 73 | 74 | 75 | def build_vocab_from_alphabet(alphabet: Union[List[str], str], specials=[]): 76 | if isinstance(alphabet, str): 77 | alphabet = list(alphabet) 78 | 79 | ordered_dict = OrderedDict( 80 | {element: 1 for element in alphabet} 81 | ) 82 | 83 | vocab: V = build_vocab(ordered_dict, min_freq=0, specials=['', '', ''] + specials) 84 | vocab.set_default_index(vocab[""]) 85 | return _AugmentedVocab(vocab) 86 | 87 | 88 | def yield_tokens(data_iter, index=0): 89 | if isinstance(index, int): 90 | index = [index] 91 | 92 | for from_to_tuple in data_iter: 93 | tokens = [] 94 | for i in index: 95 | tokens.extend(from_to_tuple[i].split(' ')) 96 | yield tokens 97 | 98 | 99 | def build_vocab_from_datasets(datasets, index): 100 | datasets_concat = reduce(lambda a, b: a + b, datasets) 101 | 102 | vocab: V = build_vocab_from_iterator( 103 | yield_tokens(datasets_concat, index), 104 | min_freq=2, specials=["", "", "", ""], 105 | ) 106 | vocab.set_default_index(vocab[""]) 107 | return _AugmentedVocab(vocab) 108 | 109 | 110 | def save_vocab(vocab: V, root, file_prefix='vocab', lang='en'): 111 | vocab_full_path = os.path.join(root, f"{file_prefix}.{lang}") 112 | tokens = vocab.get_itos() 113 | 114 | log.info(f'Saving vocabulary for {lang} to {vocab_full_path}') 115 | with open(vocab_full_path, 'w') as fp: 116 | for idx, token in enumerate(tokens): 117 | fp.write(f"{token} {idx}\n") 118 | 119 | 120 | def post_process(sentence: str, symbol: str): 121 | if symbol == "sentencepiece": 122 | sentence = sentence.replace(" ", "").replace("\u2581", " ").strip() 123 | elif symbol == "wordpiece": 124 | sentence = sentence.replace(" ", "").replace("_", " ").strip() 125 | elif symbol == "letter": 126 | sentence = sentence.replace(" ", "").replace("|", " ").strip() 127 | elif symbol == "silence": 128 | import re 129 | 130 | sentence = sentence.replace("", "") 131 | sentence = re.sub(" +", " ", sentence).strip() 132 | elif symbol == "_EOW": 133 | sentence = sentence.replace(" ", "").replace("_EOW", " ").strip() 134 | elif symbol in {"subword_nmt", "@@ ", "@@"}: 135 | if symbol == "subword_nmt": 136 | symbol = "@@ " 137 | sentence = (sentence + " ").replace(symbol, "").rstrip() 138 | elif symbol == "none": 139 | pass 140 | elif symbol is not None: 141 | raise NotImplementedError(f"Unknown post_process option: {symbol}") 142 | return sentence 143 | -------------------------------------------------------------------------------- /src/byprot/models/fixedbb/lm_design/esm2_adapter_gvptrans.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List 3 | 4 | import torch 5 | from byprot.models import register_model 6 | from byprot.models.fixedbb import FixedBackboneDesignEncoderDecoder 7 | from byprot.models.fixedbb.generator import sample_from_categorical 8 | from byprot.utils.config import compose_config as Cfg 9 | 10 | from .modules.esm2_adapter import ESM2WithStructuralAdatper 11 | from .modules.gvp_transformer_encoder import GVPTransformerEncoderWrapper 12 | 13 | 14 | ESM2AdapterGVPTransConfig = Cfg( 15 | # encoder: ProteinMPNNConfig = field(default=ProteinMPNNConfig()) 16 | encoder=Cfg( 17 | d_model=512 18 | ), 19 | adapter_layer_indices=[-1, ], 20 | separate_loss=True, 21 | name='esm2_t33_650M_UR50D', 22 | dropout=0.1, 23 | ) 24 | 25 | 26 | @register_model('esm2_adapter_gvptrans') 27 | class ESM2AdapterGVPTrans(FixedBackboneDesignEncoderDecoder): 28 | _default_cfg = ESM2AdapterGVPTransConfig 29 | 30 | def __init__(self, cfg) -> None: 31 | super().__init__(cfg) 32 | 33 | self.decoder = ESM2WithStructuralAdatper.from_pretrained(args=self.cfg, name=self.cfg.name) 34 | self.encoder = GVPTransformerEncoderWrapper(self.decoder.alphabet, freeze=True) 35 | 36 | self.padding_idx = self.decoder.padding_idx 37 | self.mask_idx = self.decoder.mask_idx 38 | self.cls_idx = self.decoder.cls_idx 39 | self.eos_idx = self.decoder.eos_idx 40 | 41 | def forward(self, batch, **kwargs): 42 | encoder_logits, encoder_out = self.encoder(batch, return_feats=True, **kwargs) 43 | 44 | encoder_out['feats'] = encoder_out['feats'].detach() 45 | 46 | init_pred = encoder_logits.argmax(-1) 47 | init_pred = torch.where(batch['coord_mask'], init_pred, batch['prev_tokens']) 48 | 49 | esm_logits = self.decoder( 50 | tokens=init_pred, 51 | encoder_out=encoder_out, 52 | )['logits'] 53 | 54 | if not getattr(self.cfg, 'separate_loss', False): 55 | logits = encoder_logits + esm_logits 56 | return logits, encoder_logits 57 | else: 58 | return esm_logits, encoder_logits 59 | 60 | def forward_encoder(self, batch): 61 | encoder_logits, encoder_out = self.encoder(batch, return_feats=True) 62 | 63 | init_pred = encoder_logits.argmax(-1) 64 | init_pred = torch.where(batch['coord_mask'], init_pred, batch['prev_tokens']) 65 | 66 | encoder_out['logits'] = encoder_logits 67 | encoder_out['init_pred'] = init_pred 68 | encoder_out['coord_mask'] = batch['coord_mask'] 69 | return encoder_out 70 | 71 | def forward_decoder(self, prev_decoder_out, encoder_out, need_attn_weights=False): 72 | output_tokens = prev_decoder_out['output_tokens'] 73 | output_scores = prev_decoder_out['output_scores'] 74 | step, max_step = prev_decoder_out['step'], prev_decoder_out['max_step'] 75 | temperature = prev_decoder_out['temperature'] 76 | history = prev_decoder_out['history'] 77 | 78 | # output_masks = output_tokens.eq(self.mask_idx) # & coord_mask 79 | output_masks = output_tokens.ne(self.padding_idx) # & coord_mask 80 | 81 | esm_logits = self.decoder( 82 | # tokens=encoder_out['init_pred'], 83 | tokens=output_tokens, 84 | encoder_out=encoder_out, 85 | )['logits'] 86 | 87 | if not getattr(self.cfg, 'separate_loss', False): 88 | logits = esm_logits + encoder_out['logits'] 89 | else: 90 | logits = esm_logits # + encoder_out['logits'] 91 | 92 | _tokens, _scores = sample_from_categorical(logits, temperature=temperature) 93 | 94 | output_tokens.masked_scatter_(output_masks, _tokens[output_masks]) 95 | output_scores.masked_scatter_(output_masks, _scores[output_masks]) 96 | 97 | history.append(output_tokens.clone()) 98 | 99 | return dict( 100 | output_tokens=output_tokens, 101 | output_scores=output_scores, 102 | step=step + 1, 103 | max_step=max_step, 104 | history=history 105 | ) 106 | 107 | def initialize_output_tokens(self, batch, encoder_out): 108 | mask = encoder_out.get('coord_mask', None) 109 | 110 | prev_tokens = batch['prev_tokens'] 111 | prev_token_mask = batch['prev_token_mask'] 112 | # lengths = prev_tokens.ne(self.padding_idx).sum(1) 113 | 114 | # initial_output_tokens = torch.full_like(prev_tokens, self.padding_idx) 115 | # initial_output_tokens.masked_fill_(new_arange(prev_tokens) < lengths[:, None], self.mask_idx) 116 | # initial_output_tokens[:, 0] = self.cls_idx 117 | # initial_output_tokens.scatter_(1, lengths[:, None] - 1, self.eos_idx) 118 | 119 | # initial_output_tokens = encoder_out['init_pred'].clone() 120 | initial_output_tokens = torch.where( 121 | prev_token_mask, encoder_out['init_pred'], prev_tokens) 122 | initial_output_scores = torch.zeros( 123 | *initial_output_tokens.size(), device=initial_output_tokens.device 124 | ) 125 | 126 | return initial_output_tokens, initial_output_scores 127 | -------------------------------------------------------------------------------- /src/byprot/models/seq2seq/transformer_decoder.py: -------------------------------------------------------------------------------- 1 | """ Transformer encoder """ 2 | 3 | import copy 4 | from typing import Dict 5 | 6 | import torch 7 | from torch import Tensor, nn 8 | from torch.nn import functional as F 9 | 10 | from .modules.ffn import FFN 11 | from .modules.multihead_attention import MHA 12 | from .modules.utils import ResNorm, _get_clones 13 | 14 | 15 | class TransformerDecoderLayer(nn.Module): 16 | def __init__( 17 | self, 18 | d_model, 19 | n_heads, 20 | d_inner=2048, 21 | dropout=0.1, 22 | attn_dropout=0., 23 | normalize_before=False, 24 | ): 25 | super().__init__() 26 | 27 | self.self_attn = ResNorm( 28 | net=MHA(embed_dim=d_model, num_heads=n_heads, dropout=attn_dropout), 29 | fn=lambda net, x, *args, **kwargs: net(query=x, key=x, value=x, *args, **kwargs), 30 | dim=d_model, normalize_before=normalize_before, dropout=dropout 31 | ) 32 | 33 | self.cross_attn = ResNorm( 34 | net=MHA(embed_dim=d_model, num_heads=n_heads, dropout=attn_dropout), 35 | dim=d_model, normalize_before=normalize_before, dropout=dropout 36 | ) 37 | 38 | self.ffn = ResNorm( 39 | net=FFN(d_model=d_model, d_inner=d_inner, dropout=dropout), 40 | dim=d_model, normalize_before=normalize_before, dropout=dropout 41 | ) 42 | 43 | self.reset_parameters() 44 | 45 | def reset_parameters(self): 46 | self.self_attn.reset_parameters() 47 | self.cross_attn.reset_parameters() 48 | self.ffn.reset_parameters() 49 | 50 | def forward( 51 | self, 52 | x: Tensor, 53 | memory: Tensor, 54 | self_padding_mask: Tensor = None, 55 | self_attn_mask: Tensor = None, 56 | memory_padding_mask: Tensor = None, 57 | incremental_states: Dict[str, Dict[str, Tensor]] = None, 58 | ): 59 | x, *others = self.self_attn( 60 | x, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, 61 | incremental_states=incremental_states) 62 | x, *others = self.cross_attn( 63 | x, key=memory, value=memory, key_padding_mask=memory_padding_mask, 64 | static_kv=True, incremental_states=incremental_states) 65 | x = self.ffn(x) 66 | return x 67 | 68 | 69 | class TransformerDecoder(nn.Module): 70 | def __init__( 71 | self, 72 | n_layers, 73 | d_model, 74 | n_heads, 75 | d_inner=2048, 76 | dropout=0.1, 77 | attn_dropout=0., 78 | normalize_before=False, 79 | causal=True, 80 | layer=None 81 | ): 82 | super().__init__() 83 | 84 | if layer is None: 85 | layer = TransformerDecoderLayer( 86 | d_model, n_heads, d_inner, dropout, attn_dropout, normalize_before 87 | ) 88 | self.layers = _get_clones(layer, N=n_layers) 89 | 90 | self.norm = None 91 | if normalize_before: 92 | self.norm = nn.LayerNorm(d_model) 93 | 94 | self.causal = causal 95 | 96 | self.reset_parameters() 97 | 98 | def reset_parameters(self): 99 | for layer in self.layers: layer.reset_parameters() 100 | 101 | def _maybe_get_causal_mask(self, x, incremental_states=None): 102 | "Mask out subsequent positions." 103 | if not self.causal: 104 | return None 105 | if self._inferring and incremental_states is not None: 106 | return None 107 | 108 | size = x.shape[1] 109 | causal_mask = torch.triu( 110 | torch.ones((size, size), dtype=torch.bool, device=x.device), 111 | diagonal=1 112 | ) 113 | return causal_mask 114 | 115 | def forward( 116 | self, 117 | x: Tensor, 118 | memory: Tensor, 119 | self_padding_mask: Tensor = None, 120 | memory_padding_mask: Tensor = None, 121 | incremental_states: Dict[str, Dict[str, Tensor]] = None, 122 | ): 123 | out = x 124 | self_attn_mask = self._maybe_get_causal_mask(x, incremental_states) 125 | 126 | for layer in self.layers: 127 | out = layer( 128 | out, memory, 129 | self_padding_mask=self_padding_mask, 130 | self_attn_mask=self_attn_mask, 131 | memory_padding_mask=memory_padding_mask, 132 | incremental_states=incremental_states 133 | ) 134 | 135 | if self.norm is not None: 136 | out = self.norm(out) 137 | return out 138 | 139 | 140 | if __name__ == '__main__': 141 | from byprot.models.sequence.transformer_decoder import * 142 | 143 | B, L, D = 10, 7, 32 144 | M = 5 145 | decoder = TransformerDecoder(n_layers=6, d_model=D, n_heads=4, d_inner=2*D, normalize_before=True) 146 | 147 | x = torch.randn(B, L, D) 148 | mem = torch.randn(B, M, D) 149 | 150 | key_padding_mask = ~(torch.arange(L)[None] < torch.randint(1, L, (B,))[:, None]) 151 | 152 | # attn_mask = torch.triu(torch.ones(10, 7, 7, dtype=torch.bool), 1) 153 | x = decoder(x, mem, self_padding_mask=key_padding_mask) 154 | -------------------------------------------------------------------------------- /src/byprot/datamodules/cath_datamodule.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | from byprot import utils 7 | from byprot.datamodules import register_datamodule 8 | from pytorch_lightning import LightningDataModule 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | from .datasets.cath import CATH 12 | from .datasets.data_utils import Alphabet, MaxTokensBatchSampler 13 | 14 | log = utils.get_logger(__name__) 15 | 16 | 17 | # @register_datamodule('struct2seq') 18 | @register_datamodule('cath') 19 | class CATHDataModule(LightningDataModule): 20 | 21 | def __init__( 22 | self, 23 | data_dir: str = "data/", 24 | chain_set_jsonl: str = 'chain_set.jsonl', 25 | chain_set_splits_json: str = 'chain_set_splits.json', 26 | max_length: int = 500, 27 | atoms: List[str] = ('N', 'CA', 'C', 'O'), 28 | alphabet=None, 29 | batch_size: int = 64, 30 | max_tokens: int = 6000, 31 | sort: bool = False, 32 | num_workers: int = 0, 33 | pin_memory: bool = False, 34 | train_split: str = 'train', 35 | valid_split: str = 'valid', 36 | test_split: str = 'test', 37 | ): 38 | super().__init__() 39 | 40 | # this line allows to access init params with 'self.hparams' attribute 41 | self.save_hyperparameters(logger=False) 42 | 43 | self.alphabet = None 44 | 45 | self.train_data: Optional[Dataset] = None 46 | self.valid_data: Optional[Dataset] = None 47 | self.test_data: Optional[Dataset] = None 48 | 49 | def setup(self, stage: Optional[str] = None): 50 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 51 | 52 | This method is called by lightning when doing `trainer.fit()` and `trainer.test()`, 53 | so be careful not to execute the random split twice! The `stage` can be used to 54 | differentiate whether it's called before trainer.fit()` or `trainer.test()`. 55 | """ 56 | 57 | # load datasets only if they're not loaded already 58 | if stage == 'fit': 59 | (train, valid), alphabet = CATH( 60 | self.hparams.data_dir, 61 | chain_set_jsonl=self.hparams.chain_set_jsonl, 62 | chain_set_splits_json=self.hparams.chain_set_splits_json, 63 | max_length=self.hparams.max_length, 64 | split=(self.hparams.train_split, self.hparams.valid_split), 65 | ) 66 | self.train_dataset = train 67 | self.valid_dataset = valid 68 | elif stage == 'test' or stage == 'predict': 69 | test, alphabet = CATH( 70 | self.hparams.data_dir, 71 | chain_set_jsonl=self.hparams.chain_set_jsonl, 72 | chain_set_splits_json=self.hparams.chain_set_splits_json, 73 | split=(self.hparams.test_split, ) 74 | ) 75 | self.test_dataset = test 76 | else: 77 | raise ValueError(f"Invalid stage: {stage}.") 78 | 79 | self.alphabet = Alphabet(**self.hparams.alphabet) 80 | 81 | self.collate_batch = self.alphabet.featurizer 82 | 83 | def _build_batch_sampler(self, dataset, max_tokens, shuffle=False, distributed=True): 84 | is_distributed = distributed and torch.distributed.is_initialized() 85 | 86 | batch_sampler = MaxTokensBatchSampler( 87 | dataset=dataset, 88 | shuffle=shuffle, 89 | distributed=is_distributed, 90 | batch_size=self.hparams.batch_size, 91 | max_tokens=max_tokens, 92 | sort=self.hparams.sort, 93 | drop_last=False, 94 | sort_key=lambda i: len(dataset[i]['seq'])) 95 | return batch_sampler 96 | 97 | def train_dataloader(self): 98 | if not hasattr(self, 'train_batch_sampler'): 99 | self.train_batch_sampler = self._build_batch_sampler( 100 | self.train_dataset, 101 | max_tokens=self.hparams.max_tokens, 102 | shuffle=True 103 | ) 104 | return DataLoader( 105 | dataset=self.train_dataset, 106 | batch_sampler=self.train_batch_sampler, 107 | num_workers=self.hparams.num_workers, 108 | pin_memory=self.hparams.pin_memory, 109 | collate_fn=self.collate_batch 110 | ) 111 | 112 | def val_dataloader(self): 113 | return DataLoader( 114 | dataset=self.valid_dataset, 115 | batch_sampler=self._build_batch_sampler( 116 | self.valid_dataset, max_tokens=self.hparams.max_tokens, distributed=False), 117 | num_workers=self.hparams.num_workers, 118 | pin_memory=self.hparams.pin_memory, 119 | collate_fn=self.collate_batch 120 | ) 121 | 122 | def test_dataloader(self): 123 | return DataLoader( 124 | dataset=self.test_dataset, 125 | batch_sampler=self._build_batch_sampler( 126 | self.test_dataset, max_tokens=self.hparams.max_tokens, distributed=False), 127 | num_workers=self.hparams.num_workers, 128 | pin_memory=self.hparams.pin_memory, 129 | collate_fn=self.collate_batch 130 | ) 131 | -------------------------------------------------------------------------------- /src/byprot/models/fixedbb/lm_design/esm_adapter.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List 3 | 4 | import torch 5 | from byprot.models import register_model 6 | from byprot.models.fixedbb import FixedBackboneDesignEncoderDecoder 7 | from byprot.models.fixedbb.generator import sample_from_categorical 8 | from byprot.models.fixedbb.protein_mpnn_cmlm.protein_mpnn import ( 9 | ProteinMPNNCMLM, ProteinMPNNConfig) 10 | 11 | from .modules.esm_adapter import ProteinBertModelWithStructuralAdatper 12 | 13 | 14 | @dataclass 15 | class ESMAdapterConfig: 16 | encoder: ProteinMPNNConfig = field(default=ProteinMPNNConfig()) 17 | adapter_layer_indices: List = field(default_factory=lambda: [32, ]) 18 | separate_loss: bool = True 19 | # ensemble_logits: bool = False 20 | initialize_input: bool = True 21 | 22 | 23 | @register_model('esm_adapter') 24 | class ESMAdapter(FixedBackboneDesignEncoderDecoder): 25 | _default_cfg = ESMAdapterConfig() 26 | 27 | def __init__(self, cfg) -> None: 28 | super().__init__(cfg) 29 | 30 | self.encoder = ProteinMPNNCMLM(self.cfg.encoder) 31 | self.decoder = ProteinBertModelWithStructuralAdatper.from_pretrained(args=self.cfg) 32 | 33 | self.padding_idx = self.decoder.padding_idx 34 | self.mask_idx = self.decoder.mask_idx 35 | self.cls_idx = self.decoder.cls_idx 36 | self.eos_idx = self.decoder.eos_idx 37 | 38 | def forward(self, batch, **kwargs): 39 | encoder_logits, encoder_out = self.encoder(batch, return_feats=True, **kwargs) 40 | 41 | encoder_out['feats'] = encoder_out['feats'].detach() 42 | 43 | if self.cfg.initialize_input: 44 | init_pred = encoder_logits.argmax(-1) 45 | init_pred = torch.where(batch['coord_mask'], init_pred, batch['prev_tokens']) 46 | else: 47 | init_pred = batch['prev_tokens'] 48 | 49 | esm_logits = self.decoder( 50 | tokens=init_pred, 51 | encoder_out=encoder_out, 52 | )['logits'] 53 | 54 | if not getattr(self.cfg, 'separate_loss', False): 55 | logits = encoder_logits + esm_logits 56 | return logits, encoder_logits 57 | else: 58 | return esm_logits, encoder_logits 59 | 60 | def forward_encoder(self, batch): 61 | encoder_logits, encoder_out = self.encoder(batch, return_feats=True) 62 | 63 | if self.cfg.initialize_input: 64 | init_pred = encoder_logits.argmax(-1) 65 | init_pred = torch.where(batch['coord_mask'], init_pred, batch['prev_tokens']) 66 | else: 67 | init_pred = batch['prev_tokens'] 68 | 69 | encoder_out['logits'] = encoder_logits 70 | encoder_out['init_pred'] = init_pred 71 | encoder_out['coord_mask'] = batch['coord_mask'] 72 | return encoder_out 73 | 74 | def forward_decoder(self, prev_decoder_out, encoder_out, need_attn_weights=False): 75 | output_tokens = prev_decoder_out['output_tokens'] 76 | output_scores = prev_decoder_out['output_scores'] 77 | step, max_step = prev_decoder_out['step'], prev_decoder_out['max_step'] 78 | temperature = prev_decoder_out['temperature'] 79 | history = prev_decoder_out['history'] 80 | 81 | # output_masks = output_tokens.eq(self.mask_idx) # & coord_mask 82 | output_masks = output_tokens.ne(self.padding_idx) # & coord_mask 83 | 84 | esm_logits = self.decoder( 85 | # tokens=encoder_out['init_pred'], 86 | tokens=output_tokens, 87 | encoder_out=encoder_out, 88 | )['logits'] 89 | 90 | if not getattr(self.cfg, 'separate_loss', False): 91 | logits = 0 * esm_logits + encoder_out['logits'] 92 | else: 93 | logits = esm_logits # + encoder_out['logits'] 94 | 95 | _tokens, _scores = sample_from_categorical(logits, temperature=temperature) 96 | 97 | output_tokens.masked_scatter_(output_masks, _tokens[output_masks]) 98 | output_scores.masked_scatter_(output_masks, _scores[output_masks]) 99 | 100 | history.append(output_tokens.clone()) 101 | 102 | return dict( 103 | output_tokens=output_tokens, 104 | output_scores=output_scores, 105 | step=step + 1, 106 | max_step=max_step, 107 | history=history 108 | ) 109 | 110 | def initialize_output_tokens(self, batch, encoder_out): 111 | mask = encoder_out.get('coord_mask', None) 112 | 113 | prev_tokens = batch['prev_tokens'] 114 | prev_token_mask = batch['prev_token_mask'] 115 | # lengths = prev_tokens.ne(self.padding_idx).sum(1) 116 | 117 | # initial_output_tokens = torch.full_like(prev_tokens, self.padding_idx) 118 | # initial_output_tokens.masked_fill_(new_arange(prev_tokens) < lengths[:, None], self.mask_idx) 119 | # initial_output_tokens[:, 0] = self.cls_idx 120 | # initial_output_tokens.scatter_(1, lengths[:, None] - 1, self.eos_idx) 121 | 122 | # initial_output_tokens = encoder_out['init_pred'].clone() 123 | initial_output_tokens = torch.where( 124 | prev_token_mask, encoder_out['init_pred'], prev_tokens) 125 | initial_output_scores = torch.zeros( 126 | *initial_output_tokens.size(), device=initial_output_tokens.device 127 | ) 128 | 129 | return initial_output_tokens, initial_output_scores 130 | -------------------------------------------------------------------------------- /src/byprot/modules/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True): 7 | flag = False 8 | if target.dim() == lprobs.dim() - 1: 9 | flag = True 10 | target = target.unsqueeze(-1) 11 | 12 | nll_loss = -lprobs.gather(dim=-1, index=target) 13 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) 14 | if ignore_index is not None: 15 | pad_mask = target.eq(ignore_index) 16 | nll_loss.masked_fill_(pad_mask, 0.0) 17 | smooth_loss.masked_fill_(pad_mask, 0.0) 18 | 19 | if flag: 20 | nll_loss = nll_loss.squeeze(-1) 21 | smooth_loss = smooth_loss.squeeze(-1) 22 | 23 | if reduce: 24 | nll_loss = nll_loss.sum() 25 | smooth_loss = smooth_loss.sum() 26 | eps_i = epsilon / (lprobs.size(-1) - 1) 27 | loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss 28 | return loss, nll_loss 29 | 30 | 31 | class CrossEntropyLoss(nn.CrossEntropyLoss): 32 | def forward(self, scores: Tensor, target: Tensor, mask=None) -> Tensor: 33 | """ 34 | scores: [N, ..., C], unnormalized scores 35 | target: [N, ...] 36 | mask: [N, ...], where elements with `True` are allowed and `False` are masked-out 37 | """ 38 | n_tokens = target.numel() 39 | n_nonpad_tokens = target.ne(self.ignore_index).long().sum() 40 | 41 | bsz, num_classes = scores.shape[0], scores.shape[-1] 42 | 43 | if mask is not None: 44 | scores = scores[mask] # [N * len, C] 45 | target = target[mask] # [N] 46 | scores = scores.reshape(-1, num_classes) 47 | target = target.reshape(-1) 48 | 49 | if self.ignore_index is not None: 50 | sample_size = target.ne(self.ignore_index).long().sum() 51 | else: 52 | sample_size = torch.tensor(target.numel(), device=target.device) 53 | 54 | # smooth_loss = F.cross_entropy( 55 | # scores.transpose(1, -1), target, 56 | # weight=self.weight, 57 | # ignore_index=self.ignore_index, reduction=self.reduction, 58 | # label_smoothing=self.label_smoothing) 59 | 60 | loss, nll_loss = label_smoothed_nll_loss( 61 | lprobs=F.log_softmax(scores, dim=-1), 62 | target=target, 63 | epsilon=self.label_smoothing, 64 | ignore_index=self.ignore_index, 65 | reduce=True, 66 | ) 67 | loss_avg = loss / sample_size 68 | ppl = torch.exp(nll_loss / sample_size) 69 | 70 | logging_output = { 71 | 'nll_loss_sum': nll_loss.data, 72 | 'loss_sum': loss.data, 73 | 'ppl': ppl.data, 74 | 'bsz': bsz, 75 | 'sample_size': sample_size, 76 | 'sample_ratio': sample_size / n_tokens, 77 | 'nonpad_ratio': n_nonpad_tokens / n_tokens 78 | } 79 | return loss_avg, logging_output 80 | 81 | 82 | class Coord2SeqCrossEntropyLoss(nn.CrossEntropyLoss): 83 | def forward(self, scores: Tensor, target: Tensor, label_mask=None, coord_mask=None, weights=None) -> Tensor: 84 | """ 85 | scores: [N, L, C], unnormalized scores 86 | target: [N, L] 87 | coord_mask: FloatTensor [N, L], where elements with `True` are allowed and `False` are masked-out 88 | """ 89 | if label_mask is None: 90 | label_mask = coord_mask 91 | 92 | bsz, num_classes = scores.shape[0], scores.shape[-1] 93 | 94 | n_tokens = target.numel() 95 | if self.ignore_index is not None: 96 | sample_size = n_nonpad_tokens = target.ne(self.ignore_index).float().sum() 97 | else: 98 | sample_size = n_nonpad_tokens = n_tokens 99 | 100 | # [N, L] 101 | loss, nll_loss = label_smoothed_nll_loss( 102 | lprobs=F.log_softmax(scores, dim=-1), 103 | target=target, 104 | epsilon=self.label_smoothing, 105 | ignore_index=self.ignore_index, 106 | reduce=False, 107 | ) 108 | if weights is not None: 109 | loss, nll_loss = loss * weights, nll_loss * weights 110 | fullseq_loss = loss.sum() / sample_size 111 | fullseq_nll_loss = nll_loss.sum() / sample_size 112 | 113 | # use coord masked loss for model training, 114 | # ignoring those position with missing coords (as nan) 115 | if label_mask is not None: 116 | label_mask = label_mask.float() 117 | sample_size = label_mask.sum() # sample size should be set to valid coordinates 118 | loss = (loss * label_mask).sum() / sample_size 119 | nll_loss = (nll_loss * label_mask).sum() / sample_size 120 | else: 121 | loss, nll_loss = fullseq_loss, fullseq_nll_loss 122 | 123 | ppl = torch.exp(nll_loss) 124 | 125 | logging_output = { 126 | 'nll_loss': nll_loss.data, 127 | 'ppl': ppl.data, 128 | 'fullseq_loss': fullseq_loss.data, 129 | 'fullseq_nll_loss': fullseq_nll_loss.data, 130 | 'bsz': bsz, 131 | 'sample_size': sample_size, 132 | 'sample_ratio': sample_size / n_tokens, 133 | 'nonpad_ratio': n_nonpad_tokens / n_tokens 134 | } 135 | return loss, logging_output 136 | 137 | -------------------------------------------------------------------------------- /src/byprot/datamodules/datasets/parallel_dataset.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | import itertools 3 | import os 4 | import random 5 | from functools import partial 6 | from typing import Callable, Iterator, List, TypeVar 7 | 8 | import numpy as np 9 | import torch 10 | from pytorch_lightning.utilities.seed import isolate_rng, seed_everything 11 | from torch import distributed as dist 12 | from torch.utils.data import DataChunk 13 | from torch.utils.data.distributed import DistributedSampler 14 | from torch.utils.data.sampler import (BatchSampler, RandomSampler, 15 | SequentialSampler, SubsetRandomSampler) 16 | from torchdata.datapipes.iter import FileOpener, IterableWrapper 17 | from torchtext._internal.module_utils import is_module_available 18 | from torchtext.data.datasets_utils import ( 19 | _clean_files, _create_dataset_directory, 20 | _generate_iwslt_files_for_lang_and_split, _wrap_split_argument) 21 | from torchtext.data.functional import to_map_style_dataset 22 | 23 | 24 | @_wrap_split_argument(("train", "valid", "test")) 25 | def ParallelDataset( 26 | root=".data", 27 | split=("train", "valid", "test"), 28 | language_pair=("de", "en"), 29 | transforms: Callable = (None, None) 30 | ): 31 | src_language, tgt_language = language_pair 32 | src_transform, tgt_transform = transforms 33 | 34 | file_path_by_lang_and_split = { 35 | src_language: { 36 | "train": f"train.{src_language}", 37 | "valid": f"valid.{src_language}", 38 | "test": f"test.{src_language}", 39 | }, 40 | tgt_language: { 41 | "train": f"train.{tgt_language}", 42 | "valid": f"valid.{tgt_language}", 43 | "test": f"test.{tgt_language}", 44 | } 45 | } 46 | 47 | src_filename = file_path_by_lang_and_split[src_language][split] 48 | full_src_filepath = os.path.join(root, src_filename) 49 | 50 | tgt_filename = file_path_by_lang_and_split[tgt_language][split] 51 | full_tgt_filepath = os.path.join(root, tgt_filename) 52 | 53 | # src_data_dp = FileOpener(full_src_filepath, encoding="utf-8") 54 | # tgt_data_dp = FileOpener(full_tgt_filepath, encoding="utf-8") 55 | 56 | src_lines = FileOpener([full_src_filepath], encoding="utf-8").readlines(return_path=False, strip_newline=True) 57 | tgt_lines = FileOpener([full_tgt_filepath], encoding="utf-8").readlines(return_path=False, strip_newline=True) 58 | 59 | # from itertools import count 60 | # count_src, count_tgt = count(), count() 61 | # src_lines = src_lines.to_map_datapipe( 62 | # key_value_fn=lambda line: (next(count_src), line)) 63 | # print(f'source dataset size: {len(src_lines)}') 64 | # tgt_lines = tgt_lines.to_map_datapipe( 65 | # key_value_fn=lambda line: (next(count_tgt), line)) 66 | # print(f'target dataset size: {len(tgt_lines)}') 67 | # print(len(src_lines)) 68 | 69 | if src_transform is not None: 70 | src_lines = src_lines.map(src_transform) 71 | if tgt_transform is not None: 72 | tgt_lines = tgt_lines.map(tgt_transform) 73 | 74 | return src_lines.zip(tgt_lines).shuffle().sharding_filter() 75 | 76 | 77 | T_co = TypeVar("T_co", covariant=True) 78 | 79 | def identity(example): 80 | return example 81 | 82 | class MaxTokensBatchSamplerOld(BatchSampler): 83 | def __init__(self, 84 | sampler, 85 | batch_size, 86 | max_tokens, 87 | drop_last, 88 | sort_key: Callable = None, 89 | bucket_size_multiplier=100, 90 | shuffle=True): 91 | super().__init__(sampler, batch_size, drop_last) 92 | self.max_tokens = max_tokens 93 | self.sort_key = sort_key 94 | self.bucket_size_multiplier = bucket_size_multiplier 95 | self.shuffle = shuffle 96 | 97 | # Not a clean solution 98 | self.bucket_batches = [] 99 | self._build_buckets() 100 | 101 | def __iter__(self): 102 | self._build_buckets() 103 | # Iterate over buckets 104 | for batches, batch_sizes in self.bucket_batches: 105 | # Shuffle bucket-batch order 106 | batches = SubsetRandomSampler(batches) if self.shuffle else batches 107 | for batch in batches: 108 | # if self.shuffle: # Shuffle inner batch 109 | # random.shuffle(batch) 110 | yield batch # Batch indexes [sent1_idx, sent2_idx,...] 111 | 112 | def __len__(self): 113 | return sum([len(x[0]) for x in self.bucket_batches]) 114 | 115 | def _build_buckets(self): 116 | # Randomize samples 117 | tmp_sampler = RandomSampler(self.sampler) if self.shuffle else self.sampler 118 | 119 | # Split samples in N batches (or "buckets") 120 | tmp_sampler = BatchSampler(tmp_sampler, min(self.batch_size * self.bucket_size_multiplier, len(self.sampler)), False) 121 | 122 | # Sort samples 123 | self.bucket_batches = [] 124 | for bucket in tmp_sampler: 125 | bucket_sorted = sorted([(i, self.sort_key(i)) for i in bucket], key=lambda x: x[1]) 126 | 127 | # Create batches constrained 128 | batches = [] 129 | batch_sizes = [] 130 | 131 | last_batch = [] 132 | last_batch_size = 0 133 | for i, (sample_i, length_i) in enumerate(bucket_sorted): 134 | if (last_batch_size + length_i) < self.max_tokens: 135 | last_batch.append(sample_i) 136 | last_batch_size += length_i 137 | else: 138 | # Add batch 139 | batches.append(last_batch) 140 | batch_sizes.append(last_batch_size) 141 | 142 | # Add new sample 143 | last_batch = [sample_i] 144 | last_batch_size = length_i 145 | 146 | # Add last batch 147 | batches.append(last_batch) 148 | batch_sizes.append(last_batch_size) 149 | 150 | # Add bucket batches 151 | self.bucket_batches.append((batches, batch_sizes)) 152 | 153 | -------------------------------------------------------------------------------- /examples/pmpnn_compatible/helper_scripts/make_pos_neg_tied_positions_dict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def main(args): 4 | 5 | import glob 6 | import random 7 | import numpy as np 8 | import json 9 | import itertools 10 | 11 | with open(args.input_path, 'r') as json_file: 12 | json_list = list(json_file) 13 | 14 | homooligomeric_state = args.homooligomer 15 | 16 | if homooligomeric_state == 0: 17 | tied_list = [[int(item) for item in one.split()] for one in args.position_list.split(",")] 18 | global_designed_chain_list = [str(item) for item in args.chain_list.split()] 19 | my_dict = {} 20 | for json_str in json_list: 21 | result = json.loads(json_str) 22 | all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain']) #A, B, C, ... 23 | tied_positions_list = [] 24 | for i, pos in enumerate(tied_list[0]): 25 | temp_dict = {} 26 | for j, chain in enumerate(global_designed_chain_list): 27 | temp_dict[chain] = [tied_list[j][i]] #needs to be a list 28 | tied_positions_list.append(temp_dict) 29 | my_dict[result['name']] = tied_positions_list 30 | else: 31 | if args.pos_neg_chain_list: 32 | chain_list_input = [[str(item) for item in one.split()] for one in args.pos_neg_chain_list.split(",")] 33 | chain_betas_input = [[float(item) for item in one.split()] for one in args.pos_neg_chain_betas.split(",")] 34 | chain_list_flat = [item for sublist in chain_list_input for item in sublist] 35 | chain_betas_flat = [item for sublist in chain_betas_input for item in sublist] 36 | chain_betas_dict = dict(zip(chain_list_flat, chain_betas_flat)) 37 | my_dict = {} 38 | for json_str in json_list: 39 | result = json.loads(json_str) 40 | all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain']) #A, B, C, ... 41 | tied_positions_list = [] 42 | chain_length = len(result[f"seq_chain_{all_chain_list[0]}"]) 43 | for chains in chain_list_input: 44 | for i in range(1,chain_length+1): 45 | temp_dict = {} 46 | for j, chain in enumerate(chains): 47 | if args.pos_neg_chain_list and chain in chain_list_flat: 48 | temp_dict[chain] = [[i], [chain_betas_dict[chain]]] 49 | else: 50 | temp_dict[chain] = [[i], [1.0]] #first list is for residue numbers, second list is for weights for the energy, +ive and -ive design 51 | tied_positions_list.append(temp_dict) 52 | my_dict[result['name']] = tied_positions_list 53 | 54 | with open(args.output_path, 'w') as f: 55 | f.write(json.dumps(my_dict) + '\n') 56 | 57 | if __name__ == "__main__": 58 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 59 | argparser.add_argument("--input_path", type=str, help="Path to the parsed PDBs") 60 | argparser.add_argument("--output_path", type=str, help="Path to the output dictionary") 61 | argparser.add_argument("--chain_list", type=str, default='', help="List of the chains that need to be fixed") 62 | argparser.add_argument("--position_list", type=str, default='', help="Position lists, e.g. 11 12 14 18, 1 2 3 4 for first chain and the second chain") 63 | argparser.add_argument("--homooligomer", type=int, default=0, help="If 0 do not use, if 1 then design homooligomer") 64 | argparser.add_argument("--pos_neg_chain_list", type=str, default='', help="Chain lists to be tied together") 65 | argparser.add_argument("--pos_neg_chain_betas", type=str, default='', help="Chain beta list for the chain lists provided; 1.0 for the positive design, -0.1 or -0.5 for negative, 0.0 means do not use that chain info") 66 | 67 | args = argparser.parse_args() 68 | main(args) 69 | 70 | 71 | #e.g. output 72 | #{"5TTA": [], "3LIS": [{"A": [1], "B": [1]}, {"A": [2], "B": [2]}, {"A": [3], "B": [3]}, {"A": [4], "B": [4]}, {"A": [5], "B": [5]}, {"A": [6], "B": [6]}, {"A": [7], "B": [7]}, {"A": [8], "B": [8]}, {"A": [9], "B": [9]}, {"A": [10], "B": [10]}, {"A": [11], "B": [11]}, {"A": [12], "B": [12]}, {"A": [13], "B": [13]}, {"A": [14], "B": [14]}, {"A": [15], "B": [15]}, {"A": [16], "B": [16]}, {"A": [17], "B": [17]}, {"A": [18], "B": [18]}, {"A": [19], "B": [19]}, {"A": [20], "B": [20]}, {"A": [21], "B": [21]}, {"A": [22], "B": [22]}, {"A": [23], "B": [23]}, {"A": [24], "B": [24]}, {"A": [25], "B": [25]}, {"A": [26], "B": [26]}, {"A": [27], "B": [27]}, {"A": [28], "B": [28]}, {"A": [29], "B": [29]}, {"A": [30], "B": [30]}, {"A": [31], "B": [31]}, {"A": [32], "B": [32]}, {"A": [33], "B": [33]}, {"A": [34], "B": [34]}, {"A": [35], "B": [35]}, {"A": [36], "B": [36]}, {"A": [37], "B": [37]}, {"A": [38], "B": [38]}, {"A": [39], "B": [39]}, {"A": [40], "B": [40]}, {"A": [41], "B": [41]}, {"A": [42], "B": [42]}, {"A": [43], "B": [43]}, {"A": [44], "B": [44]}, {"A": [45], "B": [45]}, {"A": [46], "B": [46]}, {"A": [47], "B": [47]}, {"A": [48], "B": [48]}, {"A": [49], "B": [49]}, {"A": [50], "B": [50]}, {"A": [51], "B": [51]}, {"A": [52], "B": [52]}, {"A": [53], "B": [53]}, {"A": [54], "B": [54]}, {"A": [55], "B": [55]}, {"A": [56], "B": [56]}, {"A": [57], "B": [57]}, {"A": [58], "B": [58]}, {"A": [59], "B": [59]}, {"A": [60], "B": [60]}, {"A": [61], "B": [61]}, {"A": [62], "B": [62]}, {"A": [63], "B": [63]}, {"A": [64], "B": [64]}, {"A": [65], "B": [65]}, {"A": [66], "B": [66]}, {"A": [67], "B": [67]}, {"A": [68], "B": [68]}, {"A": [69], "B": [69]}, {"A": [70], "B": [70]}, {"A": [71], "B": [71]}, {"A": [72], "B": [72]}, {"A": [73], "B": [73]}, {"A": [74], "B": [74]}, {"A": [75], "B": [75]}, {"A": [76], "B": [76]}, {"A": [77], "B": [77]}, {"A": [78], "B": [78]}, {"A": [79], "B": [79]}, {"A": [80], "B": [80]}, {"A": [81], "B": [81]}, {"A": [82], "B": [82]}, {"A": [83], "B": [83]}, {"A": [84], "B": [84]}, {"A": [85], "B": [85]}, {"A": [86], "B": [86]}, {"A": [87], "B": [87]}, {"A": [88], "B": [88]}, {"A": [89], "B": [89]}, {"A": [90], "B": [90]}, {"A": [91], "B": [91]}, {"A": [92], "B": [92]}, {"A": [93], "B": [93]}, {"A": [94], "B": [94]}, {"A": [95], "B": [95]}, {"A": [96], "B": [96]}]} 73 | 74 | -------------------------------------------------------------------------------- /src/byprot/models/fixedbb/protein_mpnn_cmlm/protein_mpnn.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from byprot.models import register_model 5 | from byprot.models.fixedbb import FixedBackboneDesignEncoderDecoder 6 | from byprot.models.fixedbb.generator import new_arange, sample_from_categorical 7 | from byprot.datamodules.datasets.data_utils import Alphabet 8 | 9 | from .decoder import MPNNSequenceDecoder 10 | from .encoder import MPNNEncoder 11 | 12 | 13 | @dataclass 14 | class ProteinMPNNConfig: 15 | d_model: int = 128 16 | d_node_feats: int = 128 17 | d_edge_feats: int = 128 18 | k_neighbors: int = 48 19 | augment_eps: float = 0.0 20 | n_enc_layers: int = 3 21 | dropout: float = 0.1 22 | 23 | # decoder-only 24 | n_vocab: int = 22 25 | n_dec_layers: int = 3 26 | random_decoding_order: bool = True 27 | nar: bool = True 28 | crf: bool = False 29 | use_esm_alphabet: bool = False 30 | 31 | 32 | @register_model('protein_mpnn_cmlm') 33 | class ProteinMPNNCMLM(FixedBackboneDesignEncoderDecoder): 34 | _default_cfg = ProteinMPNNConfig() 35 | 36 | def __init__(self, cfg) -> None: 37 | super().__init__(cfg) 38 | 39 | self.encoder = MPNNEncoder( 40 | node_features=self.cfg.d_node_feats, 41 | edge_features=self.cfg.d_edge_feats, 42 | hidden_dim=self.cfg.d_model, 43 | num_encoder_layers=self.cfg.n_enc_layers, 44 | k_neighbors=self.cfg.k_neighbors, 45 | augment_eps=self.cfg.augment_eps, 46 | dropout=self.cfg.dropout 47 | ) 48 | 49 | if self.cfg.use_esm_alphabet: 50 | alphabet = Alphabet('esm', 'cath') 51 | self.padding_idx = alphabet.padding_idx 52 | self.mask_idx = alphabet.mask_idx 53 | else: 54 | alphabet = None 55 | self.padding_idx = 0 56 | self.mask_idx = 1 57 | 58 | self.decoder = MPNNSequenceDecoder( 59 | n_vocab=self.cfg.n_vocab, 60 | d_model=self.cfg.d_model, 61 | n_layers=self.cfg.n_dec_layers, 62 | random_decoding_order=self.cfg.random_decoding_order, 63 | dropout=self.cfg.dropout, 64 | nar=self.cfg.nar, 65 | crf=self.cfg.crf, 66 | alphabet=alphabet 67 | ) 68 | 69 | def _forward(self, coords, coord_mask, prev_tokens, token_padding_mask=None, target_tokens=None, return_feats=False, **kwargs): 70 | coord_mask = coord_mask.float() 71 | encoder_out = self.encoder(X=coords, mask=coord_mask) 72 | 73 | logits, feats = self.decoder( 74 | prev_tokens=prev_tokens, 75 | memory=encoder_out, memory_mask=coord_mask, 76 | target_tokens=target_tokens, 77 | **kwargs 78 | ) 79 | 80 | if return_feats: 81 | return logits, feats 82 | return logits 83 | 84 | def forward(self, batch, return_feats=False, **kwargs): 85 | coord_mask = batch['coord_mask'].float() 86 | 87 | encoder_out = self.encoder( 88 | X=batch['coords'], 89 | mask=coord_mask, 90 | residue_idx=batch.get('residue_idx', None), 91 | chain_idx=batch.get('chain_idx', None) 92 | ) 93 | 94 | logits, feats = self.decoder( 95 | prev_tokens=batch['prev_tokens'], 96 | memory=encoder_out, 97 | memory_mask=coord_mask, 98 | target_tokens=batch.get('tokens'), 99 | **kwargs 100 | ) 101 | 102 | if return_feats: 103 | return logits, feats 104 | return logits 105 | 106 | def forward_encoder(self, batch): 107 | encoder_out = self.encoder( 108 | X=batch['coords'], 109 | mask=batch['coord_mask'].float(), 110 | residue_idx=batch.get('residue_idx', None), 111 | chain_idx=batch.get('chain_idx', None) 112 | ) 113 | encoder_out['coord_mask'] = batch['coord_mask'].float() 114 | 115 | return encoder_out 116 | 117 | def forward_decoder(self, prev_decoder_out, encoder_out, need_attn_weights=False): 118 | output_tokens = prev_decoder_out['output_tokens'] 119 | output_scores = prev_decoder_out['output_scores'] 120 | step, max_step = prev_decoder_out['step'], prev_decoder_out['max_step'] 121 | temperature = prev_decoder_out['temperature'] 122 | history = prev_decoder_out['history'] 123 | 124 | output_masks = output_tokens.eq(self.mask_idx) # & coord_mask 125 | 126 | logits, _ = self.decoder( 127 | prev_tokens=output_tokens, 128 | memory=encoder_out, 129 | memory_mask=encoder_out['coord_mask'].float(), 130 | ) 131 | # log_probs = torch.log_softmax(logits, dim=-1) 132 | # _scores, _tokens = log_probs.max(dim=-1) 133 | _tokens, _scores = sample_from_categorical(logits, temperature=temperature) 134 | 135 | output_tokens.masked_scatter_(output_masks, _tokens[output_masks]) 136 | output_scores.masked_scatter_(output_masks, _scores[output_masks]) 137 | 138 | history.append(output_tokens.clone()) 139 | 140 | return dict( 141 | output_tokens=output_tokens, 142 | output_scores=output_scores, 143 | step=step + 1, 144 | max_step=max_step, 145 | history=history 146 | ) 147 | 148 | def initialize_output_tokens(self, batch, encoder_out): 149 | # mask = encoder_out.get('coord_mask', None) 150 | 151 | prev_tokens = batch['prev_tokens'] 152 | lengths = prev_tokens.ne(self.padding_idx).sum(1) 153 | 154 | initial_output_tokens = torch.full_like(prev_tokens, self.padding_idx) 155 | initial_output_tokens.masked_fill_(new_arange(prev_tokens) < lengths[:, None], self.mask_idx) 156 | 157 | # if mask is not None: 158 | # initial_output_tokens = torch.where( 159 | # ~mask, prev_tokens, initial_output_tokens 160 | # ) 161 | # initial_output_tokens = prev_tokens.clone() 162 | 163 | initial_output_scores = torch.zeros( 164 | *initial_output_tokens.size(), device=initial_output_tokens.device 165 | ) 166 | 167 | return initial_output_tokens, initial_output_scores 168 | -------------------------------------------------------------------------------- /examples/inspect_data_and_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 45, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "The autoreload extension is already loaded. To reload it, use:\n", 13 | " %reload_ext autoreload\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "## Load and process a single-chain or multi-chain protein from a PDB file" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 46, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "from byprot.datamodules.datasets import DataProcessor" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 58, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "pdb_path = \"/root/research/projects/ByProt_public/examples/3f4m.pdb\"\n", 45 | "pdb_path = \"/root/research/projects/ByProt_public/examples/3uat.pdb\"" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 67, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "dp = DataProcessor()\n", 55 | "structure = dp.parse_PDB(\n", 56 | " pdb_path,\n", 57 | " # input_chain_list=['A', 'B'] -> load which chains\n", 58 | " # masked_chain_list=['A'] -> which chains to predict while the remaining chains serve as conditioning\n", 59 | ")" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 66, 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "name": "stdout", 69 | "output_type": "stream", 70 | "text": [ 71 | "dict_keys(['seq_chain_A', 'coords_chain_A', 'seq_chain_B', 'coords_chain_B', 'name', 'num_of_chains', 'seq', 'coords', 'masked_list', 'visible_list'])\n" 72 | ] 73 | } 74 | ], 75 | "source": [ 76 | "print(structure.keys())\n", 77 | "print(structure)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "#### single-chain protein " 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "from byprot.datamodules.datasets.data_utils import Alphabet\n", 94 | "alphabet = Alphabet('esm', 'cath')\n", 95 | "alphabet.featurize([structure])" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "#### multi-chain protein" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "alphabet = Alphabet('esm', 'multichain')\n", 112 | "alphabet.featurize([structure])" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "## Design sequences for structures" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "from byprot.utils.config import compose_config as Cfg\n", 129 | "from byprot.tasks.fixedbb.designer import Designer\n", 130 | "\n", 131 | "# 1. instantialize designer\n", 132 | "cfg = Cfg(\n", 133 | " cuda=True,\n", 134 | " generator=Cfg(\n", 135 | " max_iter=5,\n", 136 | " strategy='denoise',\n", 137 | " temperature=0,\n", 138 | " eval_sc=False,\n", 139 | " )\n", 140 | ")\n", 141 | "exp_path = \"/root/research/projects/ByProt_public/logs/cath4.2/lm_design_esm2_650m\"\n", 142 | "designer_cath = Designer(experiment_path=exp_path, cfg=cfg)\n", 143 | "\n", 144 | "exp_path = \"/root/research/projects/ByProt_public/logs/fixedbb_multichain/lm_design_esm2_650m\"\n", 145 | "designer_complex = Designer(experiment_path=exp_path, cfg=cfg)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "# multi-chain complex\n", 155 | "pdb_path = \"/root/research/projects/ByProt_public/examples/3uat.pdb\"\n", 156 | "\n", 157 | "print(f\"designed by cath-trained LM-Design\")\n", 158 | "designer_cath.set_structure(pdb_path)\n", 159 | "print(designer_cath.generate()[0]); designer_cath.calculate_metrics()\n", 160 | "\n", 161 | "print(f\"designed by pdb complex-trained LM-Design\")\n", 162 | "designer_complex.set_structure(\n", 163 | " pdb_path\n", 164 | " # chain_list=['A', 'B'] -> load which chains\n", 165 | " # masked_chain_list=['A'] -> which chains to predict while the remaining chains serve as conditioning\n", 166 | ")\n", 167 | "print(designer_complex.generate()[0]); designer_complex.calculate_metrics()" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "# single chain\n", 177 | "pdb_path = \"/root/research/projects/ByProt_public/examples/3f4m.pdb\"\n", 178 | "\n", 179 | "print(f\"designed by cath-trained LM-Design\")\n", 180 | "designer_cath.set_structure(pdb_path)\n", 181 | "print(designer_cath.generate()[0]); designer_cath.calculate_metrics()\n", 182 | "\n", 183 | "print(f\"designed by pdb complex-trained LM-Design\")\n", 184 | "designer_complex.set_structure(pdb_path)\n", 185 | "print(designer_complex.generate()[0]); designer_complex.calculate_metrics()" 186 | ] 187 | } 188 | ], 189 | "metadata": { 190 | "kernelspec": { 191 | "display_name": "ByProt_public", 192 | "language": "python", 193 | "name": "python3" 194 | }, 195 | "language_info": { 196 | "codemirror_mode": { 197 | "name": "ipython", 198 | "version": 3 199 | }, 200 | "file_extension": ".py", 201 | "mimetype": "text/x-python", 202 | "name": "python", 203 | "nbconvert_exporter": "python", 204 | "pygments_lexer": "ipython3", 205 | "version": "3.7.16" 206 | } 207 | }, 208 | "nbformat": 4, 209 | "nbformat_minor": 2 210 | } 211 | -------------------------------------------------------------------------------- /src/byprot/models/seq2seq/transformer.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from functools import partial 3 | 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from .modules.embedding import Embedding, PositionEmbedding 8 | 9 | 10 | def bert_init_params(module): 11 | """ 12 | Initialize the weights specific to the BERT Model. 13 | This overrides the default initializations depending on the specified arguments. 14 | 1. If normal_init_linear_weights is set then weights of linear 15 | layer will be initialized using the normal distribution and 16 | bais will be set to the specified value. 17 | 2. If normal_init_embed_weights is set then weights of embedding 18 | layer will be initialized using the normal distribution. 19 | 3. If normal_init_proj_weights is set then weights of 20 | in_project_weight for MultiHeadAttention initialized using 21 | the normal distribution (to be validated). 22 | """ 23 | 24 | def normal_(data): 25 | # with FSDP, module params will be on CUDA, so we cast them back to CPU 26 | # so that the RNG is consistent with and without FSDP 27 | data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) 28 | 29 | if isinstance(module, nn.Linear): 30 | normal_(module.weight.data) 31 | if module.bias is not None: 32 | module.bias.data.zero_() 33 | 34 | if isinstance(module, nn.Embedding): 35 | normal_(module.weight.data) 36 | if module.padding_idx is not None: 37 | module.weight.data[module.padding_idx].zero_() 38 | 39 | # if isinstance(module, MultiheadAttention): 40 | # normal_(module.q_proj.weight.data) 41 | # normal_(module.k_proj.weight.data) 42 | # normal_(module.v_proj.weight.data) 43 | 44 | 45 | def fairseq_init_params(module): 46 | if isinstance(module, nn.Linear): 47 | nn.init.xavier_uniform_(module.weight) 48 | if module.bias is not None: 49 | nn.init.constant_(module.bias, 0.0) 50 | 51 | if isinstance(module, nn.Embedding): 52 | nn.init.normal_(module.weight, mean=0.0, std=module.embedding_dim ** -0.5) 53 | if module.padding_idx is not None: 54 | nn.init.constant_(module.weight[module.padding_idx], 0.0) 55 | 56 | 57 | def create_padding_mask(x, pad=1): 58 | return x.ne(pad) 59 | 60 | 61 | def _set_inferring_flag(mod, flag=True): 62 | mod._inferring = flag 63 | 64 | 65 | class TransformerEncoderDecoder(nn.Module): 66 | def __init__( 67 | self, 68 | cfg, 69 | embed_src=None, 70 | embed_tgt=None, 71 | pos_embed_src=None, 72 | pos_embed_tgt=None, 73 | encoder=None, 74 | decoder=None, 75 | out_proj=None, 76 | ): 77 | super().__init__() 78 | 79 | self.cfg = cfg 80 | self.d_model = self.cfg.d_model 81 | 82 | self.embed_src, self.embed_tgt = embed_src, embed_tgt 83 | self.out_proj = out_proj 84 | 85 | self.pos_embed_src, self.pos_embed_tgt = pos_embed_src, pos_embed_tgt 86 | 87 | self.encoder = encoder 88 | self.decoder = decoder 89 | 90 | self.reset_parameters() 91 | 92 | @classmethod 93 | def build( 94 | cls, cfg, 95 | vocab_src, vocab_tgt, 96 | token_embedding, position_embedding, 97 | encoder, decoder, 98 | ) -> "TransformerEncoderDecoder": 99 | 100 | src_embed = token_embedding(len(vocab_src), cfg.d_model, padding_idx=vocab_src.pad) 101 | tgt_embed = token_embedding(len(vocab_tgt), cfg.d_model, padding_idx=vocab_tgt.pad) 102 | out_proj = nn.Linear(cfg.d_model, len(vocab_tgt), bias=False) 103 | 104 | src_pos_embed = position_embedding(cfg.d_model) 105 | tgt_pos_embed = position_embedding(cfg.d_model) 106 | 107 | model = cls(cfg, src_embed, tgt_embed, 108 | src_pos_embed, tgt_pos_embed, encoder, decoder, out_proj) 109 | model.apply(partial(_set_inferring_flag, flag=False)) 110 | 111 | # initializtion model with BERT style, which was found magically good... 112 | model.apply(bert_init_params) 113 | 114 | if cfg.share_input_output_embedding: 115 | out_proj.weight = tgt_embed.weight 116 | if cfg.share_source_target_embedding: 117 | src_embed.weight = tgt_embed.weight 118 | return model 119 | 120 | def reset_parameters(self): 121 | for child in self.children(): 122 | child.reset_parameters() 123 | 124 | @contextlib.contextmanager 125 | def inference_mode(self, mode=True): 126 | self.apply(partial(_set_inferring_flag, flag=mode)) 127 | yield 128 | self.apply(partial(_set_inferring_flag, flag=False)) 129 | 130 | def forward(self, src_tokens, tgt_tokens, src_padding_mask, tgt_padding_mask): 131 | """ 132 | Args: 133 | src_tokens (LongTensor): source tokens [bsz, slen] 134 | tgt_tokens (LongTensor): target tokens [bsz, tlen] 135 | """ 136 | encoder_out = self.encode(src_tokens, src_padding_mask) 137 | decoder_out = self.decode(tgt_tokens, encoder_out, tgt_padding_mask, src_padding_mask) 138 | logits = self.output(decoder_out, normalize=False) 139 | return logits 140 | 141 | def encode(self, src_tokens, src_padding_mask=None): 142 | src_emb = self.embed_src(src_tokens) 143 | encoder_input = self.pos_embed_src(src_emb) 144 | encoder_out = self.encoder(encoder_input, src_padding_mask) 145 | return encoder_out 146 | 147 | def decode(self, tgt_tokens, encoder_out, 148 | tgt_padding_mask=None, src_padding_mask=None, 149 | incremental_states=None): 150 | tgt_emb = self.embed_tgt(tgt_tokens) 151 | decoder_input = self.pos_embed_tgt(tgt_emb) 152 | 153 | if incremental_states is not None: 154 | # [bsz, len, d] -> [bsz, 1, d] 155 | decoder_input = decoder_input[:, -1:] 156 | tgt_padding_mask = None 157 | 158 | decoder_out = self.decoder(decoder_input, encoder_out, 159 | tgt_padding_mask, src_padding_mask, 160 | incremental_states=incremental_states) 161 | return decoder_out 162 | 163 | def output(self, decoder_out, normalize=True): 164 | logits = self.out_proj(decoder_out) 165 | return F.log_softmax(logits, dim=-1) if normalize else logits 166 | -------------------------------------------------------------------------------- /src/byprot/models/fixedbb/generator.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import math 3 | import os 4 | import pickle 5 | import re 6 | import shutil 7 | from pathlib import Path 8 | from typing import List, Sequence, Tuple, Union, Mapping 9 | 10 | import torch 11 | from torch import nn 12 | 13 | 14 | def _skeptical_unmasking(output_scores, output_masks, p): 15 | sorted_index = output_scores.sort(-1)[1] 16 | boundary_len = ( 17 | (output_masks.sum(1, keepdim=True).type_as(output_scores) - 2) * p 18 | ).long() 19 | # `length * p`` positions with lowest scores get kept 20 | skeptical_mask = new_arange(output_masks) < boundary_len 21 | return skeptical_mask.scatter(1, sorted_index, skeptical_mask) 22 | 23 | 24 | def exists(obj): 25 | return obj is not None 26 | 27 | 28 | def new_arange(x, *size): 29 | """ 30 | Return a Tensor of `size` filled with a range function on the device of x. 31 | If size is empty, using the size of the variable x. 32 | """ 33 | if len(size) == 0: 34 | size = x.size() 35 | return torch.arange(size[-1], device=x.device).expand(*size).contiguous() 36 | 37 | 38 | def maybe_remove_batch_dim(tensor): 39 | if len(tensor.shape) > 1 and tensor.shape[0] == 1: 40 | tensor = tensor.squeeze(0) 41 | return tensor 42 | 43 | 44 | class IterativeRefinementGenerator(object): 45 | def __init__(self, 46 | alphabet=None, 47 | max_iter=1, 48 | strategy='denoise', 49 | temperature=None, 50 | **kwargs 51 | ): 52 | 53 | self.alphabet = alphabet 54 | self.padding_idx = alphabet.padding_idx 55 | self.mask_idx = alphabet.mask_idx 56 | 57 | self.max_iter = max_iter 58 | self.strategy = strategy 59 | self.temperature = temperature 60 | 61 | @torch.no_grad() 62 | def generate(self, model, batch, alphabet=None, 63 | max_iter=None, strategy=None, temperature=None, replace_visible_tokens=False, 64 | need_attn_weights=False): 65 | alphabet = alphabet or self.alphabet 66 | max_iter = max_iter or self.max_iter 67 | strategy = strategy or self.strategy 68 | temperature = temperature or self.temperature 69 | 70 | # 0) encoding 71 | encoder_out = model.forward_encoder(batch) 72 | 73 | # 1) initialized from all mask tokens 74 | initial_output_tokens, initial_output_scores = model.initialize_output_tokens( 75 | batch, encoder_out=encoder_out) 76 | prev_decoder_out = dict( 77 | output_tokens=initial_output_tokens, 78 | output_scores=initial_output_scores, 79 | output_masks=None, 80 | attentions=None, 81 | step=0, 82 | max_step=max_iter, 83 | history=[initial_output_tokens.clone()], 84 | temperature=temperature, 85 | ) 86 | 87 | if need_attn_weights: 88 | attns = [] # list of {'in', 'out', 'attn'} for all iteration 89 | 90 | if strategy == 'discrete_diffusion': 91 | prev_decoder_out['output_masks'] = model.get_non_special_sym_mask(batch['prev_tokens']) 92 | 93 | # iterative refinement 94 | for step in range(max_iter): 95 | 96 | # 2.1: predict 97 | decoder_out = model.forward_decoder( 98 | prev_decoder_out=prev_decoder_out, 99 | encoder_out=encoder_out, 100 | need_attn_weights=need_attn_weights 101 | ) 102 | 103 | output_tokens = decoder_out['output_tokens'] 104 | output_scores = decoder_out['output_scores'] 105 | 106 | # 2.2: re-mask skeptical parts of low confidence 107 | # skeptical decoding (depend on the maximum decoding steps.) 108 | if ( 109 | strategy == 'mask_predict' 110 | and (step + 1) < max_iter 111 | ): 112 | skeptical_mask = _skeptical_unmasking( 113 | output_scores=output_scores, 114 | output_masks=output_tokens.ne(self.padding_idx), # & coord_mask, 115 | p=1 - (step + 1) / max_iter 116 | ) 117 | 118 | output_tokens.masked_fill_(skeptical_mask, self.mask_idx) 119 | output_scores.masked_fill_(skeptical_mask, 0.0) 120 | 121 | elif strategy == 'denoise' or strategy == 'no': 122 | pass 123 | elif strategy == 'discrete_diffusion': 124 | pass 125 | else: 126 | pass 127 | 128 | if replace_visible_tokens: 129 | visible_token_mask = ~batch['prev_token_mask'] 130 | visible_tokens = batch['prev_tokens'] 131 | output_tokens = torch.where( 132 | visible_token_mask, visible_tokens, output_tokens) 133 | 134 | if need_attn_weights: 135 | attns.append( 136 | dict(input=maybe_remove_batch_dim(prev_decoder_out['output_tokens']), 137 | output=maybe_remove_batch_dim(output_tokens), 138 | attn_weights=maybe_remove_batch_dim(decoder_out['attentions'])) 139 | ) 140 | 141 | prev_decoder_out.update( 142 | output_tokens=output_tokens, 143 | output_scores=output_scores, 144 | step=step + 1, 145 | history=decoder_out['history'] 146 | ) 147 | 148 | # skeptical_mask = _skeptical_unmasking( 149 | # output_scores=output_scores, 150 | # output_masks=output_tokens.ne(self.padding_idx), # & coord_mask, 151 | # p=0.08 152 | # ) 153 | 154 | # output_tokens.masked_fill_(skeptical_mask, self.alphabet.unk_idx) 155 | # output_scores.masked_fill_(skeptical_mask, 0.0) 156 | decoder_out = prev_decoder_out 157 | 158 | if need_attn_weights: 159 | return decoder_out['output_tokens'], decoder_out['output_scores'], attns 160 | return decoder_out['output_tokens'], decoder_out['output_scores'] 161 | 162 | 163 | def sample_from_categorical(logits=None, temperature=1.0): 164 | if temperature: 165 | dist = torch.distributions.Categorical(logits=logits.div(temperature)) 166 | tokens = dist.sample() 167 | scores = dist.log_prob(tokens) 168 | else: 169 | scores, tokens = logits.log_softmax(dim=-1).max(dim=-1) 170 | return tokens, scores 171 | -------------------------------------------------------------------------------- /examples/pmpnn_compatible/helper_scripts/parse_multiple_chains.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def main(args): 4 | 5 | import numpy as np 6 | import os, time, gzip, json 7 | import glob 8 | 9 | folder_with_pdbs_path = args.input_path 10 | save_path = args.output_path 11 | ca_only = args.ca_only 12 | 13 | alpha_1 = list("ARNDCQEGHILKMFPSTWYV-") 14 | states = len(alpha_1) 15 | alpha_3 = ['ALA','ARG','ASN','ASP','CYS','GLN','GLU','GLY','HIS','ILE', 16 | 'LEU','LYS','MET','PHE','PRO','SER','THR','TRP','TYR','VAL','GAP'] 17 | 18 | aa_1_N = {a:n for n,a in enumerate(alpha_1)} 19 | aa_3_N = {a:n for n,a in enumerate(alpha_3)} 20 | aa_N_1 = {n:a for n,a in enumerate(alpha_1)} 21 | aa_1_3 = {a:b for a,b in zip(alpha_1,alpha_3)} 22 | aa_3_1 = {b:a for a,b in zip(alpha_1,alpha_3)} 23 | 24 | def AA_to_N(x): 25 | # ["ARND"] -> [[0,1,2,3]] 26 | x = np.array(x); 27 | if x.ndim == 0: x = x[None] 28 | return [[aa_1_N.get(a, states-1) for a in y] for y in x] 29 | 30 | def N_to_AA(x): 31 | # [[0,1,2,3]] -> ["ARND"] 32 | x = np.array(x); 33 | if x.ndim == 1: x = x[None] 34 | return ["".join([aa_N_1.get(a,"-") for a in y]) for y in x] 35 | 36 | 37 | def parse_PDB_biounits(x, atoms=['N','CA','C'], chain=None): 38 | ''' 39 | input: x = PDB filename 40 | atoms = atoms to extract (optional) 41 | output: (length, atoms, coords=(x,y,z)), sequence 42 | ''' 43 | xyz,seq,min_resn,max_resn = {},{},1e6,-1e6 44 | for line in open(x,"rb"): 45 | line = line.decode("utf-8","ignore").rstrip() 46 | 47 | if line[:6] == "HETATM" and line[17:17+3] == "MSE": 48 | line = line.replace("HETATM","ATOM ") 49 | line = line.replace("MSE","MET") 50 | 51 | if line[:4] == "ATOM": 52 | ch = line[21:22] 53 | if ch == chain or chain is None: 54 | atom = line[12:12+4].strip() 55 | resi = line[17:17+3] 56 | resn = line[22:22+5].strip() 57 | x,y,z = [float(line[i:(i+8)]) for i in [30,38,46]] 58 | 59 | if resn[-1].isalpha(): 60 | resa,resn = resn[-1],int(resn[:-1])-1 61 | else: 62 | resa,resn = "",int(resn)-1 63 | # resn = int(resn) 64 | if resn < min_resn: 65 | min_resn = resn 66 | if resn > max_resn: 67 | max_resn = resn 68 | if resn not in xyz: 69 | xyz[resn] = {} 70 | if resa not in xyz[resn]: 71 | xyz[resn][resa] = {} 72 | if resn not in seq: 73 | seq[resn] = {} 74 | if resa not in seq[resn]: 75 | seq[resn][resa] = resi 76 | 77 | if atom not in xyz[resn][resa]: 78 | xyz[resn][resa][atom] = np.array([x,y,z]) 79 | 80 | # convert to numpy arrays, fill in missing values 81 | seq_,xyz_ = [],[] 82 | try: 83 | for resn in range(min_resn,max_resn+1): 84 | if resn in seq: 85 | for k in sorted(seq[resn]): seq_.append(aa_3_N.get(seq[resn][k],20)) 86 | else: seq_.append(20) 87 | if resn in xyz: 88 | for k in sorted(xyz[resn]): 89 | for atom in atoms: 90 | if atom in xyz[resn][k]: xyz_.append(xyz[resn][k][atom]) 91 | else: xyz_.append(np.full(3,np.nan)) 92 | else: 93 | for atom in atoms: xyz_.append(np.full(3,np.nan)) 94 | return np.array(xyz_).reshape(-1,len(atoms),3), N_to_AA(np.array(seq_)) 95 | except TypeError: 96 | return 'no_chain', 'no_chain' 97 | 98 | 99 | 100 | pdb_dict_list = [] 101 | c = 0 102 | 103 | if folder_with_pdbs_path[-1]!='/': 104 | folder_with_pdbs_path = folder_with_pdbs_path+'/' 105 | 106 | 107 | init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z'] 108 | extra_alphabet = [str(item) for item in list(np.arange(300))] 109 | chain_alphabet = init_alphabet + extra_alphabet 110 | 111 | biounit_names = glob.glob(folder_with_pdbs_path+'*.pdb') 112 | for biounit in biounit_names: 113 | my_dict = {} 114 | s = 0 115 | concat_seq = '' 116 | concat_N = [] 117 | concat_CA = [] 118 | concat_C = [] 119 | concat_O = [] 120 | concat_mask = [] 121 | coords_dict = {} 122 | for letter in chain_alphabet: 123 | if ca_only: 124 | sidechain_atoms = ['CA'] 125 | else: 126 | sidechain_atoms = ['N', 'CA', 'C', 'O'] 127 | xyz, seq = parse_PDB_biounits(biounit, atoms=sidechain_atoms, chain=letter) 128 | if type(xyz) != str: 129 | concat_seq += seq[0] 130 | my_dict['seq_chain_'+letter]=seq[0] 131 | coords_dict_chain = {} 132 | if ca_only: 133 | coords_dict_chain['CA_chain_'+letter]=xyz.tolist() 134 | else: 135 | coords_dict_chain['N_chain_' + letter] = xyz[:, 0, :].tolist() 136 | coords_dict_chain['CA_chain_' + letter] = xyz[:, 1, :].tolist() 137 | coords_dict_chain['C_chain_' + letter] = xyz[:, 2, :].tolist() 138 | coords_dict_chain['O_chain_' + letter] = xyz[:, 3, :].tolist() 139 | my_dict['coords_chain_'+letter]=coords_dict_chain 140 | s += 1 141 | fi = biounit.rfind("/") 142 | my_dict['name']=biounit[(fi+1):-4] 143 | my_dict['num_of_chains'] = s 144 | my_dict['seq'] = concat_seq 145 | if s < len(chain_alphabet): 146 | pdb_dict_list.append(my_dict) 147 | c+=1 148 | 149 | 150 | with open(save_path, 'w') as f: 151 | for entry in pdb_dict_list: 152 | f.write(json.dumps(entry) + '\n') 153 | 154 | 155 | if __name__ == "__main__": 156 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 157 | 158 | argparser.add_argument("--input_path", type=str, help="Path to a folder with pdb files, e.g. /home/my_pdbs/") 159 | argparser.add_argument("--output_path", type=str, help="Path where to save .jsonl dictionary of parsed pdbs") 160 | argparser.add_argument("--ca_only", action="store_true", default=False, help="parse a backbone-only structure (default: false)") 161 | 162 | args = argparser.parse_args() 163 | main(args) 164 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: ByProt 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.01.10=h06a4308_0 8 | - certifi=2022.12.7=py37h06a4308_0 9 | - ld_impl_linux-64=2.38=h1181459_1 10 | - libffi=3.4.4=h6a678d5_0 11 | - libgcc-ng=11.2.0=h1234567_1 12 | - libgomp=11.2.0=h1234567_1 13 | - libstdcxx-ng=11.2.0=h1234567_1 14 | - ncurses=6.4=h6a678d5_0 15 | - openssl=1.1.1t=h7f8727e_0 16 | - pip=22.3.1=py37h06a4308_0 17 | - python=3.7.16=h7a1cb2a_0 18 | - readline=8.2=h5eee18b_0 19 | - setuptools=65.6.3=py37h06a4308_0 20 | - sqlite=3.41.2=h5eee18b_0 21 | - tk=8.6.12=h1ccaba5_0 22 | - wheel=0.38.4=py37h06a4308_0 23 | - xz=5.4.2=h5eee18b_0 24 | - zlib=1.2.13=h5eee18b_0 25 | - pip: 26 | - absl-py==1.4.0 27 | - aiohttp==3.8.4 28 | - aiosignal==1.3.1 29 | - alabaster==0.7.13 30 | - alembic==1.11.1 31 | - antlr4-python3-runtime==4.9.3 32 | - anyio==3.7.0 33 | - argon2-cffi==21.3.0 34 | - argon2-cffi-bindings==21.2.0 35 | - async-timeout==4.0.2 36 | - asynctest==0.13.0 37 | - attrs==23.1.0 38 | - autopage==0.5.1 39 | - babel==2.12.1 40 | - backcall==0.2.0 41 | - beautifulsoup4==4.12.2 42 | - biopython==1.79 43 | - biotite==0.37.0 44 | - black==23.3.0 45 | - bleach==6.0.0 46 | - byted-torch==1.12.0.post0 47 | - byteps==0.3.0 48 | - cachetools==5.3.1 49 | - cffi==1.15.1 50 | - cfgv==3.3.1 51 | - charset-normalizer==3.1.0 52 | - click==8.1.3 53 | - cliff==3.10.1 54 | - cloudpickle==2.2.1 55 | - cmaes==0.9.1 56 | - cmd2==2.4.3 57 | - colorlog==6.7.0 58 | - cycler==0.11.0 59 | - debugpy==1.6.7 60 | - decorator==5.1.1 61 | - defusedxml==0.7.1 62 | - dgl-cu113==0.9.1.post1 63 | - dglgo==0.0.2 64 | - distlib==0.3.6 65 | - docutils==0.19 66 | - e3nn==0.5.1 67 | - einops==0.6.1 68 | - entrypoints==0.4 69 | - exceptiongroup==1.1.1 70 | - fair-esm==2.0.1 71 | - fairscale==0.4.6 72 | - fastjsonschema==2.17.1 73 | - filelock==3.12.0 74 | - flake8==5.0.4 75 | - fonttools==4.38.0 76 | - frozenlist==1.3.3 77 | - fsspec==2023.1.0 78 | - google-auth==2.19.1 79 | - google-auth-oauthlib==0.4.6 80 | - greenlet==2.0.2 81 | - grpcio==1.54.2 82 | - hydra-colorlog==1.2.0 83 | - hydra-core==1.2.0 84 | - hydra-optuna-sweeper==1.2.0 85 | - identify==2.5.24 86 | - idna==3.4 87 | - imagesize==1.4.1 88 | - importlib-metadata==6.7.0 89 | - importlib-resources==5.12.0 90 | - iniconfig==2.0.0 91 | - ipykernel==6.16.2 92 | - ipython==7.34.0 93 | - ipython-genutils==0.2.0 94 | - isort==5.11.5 95 | - jedi==0.18.2 96 | - jinja2==3.1.2 97 | - joblib==1.2.0 98 | - jsonschema==4.17.3 99 | - jupyter-client==7.4.9 100 | - jupyter-core==4.12.0 101 | - jupyter-server==1.24.0 102 | - jupyterlab-pygments==0.2.2 103 | - kiwisolver==1.4.4 104 | - littleutils==0.2.2 105 | - lmdb==1.4.1 106 | - mako==1.2.4 107 | - markdown==3.3.4 108 | - markdown-it-py==2.2.0 109 | - markupsafe==2.1.3 110 | - matplotlib==3.5.3 111 | - matplotlib-inline==0.1.6 112 | - mccabe==0.7.0 113 | - mdurl==0.1.2 114 | - mistune==2.0.5 115 | - mpmath==1.3.0 116 | - msgpack==1.0.5 117 | - multidict==6.0.4 118 | - mypy-extensions==1.0.0 119 | - nbclassic==1.0.0 120 | - nbclient==0.7.4 121 | - nbconvert==7.4.0 122 | - nbformat==5.8.0 123 | - nbstripout==0.6.1 124 | - nest-asyncio==1.5.6 125 | - networkx==2.6.3 126 | - nodeenv==1.8.0 127 | - notebook==6.5.4 128 | - notebook-shim==0.2.3 129 | - numpy==1.21.6 130 | - numpydoc==1.5.0 131 | - oauthlib==3.2.2 132 | - ogb==1.3.6 133 | - omegaconf==2.3.0 134 | - opt-einsum==3.3.0 135 | - opt-einsum-fx==0.1.4 136 | - optuna==2.10.1 137 | - outdated==0.2.2 138 | - packaging==23.1 139 | - pandas==1.3.5 140 | - pandocfilters==1.5.0 141 | - parso==0.8.3 142 | - pathspec==0.11.1 143 | - pbr==5.11.1 144 | - pexpect==4.8.0 145 | - pickleshare==0.7.5 146 | - pillow==9.5.0 147 | - pkgutil-resolve-name==1.3.10 148 | - platformdirs==2.6.2 149 | - pluggy==1.0.0 150 | - portalocker==2.7.0 151 | - pre-commit==2.21.0 152 | - prettytable==3.7.0 153 | - prometheus-client==0.17.0 154 | - prompt-toolkit==3.0.38 155 | - protobuf==3.20.3 156 | - psutil==5.9.5 157 | - ptyprocess==0.7.0 158 | - pudb==2022.1.3 159 | - pyasn1==0.5.0 160 | - pyasn1-modules==0.3.0 161 | - pycodestyle==2.9.1 162 | - pycparser==2.21 163 | - pydantic==1.10.9 164 | - pydeprecate==0.3.2 165 | - pyflakes==2.5.0 166 | - pygments==2.15.1 167 | - pyparsing==3.0.9 168 | - pyperclip==1.8.2 169 | - pyrootutils==1.0.4 170 | - pyrsistent==0.19.3 171 | - pytest==7.3.1 172 | - python-dateutil==2.8.2 173 | - python-dotenv==0.21.1 174 | - pytorch-lightning==1.7.3 175 | - pytz==2023.3 176 | - pyyaml==6.0 177 | - pyzmq==25.1.0 178 | - rdkit-pypi==2022.9.5 179 | - requests==2.31.0 180 | - requests-oauthlib==1.3.1 181 | - rich==13.4.1 182 | - rsa==4.9 183 | - ruamel-yaml==0.17.32 184 | - ruamel-yaml-clib==0.2.7 185 | - scikit-learn==1.0.2 186 | - scipy==1.7.3 187 | - seaborn==0.12.2 188 | - send2trash==1.8.2 189 | - sh==1.14.3 190 | - six==1.16.0 191 | - sniffio==1.3.0 192 | - snowballstemmer==2.2.0 193 | - soupsieve==2.4.1 194 | - sphinx==5.3.0 195 | - sphinxcontrib-applehelp==1.0.2 196 | - sphinxcontrib-devhelp==1.0.2 197 | - sphinxcontrib-htmlhelp==2.0.0 198 | - sphinxcontrib-jsmath==1.0.1 199 | - sphinxcontrib-qthelp==1.0.3 200 | - sphinxcontrib-serializinghtml==1.1.5 201 | - sqlalchemy==2.0.15 202 | - stevedore==3.5.2 203 | - sympy==1.10.1 204 | - tensorboard==2.11.2 205 | - tensorboard-data-server==0.6.1 206 | - tensorboard-plugin-wit==1.8.1 207 | - terminado==0.17.1 208 | - threadpoolctl==3.1.0 209 | - tinycss2==1.2.1 210 | - tomli==2.0.1 211 | - torch==1.12.0 212 | - torch-cluster==1.6.1 213 | - torch-geometric==2.3.1 214 | - torch-scatter==2.1.1 215 | - torch-sparse==0.6.17 216 | - torch-spline-conv==1.2.2 217 | - torchdata==0.4.0 218 | - torchmetrics==0.11.4 219 | - torchtext==0.13.0 220 | - tornado==6.2 221 | - tqdm==4.65.0 222 | - traitlets==5.9.0 223 | - typed-ast==1.5.4 224 | - typer==0.9.0 225 | - typing-extensions==4.6.3 226 | - urllib3==1.26.16 227 | - urwid==2.1.2 228 | - urwid-readline==0.13 229 | - virtualenv==20.16.2 230 | - wcwidth==0.2.6 231 | - webencodings==0.5.1 232 | - websocket-client==1.5.2 233 | - werkzeug==2.2.3 234 | - yarl==1.9.2 235 | - zipp==3.15.0 236 | -------------------------------------------------------------------------------- /src/byprot/datamodules/multichain_datamodule.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from functools import partial 4 | from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple 5 | 6 | import joblib 7 | import numpy as np 8 | import torch 9 | from byprot import utils 10 | from byprot.datamodules import register_datamodule 11 | from pytorch_lightning import LightningDataModule 12 | from torch.utils.data import Dataset 13 | from tqdm.auto import tqdm 14 | 15 | from .datasets.data_utils import Alphabet, MaxTokensBatchSampler 16 | from .datasets.multichain import (PDB_dataset2, StructureDataset, 17 | StructureLoader, build_training_clusters, 18 | featurize, loader_pdb, parse_pdb, 19 | worker_init_fn) 20 | 21 | log = utils.get_logger(__name__) 22 | 23 | 24 | MAP_SIZE = 32 * (1024 * 1024 * 1024) # 32GB 25 | 26 | 27 | def make_dataset( 28 | split, args, alphabet, 29 | params=None, 30 | load_params=None, 31 | data_path=None, 32 | deterministic=False 33 | ): 34 | if data_path is None: 35 | data_path = args.data_dir 36 | if params is None: 37 | params = { 38 | "LIST": f"{data_path}/list.csv", 39 | "VAL": f"{data_path}/valid_clusters.txt", 40 | "TEST": f"{data_path}/test_clusters.txt", 41 | "DIR": f"{data_path}", 42 | "DATCUT": "2030-Jan-01", 43 | "RESCUT": args.rescut, # resolution cutoff for PDBs 44 | "HOMO": 0.70, # min seq.id. to detect homo chains 45 | "MAXLEN": args.max_length, 46 | } 47 | params["DETERMINISTIC"] = deterministic 48 | 49 | pdb_dataset = PDB_dataset2( 50 | list(split.keys()), loader_pdb, 51 | split, params, 52 | random_select=not deterministic 53 | ) 54 | pdb_loader = torch.utils.data.DataLoader( 55 | pdb_dataset, batch_size=1, 56 | worker_init_fn=worker_init_fn, 57 | pin_memory=False, shuffle=not deterministic) 58 | 59 | pdb_dict_list = joblib.Parallel( 60 | n_jobs=max(int(joblib.cpu_count() * 0.8), 1), 61 | )( 62 | joblib.delayed(parse_pdb)( 63 | task={ 64 | 'entry': entry, 65 | 'max_length': args.max_length, 66 | 'params': params 67 | } 68 | ) 69 | for entry, _ in tqdm(pdb_loader, dynamic_ncols=True, desc='Parse PDBs') 70 | ) 71 | pdb_dict_list = filter(None, pdb_dict_list) 72 | 73 | dataset = StructureDataset( 74 | pdb_dict_list, 75 | alphabet=alphabet, 76 | truncate=None, 77 | max_length=args.max_length 78 | ) 79 | return dataset 80 | 81 | 82 | @register_datamodule('multichain') 83 | class MultichainDataModule(LightningDataModule): 84 | def __init__( 85 | self, 86 | data_dir: str = "data/", 87 | max_length=500, 88 | rescut=3.5, 89 | atoms: List[str] = ('N', 'CA', 'C', 'O'), 90 | alphabet=None, 91 | batch_size: int = 64, 92 | max_tokens: int = 6000, 93 | sort: bool = False, 94 | num_workers: int = 0, 95 | pin_memory: bool = False, 96 | train_split: str = 'train', 97 | valid_split: str = 'valid', 98 | test_split: str = 'test', 99 | to_sabdab_format: bool = False, 100 | to_pifold_format: bool = False, 101 | debug=False 102 | ): 103 | super().__init__() 104 | 105 | # this line allows to access init params with 'self.hparams' attribute 106 | self.save_hyperparameters(logger=False) 107 | 108 | self.alphabet = None 109 | 110 | self.train_dataset: Optional[Dataset] = None 111 | self.valid_dataset: Optional[Dataset] = None 112 | self.test_dataset: Optional[Dataset] = None 113 | self.predict_dataset: Optional[Dataset] = None 114 | 115 | def prepare_data(self): 116 | """Download data if needed. 117 | 118 | This method is called only from a single GPU. 119 | Do not use it to assign state (self.x = y). 120 | """ 121 | # MNIST(self.hparams.data_dir, train=True, download=True) 122 | # MNIST(self.hparams.data_dir, train=False, download=True) 123 | pass 124 | 125 | def setup(self, stage: Optional[str] = None): 126 | t0 = time.time() 127 | 128 | data_path = self.hparams.data_dir 129 | params = { 130 | "LIST": f"{data_path}/list.csv", 131 | "VAL": f"{data_path}/valid_clusters.txt", 132 | "TEST": f"{data_path}/test_clusters.txt", 133 | "DIR": f"{data_path}", 134 | "DATCUT": "2030-Jan-01", 135 | "RESCUT": self.hparams.rescut, # resolution cutoff for PDBs 136 | "HOMO": 0.70, # min seq.id. to detect homo chains 137 | "MAXLEN": self.hparams.max_length 138 | } 139 | 140 | self.hparams.data_cache = f"{data_path}/cache.db" 141 | self.hparams.batch_size = self.hparams.max_tokens 142 | if self.hparams.debug: 143 | self.hparams.num_examples_per_epoch = 50 144 | self.hparams.max_protein_length = 1000 145 | self.hparams.batch_size = 1000 146 | 147 | if os.path.exists(self.hparams.data_cache): 148 | train, valid, test = load_cache(self.hparams.data_cache, ['train', 'valid', 'test']) 149 | else: 150 | train, valid, test = build_training_clusters(params, debug=self.hparams.debug) 151 | 152 | _deterministic = False 153 | if stage != 'fit': 154 | _deterministic = True 155 | 156 | # if self.hparams.use_esm_alphabet: 157 | # self.alphabet = Alphabet(name='esm', dataset='multichain') 158 | # else: 159 | # self.alphabet = Alphabet(name='protein_mpnn', dataset='multichain') 160 | self.alphabet = Alphabet(**self.hparams.alphabet) 161 | self.collate_fn = partial( 162 | self.alphabet.featurize, 163 | deterministic=_deterministic) 164 | 165 | _make_dataset = partial( 166 | make_dataset, 167 | args=self.hparams, 168 | params=params, 169 | alphabet=self.alphabet.all_toks, 170 | deterministic=_deterministic 171 | ) 172 | if stage == 'fit': 173 | self.train_dataset = _make_dataset(split=train) 174 | self.valid_dataset = _make_dataset(split=valid) 175 | elif stage == 'test' or stage == 'predict': 176 | self.test_dataset = _make_dataset(split=test) 177 | else: 178 | raise ValueError(f"Invalid stage: {stage}.") 179 | 180 | log.info(f'Data loaded (elapsed {time.time() - t0}.') 181 | 182 | def _build_batch_sampler(self, dataset, max_tokens, shuffle=False, distributed=True): 183 | is_distributed = distributed and torch.distributed.is_initialized() 184 | 185 | batch_sampler = MaxTokensBatchSampler( 186 | dataset=dataset, 187 | shuffle=shuffle, 188 | distributed=is_distributed, 189 | batch_size=self.hparams.batch_size, 190 | max_tokens=max_tokens, 191 | sort=self.hparams.sort, 192 | drop_last=False, 193 | sort_key=lambda i: len(dataset[i]['seq'])) 194 | return batch_sampler 195 | 196 | def train_dataloader(self): 197 | return StructureLoader(self.train_dataset, batch_size=self.hparams.max_tokens, collate_fn=self.collate_fn) 198 | 199 | def val_dataloader(self): 200 | return StructureLoader(self.valid_dataset, batch_size=self.hparams.max_tokens, collate_fn=self.collate_fn, shuffle=False) 201 | 202 | def test_dataloader(self): 203 | return StructureLoader(self.test_dataset, batch_size=self.hparams.max_tokens, collate_fn=self.collate_fn, shuffle=False) 204 | --------------------------------------------------------------------------------