├── .gitignore ├── LICENSE ├── README.md ├── environment.yaml ├── eval.py ├── figs └── architecture.png ├── infer.py ├── pdb2json.py ├── script ├── eval_case_study.sh ├── eval_example.pbs └── run_example.pbs ├── src ├── data │ ├── add_noise_to_backbone.py │ ├── get_esm3_structure_seq.py │ ├── get_ss_seq.py │ └── processors │ │ ├── descriptor_features.py │ │ ├── foldseek.py │ │ ├── protein_features.py │ │ └── structure_features.py ├── esm │ ├── __init__.py │ ├── layers │ │ ├── attention.py │ │ ├── blocks.py │ │ ├── codebook.py │ │ ├── ffn.py │ │ ├── geom_attention.py │ │ ├── regression_head.py │ │ ├── rotary.py │ │ ├── structure_proj.py │ │ └── transformer_stack.py │ ├── models │ │ ├── esm3.py │ │ ├── function_decoder.py │ │ └── vqvae.py │ ├── pretrained.py │ ├── sdk │ │ └── api.py │ ├── tokenization │ │ ├── __init__.py │ │ ├── function_tokenizer.py │ │ ├── residue_tokenizer.py │ │ ├── sasa_tokenizer.py │ │ ├── sequence_tokenizer.py │ │ ├── ss_tokenizer.py │ │ ├── structure_tokenizer.py │ │ └── tokenizer_base.py │ └── utils │ │ ├── constants │ │ ├── esm3.py │ │ ├── models.py │ │ └── physics.py │ │ ├── decoding.py │ │ ├── encoding.py │ │ ├── function │ │ ├── encode_decode.py │ │ ├── interpro.py │ │ ├── lsh.py │ │ └── tfidf.py │ │ ├── generation.py │ │ ├── misc.py │ │ ├── noise_schedules.py │ │ ├── residue_constants.py │ │ ├── sampling.py │ │ ├── structure │ │ ├── affine3d.py │ │ ├── aligner.py │ │ ├── lddt.py │ │ ├── normalize_coordinates.py │ │ ├── predicted_aligned_error.py │ │ ├── protein_chain.py │ │ └── protein_structure.py │ │ └── types.py ├── esmfold.py ├── models │ ├── __pycache__ │ │ ├── adapter.cpython-312.pyc │ │ └── pooling.cpython-312.pyc │ ├── adapter.py │ └── pooling.py └── utils │ ├── __pycache__ │ ├── data_utils.cpython-312.pyc │ ├── loss_fn.cpython-312.pyc │ └── metrics.cpython-312.pyc │ ├── data_utils.py │ ├── loss_fn.py │ └── metrics.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .idea/* 163 | wandb/ 164 | ckpt/ 165 | dataset/ 166 | result/ 167 | src/data/weights/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VenusVaccine 2 | 3 |
4 | 5 | [![GitHub](https://img.shields.io/badge/GitHub-VenusVaccine-blue)](https://github.com/ai4protein/VenusVaccine) 6 | 7 | [![Python](https://img.shields.io/badge/Python-3.7%2B-blue)](https://www.python.org/) 8 | [![PyTorch](https://img.shields.io/badge/PyTorch-1.7%2B-red)](https://pytorch.org/) 9 | [![License: CC-BY-NC-ND-4.0](https://img.shields.io/badge/License-CC--BY--NC--ND%204.0-lightgrey.svg)](https://creativecommons.org/licenses/by-nc-nd/4.0/) 10 | 11 |
12 | 13 | ## 📋 Overview 14 | 15 | VenusVaccine is a deep learning-based immunogenicity prediction tool focused on the classification of protective antigen or non-protective antigen. The project leverages advanced pre-trained language models and adapter architectures to interpret immunogenicity based on the multimodal encoding of antigens, including their sequences, structures, and physico-chemical properties. 16 | 17 |
18 | VenusVaccine Architecture 19 |
20 | 21 | ### 🌟 Key Features 22 | 23 | - 🔬 **Versatile Data Processing** 24 | - Support for multiple protein database formats 25 | - Efficient data preprocessing and feature extraction 26 | - Flexible data augmentation strategies 27 | 28 | - 🧬 **Protein Feature Extraction** 29 | - E-descriptor and Z-descriptor physicochemical features 30 | - Foldseek secondary structure prediction 31 | - ESM3 structure sequence encoding 32 | 33 | - 🤖 **Advanced Model Architecture** 34 | - Integration with pre-trained protein language models 35 | - Innovative adapter design 36 | - Support for multiple PLM types (ESM, Bert, AnKh etc.) 37 | 38 | - 📊 **Comprehensive Training Framework** 39 | - Cross-validation support 40 | - Early stopping strategy 41 | - Wandb experiment tracking 42 | - Automated model evaluation 43 | 44 | - 🚀 **High-Performance Computing** 45 | - GPU acceleration support 46 | - Distributed training 47 | - Gradient accumulation optimization 48 | 49 | ## 🛠️ Installation Guide 50 | 51 | ### Requirements 52 | 53 | - Python 3.7+ 54 | - CUDA 11.0+ (for GPU training) 55 | - 8GB+ RAM 56 | 57 | ### Setup Steps 58 | 59 | 1. Clone the repository: 60 | ```bash 61 | git clone https://github.com/songleee/VenusVaccine.git 62 | cd VenusVaccine 63 | ``` 64 | 65 | 2. Create a virtual environment: 66 | ```bash 67 | conda env create -f environment.yaml 68 | ``` 69 | 70 | 3. Download data and checkpoints: 71 | Download the pre-trained model files, training data, and model evaluation results from [Google Drive](https://drive.google.com/drive/folders/1VLEGpFv7jFyWGChzxchxv-D99QUBlqOA?usp=sharing) 72 | 73 | Pre-trained model files should be placed in the `ckpt` directory: 74 | - `ckpt/Bacteria.pt`: Model for bacterial protective antigens 75 | - `ckpt/Virus.pt`: Model for viral protective antigens 76 | - `ckpt/Tumor.pt`: Model for tumor protective antigens 77 | 78 | 4. Download and install dependencies: 79 | - [Foldseek](https://github.com/steineggerlab/foldseek/releases/tag/10-941cd33) 80 | - [ESM3_encoder](https://huggingface.co/EvolutionaryScale/esm3-sm-open-v1/blob/main/data/weights/esm3_structure_encoder_v0.pth) 81 | ```bash 82 | wget https://huggingface.co/EvolutionaryScale/esm3-sm-open-v1/blob/main/data/weights/esm3_structure_encoder_v0.pth 83 | mkdir -p ./src/data/weights 84 | mv esm3_structure_encoder_v0.pth ./src/data/weights 85 | ``` 86 | 87 | ## 📊 Data Processing 88 | 89 | ### Predict single protein sequence 90 | 91 | ```bash 92 | # Predict single protein sequence 93 | python src/esmfold.py --sequence "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG" --out_file output.pdb 94 | 95 | # Predict multiple proteins from FASTA file 96 | python src/esmfold.py --fasta_file proteins.fasta --out_dir pdb_structures --fold_chunk_size 128 97 | 98 | ``` 99 | 100 | ### PDB to JSON Conversion 101 | 102 | Make sure you have got the PDB file (cryo-EM structure or predicted by AF2 or ESMFold) of interest protein first, and use `pdb2json.py` to convert PDB files to a feature-rich JSON format: 103 | 104 | ```bash 105 | python pdb2json.py 106 | ``` 107 | 108 | This tool automatically extracts: 109 | - Amino acid sequence 110 | - ESM3 structure sequence 111 | - Foldseek secondary structure prediction 112 | - E-descriptor (5-dimensional) features 113 | - Z-descriptor (3-dimensional) features 114 | 115 | ## 🚀 Quick Start 116 | 117 | ### Basic Usage 118 | 119 | ```bash 120 | python infer.py -i input.json -t Bacteria 121 | ``` 122 | 123 | ### Command Line Arguments 124 | 125 | ```bash 126 | python infer.py [-h] -i INPUT -t {Bacteria,Virus,Tumor} [--structure_seqs STRUCTURE_SEQS] 127 | [--max_seq_len MAX_SEQ_LEN] [--max_batch_token MAX_BATCH_TOKEN] 128 | [--num_workers NUM_WORKERS] [-o OUTPUT] 129 | ``` 130 | 131 | Arguments: 132 | - `-i, --input`: Path to input JSON file (required) 133 | - `-t, --type`: Pathogen type, choose from: Bacteria, Virus, Tumor (required) 134 | - `--structure_seqs`: Types of structure sequences, comma-separated (default: e_descriptor,z_descriptor,foldseek_seq,esm3_structure_seq) 135 | - `--max_seq_len`: Maximum sequence length (default: 1024) 136 | - `--max_batch_token`: Maximum tokens per batch (default: 10000) 137 | - `--num_workers`: Number of data loading workers (default: 4) 138 | - `-o, --output`: Path to output CSV file (default: results_{type}.csv) 139 | 140 | ### Input Format 141 | 142 | The input should be a JSON file with one sample per line. Fields required depend on the specified structure_seqs parameter: 143 | 144 | ```json 145 | { 146 | "name": "protein1", 147 | "aa_seq": "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", 148 | "foldseek_seq": "HHHEEELLCCHHHHHHHHHHHHSTTHHHHHHHHHHHHHHHHHHHHHHHHEETTEEHHHHHH", 149 | "esm3_structure_seq": [1, 2, 3, \...], 150 | "e_descriptor": [[0.1, 0.2, 0.3, 0.4, 0.5], \...], 151 | "z_descriptor": [[0.1, 0.2, 0.3], \...] 152 | } 153 | ``` 154 | 155 | Required fields: 156 | - `name`: Protein sequence identifier 157 | - `aa_seq`: Amino acid sequence 158 | 159 | Optional fields (depending on structure_seqs parameter): 160 | - `foldseek_seq`: Secondary structure sequence predicted by Foldseek 161 | - `esm3_structure_seq`: Structure sequence predicted by ESM3 162 | - `e_descriptor`: E-descriptor features (5-dimensional) 163 | - `z_descriptor`: Z-descriptor features (3-dimensional) 164 | 165 | ### Output Format 166 | 167 | The output is a CSV file containing: 168 | - `name`: Protein sequence identifier 169 | - `aa_seq`: Amino acid sequence 170 | - `pred_label`: Prediction label (0: non-protective antigen, 1: protective antigen) 171 | - `pred_proba`: Prediction probability of being a protective antigen 172 | 173 | ### Examples 174 | 175 | 1. Predict using all structural features: 176 | ```bash 177 | python infer.py -i proteins.json -t Bacteria 178 | ``` 179 | 180 | 2. Use only specific structural features: 181 | ```bash 182 | python infer.py -i proteins.json -t Virus --structure_seqs "e_descriptor,z_descriptor" 183 | ``` 184 | 185 | 3. Specify output file: 186 | ```bash 187 | python infer.py -i proteins.json -t Tumor -o predictions.csv 188 | ``` 189 | 190 | 4. Adjust sequence length and batch size: 191 | ```bash 192 | python infer.py -i proteins.json -t Bacteria --max_seq_len 512 --max_batch_token 5000 193 | ``` 194 | 195 | ## ⚠️ Important Notes 196 | 197 | 1. Ensure all required dependencies are installed 198 | 2. Make sure corresponding model files exist in the `ckpt` directory (`Bacteria.pt`, `Virus.pt`, or `Tumor.pt`) 199 | 3. Make sure the PLM checkpoints downloaded from huggingface are set up correctly if the network failed 200 | 4. GPU is recommended for better inference performance 201 | 202 | ## 📝 Citation 203 | 204 | If you find this tool helpful, please cite our work: 205 | ``` 206 | @inproceedings{ 207 | li2025immunogenicity, 208 | title={Immunogenicity Prediction with Dual Attention Enables Vaccine Target Selection}, 209 | author={Song Li and Yang Tan and Song Ke and Liang Hong and Bingxin Zhou}, 210 | booktitle={The Thirteenth International Conference on Learning Representations}, 211 | year={2025}, 212 | url={https://openreview.net/forum?id=hWmwL9gizZ} 213 | } 214 | ``` 215 | 216 | ## 📝 License 217 | 218 | This project is licensed under the terms of the [CC-BY-NC-ND-4.0](https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode) license. 219 | 220 | ## 📮 Contact 221 | 222 | - Project Maintainer: Song Li, Yang Tan 223 | - Email: songlee@sjtu.edu.cn 224 | - Issue Tracking: [Issue Page](https://github.com/songleee/VenusVaccine/issues) 225 | 226 | --- 227 | 228 |
229 | ⭐️ If you find this project helpful, please give it a star! 230 |
-------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: venusvaccine 2 | channels: 3 | - conda-forge 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_gnu 10 | - bzip2=1.0.8=h4bc722e_7 11 | - ca-certificates=2024.7.4=hbcca054_0 12 | - ld_impl_linux-64=2.40=hf3520f5_7 13 | - libffi=3.4.2=h7f98852_5 14 | - libgcc-ng=14.1.0=h77fa898_0 15 | - libgomp=14.1.0=h77fa898_0 16 | - libnsl=2.0.1=hd590300_0 17 | - libsqlite=3.46.0=hde9e2c9_0 18 | - libuuid=2.38.1=h0b41bf4_0 19 | - libxcrypt=4.4.36=hd590300_1 20 | - libzlib=1.3.1=h4ab18f5_1 21 | - ncurses=6.5=h59595ed_0 22 | - openssl=3.3.1=h4bc722e_2 23 | - pip=24.2=pyhd8ed1ab_0 24 | - python=3.10.14=hd12c33a_0_cpython 25 | - readline=8.2=h8228510_1 26 | - setuptools=72.1.0=pyhd8ed1ab_0 27 | - tk=8.6.13=noxft_h4845f30_101 28 | - wheel=0.44.0=pyhd8ed1ab_0 29 | - xz=5.2.6=h166bdaf_0 30 | - pip: 31 | - accelerate==0.33.0 32 | - aiohappyeyeballs==2.3.5 33 | - aiohttp==3.10.3 34 | - aiosignal==1.3.1 35 | - async-timeout==4.0.3 36 | - attrs==24.2.0 37 | - biotite==0.41.2 38 | - certifi==2024.7.4 39 | - charset-normalizer==3.3.2 40 | - click==8.1.7 41 | - datasets==2.20.0 42 | - dill==0.3.8 43 | - docker-pycreds==0.4.0 44 | - filelock==3.15.4 45 | - frozenlist==1.4.1 46 | - fsspec==2024.5.0 47 | - gitdb==4.0.11 48 | - gitpython==3.1.43 49 | - huggingface-hub==0.24.5 50 | - idna==3.7 51 | - jinja2==3.1.4 52 | - lightning-utilities==0.11.6 53 | - markupsafe==2.1.5 54 | - mpmath==1.3.0 55 | - msgpack==1.0.8 56 | - multidict==6.0.5 57 | - multiprocess==0.70.16 58 | - networkx==3.3 59 | - numpy==1.26.4 60 | - nvidia-cublas-cu12==12.1.3.1 61 | - nvidia-cuda-cupti-cu12==12.1.105 62 | - nvidia-cuda-nvrtc-cu12==12.1.105 63 | - nvidia-cuda-runtime-cu12==12.1.105 64 | - nvidia-cudnn-cu12==9.1.0.70 65 | - nvidia-cufft-cu12==11.0.2.54 66 | - nvidia-curand-cu12==10.3.2.106 67 | - nvidia-cusolver-cu12==11.4.5.107 68 | - nvidia-cusparse-cu12==12.1.0.106 69 | - nvidia-nccl-cu12==2.20.5 70 | - nvidia-nvjitlink-cu12==12.6.20 71 | - nvidia-nvtx-cu12==12.1.105 72 | - packaging==24.1 73 | - pandas==2.2.2 74 | - platformdirs==4.2.2 75 | - protobuf==5.27.3 76 | - psutil==6.0.0 77 | - pyarrow==17.0.0 78 | - pyarrow-hotfix==0.6 79 | - python-dateutil==2.9.0.post0 80 | - pytz==2024.1 81 | - pyyaml==6.0.2 82 | - regex==2024.7.24 83 | - requests==2.32.3 84 | - safetensors==0.4.4 85 | - sentry-sdk==2.12.0 86 | - setproctitle==1.3.3 87 | - six==1.16.0 88 | - smmap==5.0.1 89 | - sympy==1.13.2 90 | - tokenizers==0.19.1 91 | - torch==2.4.0 92 | - torchmetrics==1.4.1 93 | - tqdm==4.66.5 94 | - transformers==4.44.0 95 | - triton==3.0.0 96 | - typing-extensions==4.12.2 97 | - tzdata==2024.1 98 | - urllib3==2.2.2 99 | - wandb==0.17.6 100 | - xxhash==3.4.1 101 | - yarl==1.9.4 102 | prefix: /home/lisong/software/anaconda3/envs/venusvaccine 103 | -------------------------------------------------------------------------------- /figs/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4protein/VenusVaccine/a05435581246573845d8a32336340e08c821158d/figs/architecture.png -------------------------------------------------------------------------------- /pdb2json.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pandas as pd 4 | from tqdm import tqdm 5 | import warnings 6 | from src.data.processors.descriptor_features import DescriptorFeatureProcessor 7 | from src.data.processors.structure_features import StructureFeatureProcessor 8 | from src.data.processors.foldseek import FoldseekProcessor 9 | 10 | warnings.filterwarnings("ignore") 11 | 12 | 13 | def process_pdb_folder(pdb_dir, output_json_file): 14 | # Initialize processors 15 | descriptor_processor = DescriptorFeatureProcessor() 16 | structure_processor = StructureFeatureProcessor() 17 | 18 | # Get Foldseek features 19 | foldseek_dict = FoldseekProcessor.run_foldseek_commands(pdb_dir) 20 | 21 | results = [] 22 | for pdb_file in tqdm(os.listdir(pdb_dir)): 23 | if not pdb_file.endswith(".pdb"): 24 | continue 25 | 26 | pdb_path = os.path.join(pdb_dir, pdb_file) 27 | name = pdb_file[:-4] 28 | 29 | # Get structure features 30 | esm3_structure_seq, sequence = structure_processor.get_esm3_structure_seq(pdb_path) 31 | 32 | # Get other features 33 | foldseek_seq = foldseek_dict.get(name) 34 | e_descriptor = descriptor_processor.e_descriptor_embedding(sequence) 35 | z_descriptor = descriptor_processor.z_descriptor_embedding(sequence) 36 | 37 | result = { 38 | "name": name, 39 | "aa_seq": sequence, 40 | "esm3_structure_seq": esm3_structure_seq, 41 | "foldseek_seq": foldseek_seq, 42 | "e_descriptor": e_descriptor, 43 | "z_descriptor": z_descriptor 44 | } 45 | results.append(result) 46 | 47 | # Save results 48 | pd.DataFrame(results).to_json(output_json_file, orient="records", lines=True) 49 | print("JSON file created successfully!") 50 | 51 | if __name__ == "__main__": 52 | if len(sys.argv) != 3: 53 | print("Usage: python pdb2json.py ") 54 | sys.exit(1) 55 | 56 | pdb_dir = sys.argv[1] 57 | output_json_file = sys.argv[2] 58 | process_pdb_folder(pdb_dir, output_json_file) 59 | -------------------------------------------------------------------------------- /script/eval_case_study.sh: -------------------------------------------------------------------------------- 1 | # --------------------case study-------------------- 2 | # ElnaggarLab/ankh-large 3 | # facebook/esm2_t33_650M_UR50D 4 | # Rostlab/prot_bert 5 | dataset=BacteriaBinary 6 | pdb_type=ESMFold 7 | seqs=ez_descriptor,foldseek_seq,esm3_structure_seq 8 | seqs_type=full 9 | plm_group=facebook 10 | plm_model=esm2_t33_650M_UR50D 11 | pooling_head=attention1d 12 | lr=5e-4 13 | num_labels=2 14 | 15 | CUDA_VISIBLE_DEVICES=0 python eval.py \ 16 | --plm_model ${plm_group}/${plm_model} \ 17 | --dataset $dataset \ 18 | --problem_type single_label_classification \ 19 | --num_labels $num_labels \ 20 | --pooling_method $pooling_head \ 21 | --return_attentions \ 22 | --test_file dataset/Case_1_Helicobacter_pylori/case.json \ 23 | --test_result_dir result_random/$plm_model/case1 \ 24 | --metrics auc,accuracy,precision,recall,f1,mcc \ 25 | --structure_seqs $seqs \ 26 | --max_batch_token 10000 \ 27 | --ckpt_root result_random \ 28 | --ckpt_dir $plm_model/$dataset \ 29 | --model_name "$pdb_type"_"$plm_model"_"$pooling_head"_"$lr"_"$seqs_type".pt 30 | -------------------------------------------------------------------------------- /script/eval_example.pbs: -------------------------------------------------------------------------------- 1 | #PBS -q ai 2 | #PBS -l walltime=72:00:00 3 | #PBS -l ncpus=6 4 | #PBS -l ngpus=1 5 | #PBS -l host=ai1 6 | #PBS -l mem=100gb 7 | #PBS -N ProVaccine 8 | #PBS -o out.log 9 | #PBS -e out.log 10 | 11 | cd $PBS_O_WORKDIR 12 | #module purge 13 | #module load Anaconda3 14 | export PATH=/home/lisong/software/anaconda3/bin:$PATH 15 | export PATH=/home/lisong/local/bin:$PATH 16 | export HF_ENDPOINT=https://hf-mirror.com 17 | source activate venusvaccine 18 | 19 | dataset=BacteriaBinary 20 | pdb_type=ESMFold 21 | seqs=ez_descriptor,foldseek_seq,esm3_structure_seq 22 | seqs_type=full 23 | plm_group=Rostlab 24 | plm_model=prot_bert 25 | checkpoint=/home/lisong/huggingface/checkpoints/Rostlab/prot_bert 26 | 27 | pooling_head=attention1d 28 | lr=5e-4 29 | num_labels=2 30 | CUDA_VISIBLE_DEVICES=0 python eval.py \ 31 | --plm_model $checkpoint \ 32 | --dataset $dataset \ 33 | --problem_type single_label_classification \ 34 | --num_labels $num_labels \ 35 | --pooling_method $pooling_head \ 36 | --test_file dataset/$dataset/$pdb_type/test.json \ 37 | --test_result_dir result/$plm_model/$dataset/${seqs_type} \ 38 | --metrics auc,accuracy,precision,recall,f1,mcc \ 39 | --structure_seqs $seqs \ 40 | --max_batch_token 10000 \ 41 | --ckpt_root result \ 42 | --ckpt_dir $plm_model/$dataset \ 43 | --model_name "$pdb_type"_"$plm_model"_"$pooling_head"_"$lr"_"$seqs_type".pt -------------------------------------------------------------------------------- /script/run_example.pbs: -------------------------------------------------------------------------------- 1 | #PBS -q ai 2 | #PBS -l walltime=72:00:00 3 | #PBS -l ncpus=6 4 | #PBS -l ngpus=1 5 | #PBS -l host=ai1 6 | #PBS -l mem=100gb 7 | #PBS -N VenusVaccine 8 | #PBS -o out.log 9 | #PBS -e out.log 10 | 11 | cd $PBS_O_WORKDIR 12 | #module purge 13 | #module load Anaconda3 14 | export PATH=/home/lisong/software/anaconda3/bin:$PATH 15 | export PATH=/home/lisong/local/bin:$PATH 16 | export HF_ENDPOINT=https://hf-mirror.com 17 | source activate venusvaccine 18 | 19 | # ElnaggarLab/ankh-large 20 | # facebook/esm2_t33_650M_UR50D 21 | # Rostlab/prot_bert 22 | dataset=BacteriaBinary 23 | pdb_type=ESMFold 24 | seqs=ez_descriptor,foldseek_seq,esm3_structure_seq 25 | seqs_type=full 26 | plm_group=Rostlab 27 | plm_model=prot_bert 28 | checkpoint=/home/lisong/huggingface/checkpoints/Rostlab/prot_bert 29 | 30 | pooling_head=attention1d 31 | lr=5e-4 32 | 33 | CUDA_VISIBLE_DEVICES=0 python train.py \ 34 | --plm_model $checkpoint \ 35 | --num_attention_heads 8 \ 36 | --pooling_method $pooling_head \ 37 | --pooling_dropout 0.1 \ 38 | --dataset_config dataset/$dataset/"$dataset"_"$pdb_type".json \ 39 | --lr $lr \ 40 | --num_workers 4 \ 41 | --gradient_accumulation_steps 1 \ 42 | --max_train_epochs 50 \ 43 | --max_batch_token 40000 \ 44 | --patience 5 \ 45 | --structure_seqs $seqs \ 46 | --ckpt_root result \ 47 | --ckpt_dir $plm_model/$dataset \ 48 | --model_name "$pdb_type"_"$plm_model"_"$pooling_head"_"$lr"_"$seqs_type".pt \ 49 | # --wandb \ 50 | # --wandb_entity your/wandb/name \ 51 | # --wandb_project VenusVaccine \ 52 | # --wandb_run_name "$dataset"_"$pdb_type"_"$plm_model"_"$pooling_head"_"$lr"_"$seqs_type" 53 | -------------------------------------------------------------------------------- /src/data/add_noise_to_backbone.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import json 4 | from tqdm import tqdm 5 | 6 | pdbs = os.listdir('alphafold_pdb') 7 | for pdb in tqdm(pdbs): 8 | pdb_lines = open(f"alphafold_pdb/{pdb}").read().splitlines() 9 | 10 | def add_noise_and_save(variance, file_name): 11 | with open(file_name, "w") as file: 12 | for line in pdb_lines: 13 | if line.startswith("ATOM"): 14 | parts = line.split() 15 | try: 16 | coords = np.array([float(parts[6]), float(parts[7]), float(parts[8])]) 17 | noise = np.random.normal(0, variance, coords.shape) 18 | new_coords = coords + noise 19 | new_line = f"{line[:30]}{new_coords[0]:8.3f}{new_coords[1]:8.3f}{new_coords[2]:8.3f}{line[54:]}" 20 | file.write(new_line + "\n") 21 | except: 22 | file.write(line + "\n") 23 | else: 24 | file.write(line + "\n") 25 | 26 | variances = [0.5] 27 | 28 | for variance in variances: 29 | file_name = f"alphafold_pdb_noise_{variance}/{pdb}" 30 | try: 31 | add_noise_and_save(variance, file_name) 32 | except Exception as e: 33 | print(e) 34 | print(pdb) 35 | 36 | -------------------------------------------------------------------------------- /src/data/get_esm3_structure_seq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | sys.path.append(os.getcwd()) 5 | import json 6 | import argparse 7 | import pandas as pd 8 | import numpy as np 9 | from tqdm import tqdm 10 | from biotite.structure.io.pdb import PDBFile 11 | from torch.nn import functional as F 12 | from src.esm.utils.structure.protein_chain import ProteinChain 13 | from src.esm.utils.constants.esm3 import VQVAE_SPECIAL_TOKENS 14 | from src.esm.tokenization.structure_tokenizer import StructureTokenizer 15 | from src.esm.models.vqvae import ( 16 | StructureTokenDecoder, 17 | StructureTokenEncoder, 18 | ) 19 | import torch._dynamo 20 | torch._dynamo.config.suppress_errors = True 21 | 22 | VQVAE_CODEBOOK_SIZE = 4096 23 | VQVAE_SPECIAL_TOKENS = { 24 | "MASK": VQVAE_CODEBOOK_SIZE, 25 | "EOS": VQVAE_CODEBOOK_SIZE + 1, 26 | "BOS": VQVAE_CODEBOOK_SIZE + 2, 27 | "PAD": VQVAE_CODEBOOK_SIZE + 3, 28 | "CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4, 29 | } 30 | 31 | def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"): 32 | model = ( 33 | StructureTokenEncoder( 34 | d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096 35 | ) 36 | .to(device) 37 | .eval() 38 | ) 39 | state_dict = torch.load( 40 | "data/weights/esm3_structure_encoder_v0.pth", map_location=device 41 | ) 42 | model.load_state_dict(state_dict) 43 | return model 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument("--pdb_file", type=str, default=None) 48 | parser.add_argument("--pdb_dir", type=str, default=None) 49 | parser.add_argument("--out_file", type=str, default='structure_tokens.json') 50 | args = parser.parse_args() 51 | 52 | device="cuda:0" 53 | results = [] 54 | # result_dict = {'name':[], 'aa_seq':[], 'esm3_structure_tokens':[], 'plddt':[], 'residue_index':[]} 55 | 56 | encoder = ESM3_structure_encoder_v0(device) 57 | 58 | if args.pdb_file is not None: 59 | # Extract Unique Chain IDs 60 | chain_ids = np.unique(PDBFile.read(args.pdb_file).get_structure().chain_id) 61 | # print(chain_ids) 62 | # ['L', 'H'] 63 | 64 | # By Default, ProteinChain takes first one 65 | chain = ProteinChain.from_pdb(args.pdb_file, chain_id=chain_ids[0]) 66 | sequence = chain.sequence 67 | 68 | # Encoder 69 | coords, plddt, residue_index = chain.to_structure_encoder_inputs() 70 | coords = coords.to(device) 71 | #plddt = plddt.cuda() 72 | residue_index = residue_index.to(device) 73 | _, structure_tokens = encoder.encode(coords, residue_index=residue_index) 74 | 75 | result = {'name':args.pdb_file, 'aa_seq':sequence, 'esm3_structure_seq':structure_tokens.cpu().numpy().tolist()[0]} 76 | results.append(result) 77 | 78 | with open(args.out_file, "w") as f: 79 | f.write("\n".join([json.dumps(r) for r in results])) 80 | 81 | elif args.pdb_dir is not None: 82 | pdb_files = os.listdir(args.pdb_dir) 83 | for pdb_file in tqdm(pdb_files): 84 | # Extract Unique Chain IDs 85 | chain_ids = np.unique(PDBFile.read(os.path.join(args.pdb_dir, pdb_file)).get_structure().chain_id) 86 | # print(chain_ids) 87 | # ['L', 'H'] 88 | 89 | # By Default, ProteinChain takes first one 90 | chain = ProteinChain.from_pdb(os.path.join(args.pdb_dir, pdb_file), chain_id=chain_ids[0]) 91 | sequence = chain.sequence 92 | 93 | # Encoder 94 | coords, plddt, residue_index = chain.to_structure_encoder_inputs() 95 | coords = coords.to(device) 96 | #plddt = pldt.cuda() 97 | residue_index = residue_index.to(device) 98 | _, structure_tokens = encoder.encode(coords, residue_index=residue_index) 99 | 100 | result = {'name':pdb_file, 'aa_seq':sequence, 'esm3_structure_seq':structure_tokens.cpu().numpy().tolist()[0]} 101 | results.append(result) 102 | 103 | with open(args.out_file, "w") as f: 104 | f.write("\n".join([json.dumps(r) for r in results])) 105 | -------------------------------------------------------------------------------- /src/data/get_ss_seq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.getcwd()) 4 | import argparse 5 | import json 6 | import pandas as pd 7 | from tqdm import tqdm 8 | from Bio import PDB 9 | from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed 10 | from src.utils.data_utils import extract_seq_from_pdb 11 | 12 | 13 | ss_alphabet = ['H', 'E', 'C'] 14 | ss_alphabet_dic = { 15 | "H": "H", "G": "H", "E": "E", 16 | "B": "E", "I": "C", "T": "C", 17 | "S": "C", "L": "C", "-": "C", 18 | "P": "C" 19 | } 20 | 21 | def generate_feature(pdb_file): 22 | try: 23 | # extract amino acid sequence 24 | aa_seq = extract_seq_from_pdb(pdb_file) 25 | pdb_parser = PDB.PDBParser(QUIET=True) 26 | structure = pdb_parser.get_structure("protein", pdb_file) 27 | model = structure[0] 28 | dssp = PDB.DSSP(model, pdb_file) 29 | # extract secondary structure sequence 30 | sec_structures = [] 31 | for i, dssp_res in enumerate(dssp): 32 | sec_structures.append(dssp_res[2]) 33 | 34 | except Exception as e: 35 | return pdb_file, e 36 | 37 | sec_structure_str_8 = ''.join(sec_structures) 38 | sec_structure_str_8 = sec_structure_str_8.replace('-', 'L') 39 | if len(aa_seq) != len(sec_structure_str_8): 40 | return pdb_file, f"aa_seq {len(aa_seq)} and sec_structure_str_8 {len(sec_structure_str_8)} length mismatch" 41 | 42 | sec_structure_str_3 = ''.join([ss_alphabet_dic[ss] for ss in sec_structures]) 43 | 44 | final_feature = {} 45 | final_feature["name"] = pdb_file.split('/')[-1] 46 | final_feature["aa_seq"] = aa_seq 47 | final_feature["ss8_seq"] = sec_structure_str_8 48 | final_feature["ss3_seq"] = sec_structure_str_3 49 | 50 | return final_feature, None 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('--pdb_dir', type=str, help='pdb dir') 55 | parser.add_argument('--pdb_file', type=str, help='pdb file') 56 | 57 | # multi processing 58 | parser.add_argument('--num_workers', type=int, default=4, help='number of workers') 59 | 60 | # index pdb for large scale inference 61 | parser.add_argument("--pdb_index_file", default=None, type=str, help="pdb index file") 62 | parser.add_argument("--pdb_index_level", default=1, type=int, help="pdb index level") 63 | 64 | # save file 65 | parser.add_argument('--error_file', type=str, help='save error file') 66 | parser.add_argument('--out_file', type=str, help='save file') 67 | args = parser.parse_args() 68 | 69 | out_dir = os.path.dirname(args.out_file) 70 | os.makedirs(out_dir, exist_ok=True) 71 | 72 | if args.pdb_dir is not None: 73 | # load pdb index file 74 | if args.pdb_index_file: 75 | pdbs = open(args.pdb_index_file).read().splitlines() 76 | pdb_files = [] 77 | for pdb in pdbs: 78 | pdb_relative_dir = args.pdb_dir 79 | for i in range(1, args.pdb_index_level+1): 80 | pdb_relative_dir = os.path.join(pdb_relative_dir, pdb[:i]) 81 | pdb_files.append(os.path.join(pdb_relative_dir, pdb+".pdb")) 82 | 83 | # regular pdb dir 84 | else: 85 | pdb_files = sorted([os.path.join(args.pdb_dir, p) for p in os.listdir(args.pdb_dir)]) 86 | 87 | results, error_pdbs, error_messages = [], [], [] 88 | with ThreadPoolExecutor(max_workers=args.num_workers) as executor: 89 | futures = [executor.submit(generate_feature, pdb_file) for pdb_file in pdb_files] 90 | 91 | with tqdm(total=len(pdb_files), desc="Processing pdb") as progress: 92 | for future in as_completed(futures): 93 | result, message = future.result() 94 | if message is None: 95 | results.append(result) 96 | else: 97 | error_pdbs.append(result) 98 | error_messages.append(message) 99 | progress.update(1) 100 | progress.close() 101 | 102 | if error_pdbs: 103 | if args.error_file is None: 104 | args.error_file = args.out_file.split(".")[0]+"_error.csv" 105 | error_dir = os.path.dirname(args.error_file) 106 | os.makedirs(error_dir, exist_ok=True) 107 | error_info = {"error_pdbs": error_pdbs, "error_messages": error_messages} 108 | pd.DataFrame(error_info).to_csv(args.error_file, index=False) 109 | 110 | with open(args.out_file, "w") as f: 111 | f.write("\n".join([json.dumps(r) for r in results])) 112 | 113 | elif args.pdb_file is not None: 114 | result, message = generate_feature(args.pdb_file) 115 | with open(args.out_file, "w") as f: 116 | json.dump(result, f) 117 | -------------------------------------------------------------------------------- /src/data/processors/descriptor_features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class DescriptorFeatureProcessor: 4 | def __init__(self): 5 | self.e_descriptors = self._init_e_descriptors() 6 | self.z_descriptors = self._init_z_descriptors() 7 | 8 | def _init_e_descriptors(self): 9 | e1 = {'A': 0.008, 'R': 0.171, 'N': 0.255, 'D': 0.303, 'C': -0.132, 'Q': 0.149, 'E': 0.221, 'G': 0.218, 10 | 'H': 0.023, 'I': -0.353, 'L': -0.267, 'K': 0.243, 'M': -0.239, 'F': -0.329, 'P': 0.173, 'S': 0.199, 11 | 'T': 0.068, 'W': -0.296, 'Y': -0.141, 'V': -0.274} 12 | e2 = {'A': 0.134, 'R': -0.361, 'N': 0.038, 'D': -0.057, 'C': 0.174, 'Q': -0.184, 'E': -0.28, 'G': 0.562, 13 | 'H': -0.177, 'I': 0.071, 'L': 0.018, 'K': -0.339, 'M': -0.141, 'F': -0.023, 'P': 0.286, 'S': 0.238, 14 | 'T': 0.147, 'W': -0.186, 'Y': -0.057, 'V': 0.136} 15 | e3 = {'A': -0.475, 'R': 0.107, 'N': 0.117, 'D': -0.014, 'C': 0.07, 'Q': -0.03, 'E': -0.315, 'G': -0.024, 16 | 'H': 0.041, 'I': -0.088, 'L': -0.265, 'K': -0.044, 'M': -0.155, 'F': 0.072, 'P': 0.407, 'S': -0.015, 17 | 'T': -0.015, 'W': 0.389, 'Y': 0.425, 'V': -0.187} 18 | e4 = {'A': -0.039, 'R': -0.258, 'N': 0.118, 'D': 0.225, 'C': 0.565, 'Q': 0.035, 'E': 0.157, 'G': 0.018, 19 | 'H': 0.28, 'I': -0.195, 'L': -0.274, 'K': -0.325, 'M': 0.321, 'F': -0.002, 'P': -0.215, 'S': -0.068, 20 | 'T': -0.132, 'W': 0.083, 'Y': -0.096, 'V': -0.196} 21 | e5 = {'A': 0.181, 'R': -0.364, 'N': -0.055, 'D': 0.156, 'C': -0.374, 'Q': -0.112, 'E': 0.303, 'G': 0.106, 22 | 'H': -0.021, 'I': -0.107, 'L': 0.206, 'K': -0.027, 'M': 0.077, 'F': 0.208, 'P': 0.384, 'S': -0.196, 23 | 'T': -0.274, 'W': 0.297, 'Y': -0.091, 'V': -0.299} 24 | return [e1, e2, e3, e4, e5] 25 | 26 | def _init_z_descriptors(self): 27 | z1 = {'A': 0.07, 'R': 2.88, 'N': 3.22, 'D': 3.64, 'C': 0.71, 'Q': 2.18, 'E': 3.08, 'G': 2.23, 'H': 2.41, 28 | 'I': -4.44, 'L': -4.19, 'K': 2.84, 'M': -2.49, 'F': -4.92, 'P': -1.22, 'S': 1.96, 'T': 0.92, 'W': -4.75, 29 | 'Y': -1.39, 'V': -2.69} 30 | z2 = {'A': -1.73, 'R': 2.52, 'N': 1.45, 'D': 1.13, 'C': -0.97, 'Q': 0.53, 'E': 0.39, 'G': -5.36, 'H': 1.74, 31 | 'I': -1.68, 'L': -1.03, 'K': 1.41, 'M': -0.27, 'F': 1.30, 'P': 0.88, 'S': -1.63, 'T': -2.09, 'W': 3.65, 32 | 'Y': 2.32, 'V': -2.53} 33 | z3 = {'A': 0.09, 'R': -3.44, 'N': 0.84, 'D': 2.36, 'C': 4.13, 'Q': -1.14, 'E': -0.07, 'G': 0.30, 'H': 1.11, 34 | 'I': -1.03, 'L': -0.98, 'K': -3.14, 'M': -0.41, 'F': 0.45, 'P': 2.23, 'S': 0.57, 'T': -1.40, 'W': 0.85, 35 | 'Y': 0.01, 'V': -1.29} 36 | return [z1, z2, z3] 37 | 38 | def e_descriptor_embedding(self, seq): 39 | descriptors = {aa: [d[aa] for d in self.e_descriptors] for aa in self.e_descriptors[0].keys()} 40 | return [descriptors.get(aa, [0.0]*5) for aa in seq] 41 | 42 | def z_descriptor_embedding(self, seq): 43 | descriptors = {aa: [d[aa] for d in self.z_descriptors] for aa in self.z_descriptors[0].keys()} 44 | return [descriptors.get(aa, [0.0]*3) for aa in seq] -------------------------------------------------------------------------------- /src/data/processors/foldseek.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | from Bio import SeqIO 4 | 5 | class FoldseekProcessor: 6 | @staticmethod 7 | def run_foldseek_commands(pdb_dir): 8 | temp_dir = "temp" 9 | fasta_file = "foldseek_seq.fasta" 10 | 11 | os.makedirs(temp_dir, exist_ok=True) 12 | 13 | try: 14 | subprocess.run(["foldseek", "createdb", pdb_dir, f"{temp_dir}/db"], check=True) 15 | subprocess.run(["foldseek", "lndb", f"{temp_dir}/db_h", f"{temp_dir}/db_ss_h"], check=True) 16 | subprocess.run(["foldseek", "convert2fasta", f"{temp_dir}/db_ss", fasta_file], check=True) 17 | 18 | foldseek_dict = {record.id: str(record.seq) for record in SeqIO.parse(fasta_file, "fasta")} 19 | 20 | return foldseek_dict 21 | finally: 22 | if os.path.exists(temp_dir): 23 | subprocess.run(["rm", "-rf", temp_dir], check=True) 24 | if os.path.exists(fasta_file): 25 | os.remove(fasta_file) -------------------------------------------------------------------------------- /src/data/processors/protein_features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class DescriptorFeatureProcessor: 4 | def __init__(self): 5 | self.e_descriptors = self._init_e_descriptors() 6 | self.z_descriptors = self._init_z_descriptors() 7 | 8 | def _init_e_descriptors(self): 9 | e1 = {'A': 0.008, 'R': 0.171, 'N': 0.255, 'D': 0.303, 'C': -0.132, 'Q': 0.149, 'E': 0.221, 'G': 0.218, 10 | 'H': 0.023, 'I': -0.353, 'L': -0.267, 'K': 0.243, 'M': -0.239, 'F': -0.329, 'P': 0.173, 'S': 0.199, 11 | 'T': 0.068, 'W': -0.296, 'Y': -0.141, 'V': -0.274} 12 | e2 = {'A': 0.134, 'R': -0.361, 'N': 0.038, 'D': -0.057, 'C': 0.174, 'Q': -0.184, 'E': -0.28, 'G': 0.562, 13 | 'H': -0.177, 'I': 0.071, 'L': 0.018, 'K': -0.339, 'M': -0.141, 'F': -0.023, 'P': 0.286, 'S': 0.238, 14 | 'T': 0.147, 'W': -0.186, 'Y': -0.057, 'V': 0.136} 15 | e3 = {'A': -0.475, 'R': 0.107, 'N': 0.117, 'D': -0.014, 'C': 0.07, 'Q': -0.03, 'E': -0.315, 'G': -0.024, 16 | 'H': 0.041, 'I': -0.088, 'L': -0.265, 'K': -0.044, 'M': -0.155, 'F': 0.072, 'P': 0.407, 'S': -0.015, 17 | 'T': -0.015, 'W': 0.389, 'Y': 0.425, 'V': -0.187} 18 | e4 = {'A': -0.039, 'R': -0.258, 'N': 0.118, 'D': 0.225, 'C': 0.565, 'Q': 0.035, 'E': 0.157, 'G': 0.018, 19 | 'H': 0.28, 'I': -0.195, 'L': -0.274, 'K': -0.325, 'M': 0.321, 'F': -0.002, 'P': -0.215, 'S': -0.068, 20 | 'T': -0.132, 'W': 0.083, 'Y': -0.096, 'V': -0.196} 21 | e5 = {'A': 0.181, 'R': -0.364, 'N': -0.055, 'D': 0.156, 'C': -0.374, 'Q': -0.112, 'E': 0.303, 'G': 0.106, 22 | 'H': -0.021, 'I': -0.107, 'L': 0.206, 'K': -0.027, 'M': 0.077, 'F': 0.208, 'P': 0.384, 'S': -0.196, 23 | 'T': -0.274, 'W': 0.297, 'Y': -0.091, 'V': -0.299} 24 | return [e1, e2, e3, e4, e5] 25 | 26 | def _init_z_descriptors(self): 27 | z1 = {'A': 0.07, 'R': 2.88, 'N': 3.22, 'D': 3.64, 'C': 0.71, 'Q': 2.18, 'E': 3.08, 'G': 2.23, 'H': 2.41, 28 | 'I': -4.44, 'L': -4.19, 'K': 2.84, 'M': -2.49, 'F': -4.92, 'P': -1.22, 'S': 1.96, 'T': 0.92, 'W': -4.75, 29 | 'Y': -1.39, 'V': -2.69} 30 | z2 = {'A': -1.73, 'R': 2.52, 'N': 1.45, 'D': 1.13, 'C': -0.97, 'Q': 0.53, 'E': 0.39, 'G': -5.36, 'H': 1.74, 31 | 'I': -1.68, 'L': -1.03, 'K': 1.41, 'M': -0.27, 'F': 1.30, 'P': 0.88, 'S': -1.63, 'T': -2.09, 'W': 3.65, 32 | 'Y': 2.32, 'V': -2.53} 33 | z3 = {'A': 0.09, 'R': -3.44, 'N': 0.84, 'D': 2.36, 'C': 4.13, 'Q': -1.14, 'E': -0.07, 'G': 0.30, 'H': 1.11, 34 | 'I': -1.03, 'L': -0.98, 'K': -3.14, 'M': -0.41, 'F': 0.45, 'P': 2.23, 'S': 0.57, 'T': -1.40, 'W': 0.85, 35 | 'Y': 0.01, 'V': -1.29} 36 | return [z1, z2, z3] 37 | 38 | def e_descriptor_embedding(self, seq): 39 | descriptors = {aa: [d[aa] for d in self.e_descriptors] for aa in self.e_descriptors[0].keys()} 40 | return [descriptors.get(aa, [0.0]*5) for aa in seq] 41 | 42 | def z_descriptor_embedding(self, seq): 43 | descriptors = {aa: [d[aa] for d in self.z_descriptors] for aa in self.z_descriptors[0].keys()} 44 | return [descriptors.get(aa, [0.0]*3) for aa in seq] -------------------------------------------------------------------------------- /src/data/processors/structure_features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from biotite.structure.io.pdb import PDBFile 4 | from src.esm.utils.structure.protein_chain import ProteinChain 5 | from src.esm.models.vqvae import StructureTokenEncoder 6 | 7 | class StructureFeatureProcessor: 8 | def __init__(self, device="cpu"): 9 | self.device = device 10 | self.encoder = self._load_esm3_encoder() 11 | 12 | def _load_esm3_encoder(self): 13 | model = ( 14 | StructureTokenEncoder( 15 | d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096 16 | ) 17 | .to(self.device) 18 | .eval() 19 | ) 20 | state_dict = torch.load( 21 | "src/data/weights/esm3_structure_encoder_v0.pth", map_location=self.device 22 | ) 23 | model.load_state_dict(state_dict) 24 | return model 25 | 26 | def get_esm3_structure_seq(self, pdb_file): 27 | chain_ids = self._get_chain_ids(pdb_file) 28 | chain = ProteinChain.from_pdb(pdb_file, chain_id=chain_ids[0]) 29 | 30 | coords, plddt, residue_index = chain.to_structure_encoder_inputs() 31 | coords = coords.to(self.device) 32 | residue_index = residue_index.to(self.device) 33 | 34 | _, structure_tokens = self.encoder.encode(coords, residue_index=residue_index) 35 | return structure_tokens.cpu().numpy().tolist()[0], chain.sequence 36 | 37 | @staticmethod 38 | def _get_chain_ids(pdb_file): 39 | return np.unique(PDBFile.read(pdb_file).get_structure().chain_id) -------------------------------------------------------------------------------- /src/esm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4protein/VenusVaccine/a05435581246573845d8a32336340e08c821158d/src/esm/__init__.py -------------------------------------------------------------------------------- /src/esm/layers/attention.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import einops 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | from src.esm.layers.rotary import RotaryEmbedding 9 | 10 | 11 | class MultiHeadAttention(nn.Module): 12 | def __init__( 13 | self, 14 | d_model: int, 15 | n_heads: int, 16 | bias: bool = False, 17 | qk_layernorm: bool = True, 18 | ): 19 | super().__init__() 20 | 21 | self.d_model = d_model 22 | self.n_heads = n_heads 23 | 24 | self.d_head = self.d_model // self.n_heads 25 | self.layernorm_qkv = nn.Sequential( 26 | nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=bias) 27 | ) 28 | self.out_proj = nn.Linear(d_model, d_model, bias=bias) 29 | 30 | if qk_layernorm: 31 | self.q_ln = nn.LayerNorm(d_model, bias=bias) 32 | self.k_ln = nn.LayerNorm(d_model, bias=bias) 33 | else: 34 | self.q_ln = nn.Identity() 35 | self.k_ln = nn.Identity() 36 | 37 | self.rotary = RotaryEmbedding(d_model // n_heads) 38 | 39 | def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor): 40 | q = q.unflatten(-1, (self.n_heads, self.d_head)) 41 | k = k.unflatten(-1, (self.n_heads, self.d_head)) 42 | q, k = self.rotary(q, k) 43 | q = q.flatten(-2, -1) 44 | k = k.flatten(-2, -1) 45 | return q, k 46 | 47 | def forward(self, x, seq_id): 48 | qkv_BLD3 = self.layernorm_qkv(x) 49 | query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1) 50 | query_BLD, key_BLD = self.q_ln(query_BLD), self.k_ln(key_BLD) 51 | query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD) 52 | 53 | n_heads = self.n_heads 54 | reshaper = functools.partial( 55 | einops.rearrange, pattern="b s (h d) -> b h s d", h=n_heads 56 | ) 57 | 58 | query_BHLD, key_BHLD, value_BHLD = map( 59 | reshaper, (query_BLD, key_BLD, value_BLD) 60 | ) 61 | 62 | # Where True, enable participation in attention. 63 | mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2) 64 | mask_BHLL = mask_BLL.unsqueeze(1) 65 | 66 | context_BHLD = F.scaled_dot_product_attention( 67 | query_BHLD, key_BHLD, value_BHLD, mask_BHLL 68 | ) 69 | context_BLD = einops.rearrange(context_BHLD, "b h s d -> b s (h d)") 70 | return self.out_proj(context_BLD) 71 | -------------------------------------------------------------------------------- /src/esm/layers/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from src.esm.layers.attention import MultiHeadAttention 6 | from src.esm.layers.geom_attention import ( 7 | GeometricReasoningOriginalImpl, 8 | ) 9 | from src.esm.utils.structure.affine3d import Affine3D 10 | 11 | 12 | def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int: 13 | # set hidden dimesion to nearest multiple of 256 after expansion ratio 14 | return int(((expansion_ratio * d_model) + 255) // 256 * 256) 15 | 16 | 17 | class SwiGLU(nn.Module): 18 | """ 19 | SwiGLU activation function as an nn.Module, allowing it to be used within nn.Sequential. 20 | This module splits the input tensor along the last dimension and applies the SiLU (Swish) 21 | activation function to the first half, then multiplies it by the second half. 22 | """ 23 | 24 | def __init__(self): 25 | super(SwiGLU, self).__init__() 26 | 27 | def forward(self, x: torch.Tensor) -> torch.Tensor: 28 | x1, x2 = x.chunk(2, dim=-1) 29 | return F.silu(x1) * x2 30 | 31 | 32 | def swiglu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool): 33 | return nn.Sequential( 34 | nn.LayerNorm(d_model), 35 | nn.Linear( 36 | d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=bias 37 | ), 38 | SwiGLU(), 39 | nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=bias), 40 | ) 41 | 42 | 43 | def gelu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool): 44 | hidden_dim = int(expansion_ratio * d_model) 45 | return nn.Sequential( 46 | nn.LayerNorm(d_model), 47 | nn.Linear(d_model, hidden_dim, bias=bias), 48 | nn.GELU(), 49 | nn.Linear(hidden_dim, d_model, bias=bias), 50 | ) 51 | 52 | 53 | class UnifiedTransformerBlock(nn.Module): 54 | """ 55 | A unified transformer block that can optionally incorporate geometric attention. 56 | 57 | This class defines a transformer block that can be configured to use geometric attention 58 | alongside the standard multi-head attention mechanism. It is designed to be a flexible 59 | component of transformer-based models, allowing for the integration of geometric reasoning. 60 | 61 | Parameters 62 | ---------- 63 | d_model : int 64 | The dimensionality of the input and output features of the transformer block. 65 | n_heads : int 66 | The number of attention heads in the multi-head attention mechanism. 67 | n_layers : int 68 | The number of layers in the transformer block. 69 | use_geom_attn : bool, optional 70 | Whether to use geometric attention in addition to the standard multi-head attention. Defaults to False. 71 | v_heads : int, optional 72 | The number of heads to use for the geometric attention mechanism, if enabled. Must be specified if `use_geom_attn` is True. 73 | """ 74 | 75 | def __init__( 76 | self, 77 | d_model: int, 78 | n_heads: int, 79 | use_geom_attn: bool = False, 80 | use_plain_attn: bool = True, 81 | v_heads: int | None = None, 82 | bias: bool = False, 83 | expansion_ratio: float = 4.0, 84 | residue_scaling_factor: float = 1, 85 | mask_and_zero_frameless: bool = False, 86 | qk_layernorm: bool = True, 87 | ffn_type: str = "swiglu", # swiglu | gelu 88 | ): 89 | super().__init__() 90 | self.use_plain_attn = use_plain_attn 91 | if self.use_plain_attn: 92 | self.attn = MultiHeadAttention( 93 | d_model, n_heads, bias, qk_layernorm=qk_layernorm 94 | ) 95 | self.use_geom_attn = use_geom_attn 96 | if self.use_geom_attn: 97 | if v_heads is None: 98 | raise ValueError("v_heads must be specified when use_geom_attn is True") 99 | self.geom_attn = GeometricReasoningOriginalImpl( 100 | c_s=d_model, 101 | v_heads=v_heads, 102 | bias=bias, 103 | mask_and_zero_frameless=mask_and_zero_frameless, 104 | ) 105 | if ffn_type == "swiglu": 106 | self.ffn = swiglu_ln_ffn(d_model, expansion_ratio, bias) 107 | elif ffn_type == "gelu": 108 | self.ffn = gelu_ln_ffn(d_model, expansion_ratio, bias) 109 | else: 110 | raise ValueError(f"Unknown ffn_type: {ffn_type}") 111 | self.scaling_factor = residue_scaling_factor 112 | 113 | def forward( 114 | self, 115 | x: torch.Tensor, 116 | sequence_id: torch.Tensor, 117 | frames: Affine3D, 118 | frames_mask: torch.Tensor, 119 | chain_id: torch.Tensor, 120 | ) -> torch.Tensor: 121 | """ 122 | Forward pass for the UnifiedTransformerBlock. 123 | 124 | Parameters 125 | ---------- 126 | x : torch.Tensor[float] 127 | Input tensor to the transformer block, typically the output from the previous layer. 128 | sequence_id : torch.Tensor[int] 129 | Tensor containing sequence IDs for each element in the batch, used for attention masking. 130 | frames : Affine3D 131 | Affine3D containing geometric frame information for geometric attention. 132 | frames_mask : torch.Tensor[bool] 133 | Boolean mask tensor indicating valid frames for geometric attention. 134 | chain_id : torch.Tensor[int] 135 | Tensor containing chain IDs for each element, used for attention masking in geometric attention. 136 | 137 | Returns 138 | ------- 139 | torch.Tensor[float] 140 | The output tensor after applying the transformer block operations. 141 | """ 142 | if self.use_plain_attn: 143 | r1 = self.attn(x, sequence_id) 144 | x = x + r1 / self.scaling_factor 145 | 146 | if self.use_geom_attn: 147 | r2 = self.geom_attn(x, frames, frames_mask, sequence_id, chain_id) 148 | x = x + r2 / self.scaling_factor 149 | 150 | r3 = self.ffn(x) / self.scaling_factor 151 | x = x + r3 152 | 153 | return x 154 | -------------------------------------------------------------------------------- /src/esm/layers/codebook.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.distributed as dist 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class EMACodebook(nn.Module): 9 | def __init__( 10 | self, 11 | n_codes, 12 | embedding_dim, 13 | no_random_restart=True, 14 | restart_thres=1.0, 15 | ema_decay=0.99, 16 | ): 17 | super().__init__() 18 | self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim)) 19 | self.register_buffer("N", torch.zeros(n_codes)) 20 | self.register_buffer("z_avg", self.embeddings.data.clone()) 21 | 22 | self.n_codes = n_codes 23 | self.embedding_dim = embedding_dim 24 | self._need_init = True 25 | self.no_random_restart = no_random_restart 26 | self.restart_thres = restart_thres 27 | self.freeze_codebook = False 28 | self.ema_decay = ema_decay 29 | 30 | def reset_parameters(self): 31 | # For meta init 32 | pass 33 | 34 | def _tile(self, x): 35 | d, ew = x.shape 36 | if d < self.n_codes: 37 | n_repeats = (self.n_codes + d - 1) // d 38 | std = 0.01 / np.sqrt(ew) 39 | x = x.repeat(n_repeats, 1) 40 | x = x + torch.randn_like(x) * std 41 | return x 42 | 43 | def _init_embeddings(self, z): 44 | # z: [b, t, c] 45 | self._need_init = False 46 | flat_inputs = z.view(-1, self.embedding_dim) 47 | y = self._tile(flat_inputs) 48 | 49 | y.shape[0] 50 | _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] 51 | if dist.is_initialized(): 52 | dist.broadcast(_k_rand, 0) 53 | self.embeddings.data.copy_(_k_rand) 54 | self.z_avg.data.copy_(_k_rand) 55 | self.N.data.copy_(torch.ones(self.n_codes)) 56 | 57 | def forward(self, z): 58 | # z: [b, t, c] 59 | if self._need_init and self.training and not self.freeze_codebook: 60 | self._init_embeddings(z) 61 | # z is of shape [batch_size, sequence length, channels] 62 | flat_inputs = z.view(-1, self.embedding_dim) 63 | distances = ( 64 | (flat_inputs**2).sum(dim=1, keepdim=True) 65 | - 2 * flat_inputs @ self.embeddings.t() 66 | + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) 67 | ) # [bt, c] 68 | 69 | encoding_indices = torch.argmin(distances, dim=1) 70 | encoding_indices = encoding_indices.view(*z.shape[:2]) # [b, t, ncode] 71 | 72 | embeddings = F.embedding(encoding_indices, self.embeddings) # [b, t, c] 73 | 74 | commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) 75 | 76 | # EMA codebook update 77 | if self.training and not self.freeze_codebook: 78 | assert False, "Not implemented" 79 | embeddings_st = (embeddings - z).detach() + z 80 | 81 | return embeddings_st, encoding_indices, commitment_loss 82 | 83 | def dictionary_lookup(self, encodings): 84 | embeddings = F.embedding(encodings, self.embeddings) 85 | return embeddings 86 | 87 | def soft_codebook_lookup(self, weights: torch.Tensor) -> torch.Tensor: 88 | return weights @ self.embeddings 89 | -------------------------------------------------------------------------------- /src/esm/layers/ffn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torch import Tensor 4 | 5 | # NOT CURRENTLY USED 6 | 7 | 8 | class SwiGLU(nn.Module): 9 | def __init__(self) -> None: 10 | super().__init__() 11 | 12 | def forward(self, x: Tensor) -> Tensor: 13 | x1, x2 = x.chunk(2, dim=-1) 14 | hidden = F.silu(x1) * x2 15 | return hidden 16 | 17 | 18 | class FFN(nn.Module): 19 | def __init__(self, in_proj, activation, out_proj) -> None: 20 | super().__init__() 21 | self.in_proj = in_proj 22 | self.activation = activation 23 | self.out_proj = out_proj 24 | 25 | def forward(self, x: Tensor) -> Tensor: 26 | x = self.in_proj(x) 27 | x = self.activation(x) 28 | x = self.out_proj(x) 29 | return x 30 | -------------------------------------------------------------------------------- /src/esm/layers/geom_attention.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import torch 4 | from einops import rearrange 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | 9 | class GeometricReasoningOriginalImpl(nn.Module): 10 | def __init__( 11 | self, 12 | c_s: int, 13 | v_heads: int, 14 | num_vector_messages: int = 1, 15 | mask_and_zero_frameless: bool = True, 16 | divide_residual_by_depth: bool = False, 17 | bias: bool = False, 18 | ): 19 | """Approximate implementation: 20 | 21 | ATTN(A, v) := (softmax_j A_ij) v_j 22 | make_rot_vectors(x) := R(i->g) Linear(x).reshape(..., 3) 23 | make_vectors(x) := T(i->g) Linear(x).reshape(..., 3) 24 | 25 | v <- make_rot_vectors(x) 26 | q_dir, k_dir <- make_rot_vectors(x) 27 | q_dist, k_dist <- make_vectors(x) 28 | 29 | A_ij <- dot(q_dir_i, k_dir_j) -||q_dist_i - k_dist_j||^2 30 | x <- x + Linear(T(g->i) ATTN(A, v)) 31 | """ 32 | super().__init__() 33 | self.c_s = c_s 34 | self.v_heads = v_heads 35 | self.num_vector_messages = num_vector_messages 36 | self.mask_and_zero_frameless = mask_and_zero_frameless 37 | 38 | self.s_norm = nn.LayerNorm(c_s, bias=bias) 39 | dim_proj = ( 40 | 4 * self.v_heads * 3 + self.v_heads * 3 * self.num_vector_messages 41 | ) # 2 x (q, k) * number of heads * (x, y, z) + number of heads * number of vector messages * (x, y, z) 42 | self.proj = nn.Linear(c_s, dim_proj, bias=bias) 43 | channels_out = self.v_heads * 3 * self.num_vector_messages 44 | self.out_proj = nn.Linear(channels_out, c_s, bias=bias) 45 | 46 | # The basic idea is for some attention heads to pay more or less attention to rotation versus distance, 47 | # as well as to control the sharpness of the softmax (i.e., should this head only attend to those residues 48 | # very nearby or should there be shallower dropoff in attention weight?) 49 | self.distance_scale_per_head = nn.Parameter(torch.zeros((self.v_heads))) 50 | self.rotation_scale_per_head = nn.Parameter(torch.zeros((self.v_heads))) 51 | 52 | def forward(self, s, affine, affine_mask, sequence_id, chain_id): 53 | attn_bias = sequence_id.unsqueeze(-1) == sequence_id.unsqueeze(-2) 54 | attn_bias = attn_bias.unsqueeze(1).float() 55 | attn_bias = attn_bias.masked_fill( 56 | ~affine_mask[:, None, None, :], torch.finfo(attn_bias.dtype).min 57 | ) 58 | chain_id_mask = chain_id.unsqueeze(1) != chain_id.unsqueeze(2) 59 | attn_bias = attn_bias.masked_fill( 60 | chain_id_mask.unsqueeze(1), torch.finfo(s.dtype).min 61 | ) 62 | 63 | ns = self.s_norm(s) 64 | vec_rot, vec_dist = self.proj(ns).split( 65 | [ 66 | self.v_heads * 2 * 3 + self.v_heads * 3 * self.num_vector_messages, 67 | self.v_heads * 2 * 3, 68 | ], 69 | dim=-1, 70 | ) 71 | 72 | # Rotate the queries and keys for the rotation term. We also rotate the values. 73 | # NOTE(zeming, thayes): Values are only rotated, not translated. We may wish to change 74 | # this in the future. 75 | query_rot, key_rot, value = ( 76 | affine.rot[..., None] 77 | .apply(rearrange(vec_rot, "... (h c) -> ... h c", c=3)) 78 | .split( 79 | [ 80 | self.v_heads, 81 | self.v_heads, 82 | self.v_heads * self.num_vector_messages, 83 | ], 84 | dim=-2, 85 | ) 86 | ) 87 | 88 | # Rotate and translate the queries and keys for the distance term 89 | # NOTE(thayes): a simple speedup would be to apply all rotations together, then 90 | # separately apply the translations. 91 | query_dist, key_dist = ( 92 | affine[..., None] 93 | .apply(rearrange(vec_dist, "... (h c) -> ... h c", c=3)) 94 | .chunk(2, dim=-2) 95 | ) 96 | 97 | query_dist = rearrange(query_dist, "b s h d -> b h s 1 d") 98 | key_dist = rearrange(key_dist, "b s h d -> b h 1 s d") 99 | query_rot = rearrange(query_rot, "b s h d -> b h s d") 100 | key_rot = rearrange(key_rot, "b s h d -> b h d s") 101 | value = rearrange( 102 | value, "b s (h m) d -> b h s (m d)", m=self.num_vector_messages 103 | ) 104 | 105 | distance_term = (query_dist - key_dist).norm(dim=-1) / sqrt(3) 106 | rotation_term = query_rot.matmul(key_rot) / sqrt(3) 107 | distance_term_weight = rearrange( 108 | F.softplus(self.distance_scale_per_head), "h -> h 1 1" 109 | ) 110 | rotation_term_weight = rearrange( 111 | F.softplus(self.rotation_scale_per_head), "h -> h 1 1" 112 | ) 113 | 114 | attn_weight = ( 115 | rotation_term * rotation_term_weight - distance_term * distance_term_weight 116 | ) 117 | 118 | if attn_bias is not None: 119 | # we can re-use the attention bias from the transformer layers 120 | # NOTE(thayes): This attention bias is expected to handle two things: 121 | # 1. Masking attention on padding tokens 122 | # 2. Masking cross sequence attention in the case of bin packing 123 | s_q = attn_weight.size(2) 124 | s_k = attn_weight.size(3) 125 | _s_q = max(0, attn_bias.size(2) - s_q) 126 | _s_k = max(0, attn_bias.size(3) - s_k) 127 | attn_bias = attn_bias[:, :, _s_q:, _s_k:] 128 | attn_weight = attn_weight + attn_bias 129 | 130 | attn_weight = torch.softmax(attn_weight, dim=-1) 131 | 132 | attn_out = attn_weight.matmul(value) 133 | 134 | attn_out = ( 135 | affine.rot[..., None] 136 | .invert() 137 | .apply( 138 | rearrange( 139 | attn_out, "b h s (m d) -> b s (h m) d", m=self.num_vector_messages 140 | ) 141 | ) 142 | ) 143 | 144 | attn_out = rearrange( 145 | attn_out, "b s (h m) d -> b s (h m d)", m=self.num_vector_messages 146 | ) 147 | if self.mask_and_zero_frameless: 148 | attn_out = attn_out.masked_fill(~affine_mask[..., None], 0.0) 149 | s = self.out_proj(attn_out) 150 | 151 | return s 152 | -------------------------------------------------------------------------------- /src/esm/layers/regression_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def RegressionHead( 5 | d_model: int, 6 | output_dim: int, 7 | hidden_dim: int | None = None, 8 | ) -> nn.Module: 9 | """Single-hidden layer MLP for supervised output. 10 | 11 | Args: 12 | d_model: input dimension 13 | output_dim: dimensionality of the output. 14 | hidden_dim: optional dimension of hidden layer, defaults to d_model. 15 | Returns: 16 | output MLP module. 17 | """ 18 | hidden_dim = hidden_dim if hidden_dim is not None else d_model 19 | return nn.Sequential( 20 | nn.Linear(d_model, hidden_dim), 21 | nn.GELU(), 22 | nn.LayerNorm(hidden_dim), 23 | nn.Linear(hidden_dim, output_dim), 24 | ) 25 | -------------------------------------------------------------------------------- /src/esm/layers/rotary.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 2 | # 3 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 4 | # and OPT implementations in this library. It has been modified from its 5 | # original forms to accommodate minor architectural differences compared 6 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | # NOTE: this implementation is from LLaMA 2: 20 | # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/08639a72e17836184096ae6a7e2766f2a34c3e36/modeling_flash_llama.py#L114 21 | # Flash attention rotary implementation can be installed like so: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary` 22 | 23 | from typing import Tuple 24 | 25 | import torch 26 | from einops import rearrange, repeat 27 | 28 | 29 | def rotate_half(x, interleaved=False): 30 | if not interleaved: 31 | x1, x2 = x.chunk(2, dim=-1) 32 | return torch.cat((-x2, x1), dim=-1) 33 | else: 34 | x1, x2 = x[..., ::2], x[..., 1::2] 35 | return rearrange( 36 | torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 37 | ) 38 | 39 | 40 | def apply_rotary_emb_torch(x, cos, sin, interleaved=False, _inplace=False): 41 | """ 42 | x: (batch_size, seqlen, nheads, headdim) 43 | cos, sin: (seqlen, rotary_dim / 2) 44 | """ 45 | ro_dim = cos.shape[-1] * 2 46 | assert ro_dim <= x.shape[-1] 47 | seqlen = x.size(1) 48 | cos = cos[:seqlen] 49 | sin = sin[:seqlen] 50 | cos = repeat(cos, "s d -> s 1 (2 d)") 51 | sin = repeat(sin, "s d -> s 1 (2 d)") 52 | return torch.cat( 53 | [ 54 | x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, 55 | x[..., ro_dim:], 56 | ], 57 | dim=-1, 58 | ) 59 | 60 | 61 | class RotaryEmbedding(torch.nn.Module): 62 | """ 63 | The rotary position embeddings from RoFormer_ (Su et. al). 64 | A crucial insight from the method is that the query and keys are 65 | transformed by rotation matrices which depend on the relative positions. 66 | Other implementations are available in the Rotary Transformer repo_ and in 67 | GPT-NeoX_, GPT-NeoX was an inspiration 68 | .. _RoFormer: https://arxiv.org/abs/2104.09864 69 | .. _repo: https://github.com/ZhuiyiTechnology/roformer 70 | .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox 71 | If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). 72 | A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 73 | Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py 74 | """ 75 | 76 | def __init__( 77 | self, 78 | dim: int, 79 | base=10000.0, 80 | interleaved=False, 81 | scale_base=None, 82 | scaling_factor=1.0, 83 | pos_idx_in_fp32=True, 84 | device=None, 85 | ): 86 | """ 87 | interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead 88 | of 1st half and 2nd half (GPT-NeoX style). 89 | pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, 90 | otherwise they might be in lower precision. 91 | This option was added because previously (before 2023-07-02), when we construct 92 | the position indices, we use the dtype of self.inv_freq. In most cases this would 93 | be fp32, but if the model is trained in pure bf16 (not mixed precision), then 94 | self.inv_freq would be bf16, and the position indices are also in bf16. 95 | Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the 96 | embeddings for some positions will coincide. 97 | To maintain compatibility with models previously trained in pure bf16, 98 | we add this option. 99 | scaling_factor: RotaryEmbedding extended with linear scaling. 100 | """ 101 | super().__init__() 102 | self.dim = dim 103 | self.base = float(base) 104 | self.pos_idx_in_fp32 = pos_idx_in_fp32 105 | # Generate and save the inverse frequency buffer (non trainable) 106 | self.interleaved = interleaved 107 | self.scale_base = scale_base 108 | self.scaling_factor = scaling_factor 109 | self.device = device 110 | 111 | self._seq_len_cached = 0 112 | self._cos_cached = None 113 | self._sin_cached = None 114 | self._cos_k_cached = None 115 | self._sin_k_cached = None 116 | self.reset_parameters() 117 | 118 | def reset_parameters(self): 119 | inv_freq = self._compute_inv_freq(self.device) 120 | self.register_buffer("inv_freq", inv_freq, persistent=False) 121 | arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) 122 | scale = ( 123 | (arange + 0.4 * self.dim) / (1.4 * self.dim) 124 | if self.scale_base is not None 125 | else None 126 | ) 127 | self.register_buffer("scale", scale) 128 | 129 | def _compute_inv_freq(self, device=None): 130 | return 1 / ( 131 | self.base 132 | ** ( 133 | torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) 134 | / self.dim 135 | ) 136 | ) 137 | 138 | def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): 139 | # Reset the tables if the sequence length has changed, 140 | # if we're on a new device (possibly due to tracing for instance), 141 | # or if we're switching from inference mode to training 142 | if ( 143 | seqlen > self._seq_len_cached 144 | or self._cos_cached is None 145 | or self._cos_cached.device != device 146 | or self._cos_cached.dtype != dtype 147 | or (self.training and self._cos_cached.is_inference()) 148 | ): 149 | self._seq_len_cached = seqlen 150 | # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 151 | # And the output of arange can be quite large, so bf16 would lose a lot of precision. 152 | # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. 153 | if self.pos_idx_in_fp32: 154 | t = torch.arange(seqlen, device=device, dtype=torch.float32) 155 | t /= self.scaling_factor 156 | # We want fp32 here as well since inv_freq will be multiplied with t, and the output 157 | # will be large. Having it in bf16 will lose a lot of precision and cause the 158 | # cos & sin output to change significantly. 159 | # We want to recompute self.inv_freq if it was not loaded in fp32 160 | if self.inv_freq.dtype != torch.float32: 161 | inv_freq = self.inv_freq.to(torch.float32) 162 | else: 163 | inv_freq = self.inv_freq 164 | else: 165 | t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) 166 | t /= self.scaling_factor 167 | inv_freq = self.inv_freq 168 | # Don't do einsum, it converts fp32 to fp16 under AMP 169 | # freqs = torch.einsum("i,j->ij", t, self.inv_freq) 170 | freqs = torch.outer(t, inv_freq) 171 | 172 | if self.scale is None: 173 | self._cos_cached = torch.cos(freqs).to(dtype) 174 | self._sin_cached = torch.sin(freqs).to(dtype) 175 | else: 176 | power = ( 177 | torch.arange( 178 | seqlen, dtype=self.scale.dtype, device=self.scale.device 179 | ) 180 | - seqlen // 2 181 | ) / self.scale_base 182 | scale = self.scale.to(device=power.device) ** power.unsqueeze(-1) 183 | # We want the multiplication by scale to happen in fp32 184 | self._cos_cached = (torch.cos(freqs) * scale).to(dtype) 185 | self._sin_cached = (torch.sin(freqs) * scale).to(dtype) 186 | self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) 187 | self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) 188 | 189 | def forward( 190 | self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0 191 | ) -> Tuple[torch.Tensor, torch.Tensor]: 192 | """ 193 | q: (batch, seqlen, nheads, headdim) 194 | k: (batch, seqlen, nheads, headdim) 195 | seqlen_offset: can be used in generation where the qkv being passed in is only the last 196 | token in the batch. 197 | """ 198 | self._update_cos_sin_cache( 199 | q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype 200 | ) 201 | assert self._cos_cached is not None 202 | assert self._sin_cached is not None 203 | if self.scale is None: 204 | return ( 205 | apply_rotary_emb_torch( 206 | q, 207 | self._cos_cached[seqlen_offset:], 208 | self._sin_cached[seqlen_offset:], 209 | self.interleaved, 210 | True, # inplace=True 211 | ), 212 | apply_rotary_emb_torch( 213 | k, 214 | self._cos_cached[seqlen_offset:], 215 | self._sin_cached[seqlen_offset:], 216 | self.interleaved, 217 | True, # inplace=True 218 | ), 219 | ) # type: ignore 220 | else: 221 | assert False 222 | -------------------------------------------------------------------------------- /src/esm/layers/structure_proj.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from src.esm.utils.constants.physics import ( 5 | BB_COORDINATES, 6 | ) 7 | from src.esm.utils.structure.affine3d import ( 8 | Affine3D, 9 | RotationMatrix, 10 | ) 11 | 12 | 13 | class Dim6RotStructureHead(nn.Module): 14 | # Normally, AF2 uses quaternions to specify rotations. There's some evidence that 15 | # other representations are more well behaved - the best one according to 16 | # https://openaccess.thecvf.com/content_CVPR_2019/papers/Zhou_On_the_Continuity_of_Rotation_Representations_in_Neural_Networks_CVPR_2019_paper.pdf 17 | # is using graham schmidt on 2 vectors, which is implemented here. 18 | def __init__( 19 | self, 20 | input_dim: int, 21 | trans_scale_factor: float = 10, 22 | norm_type: str = "layernorm", 23 | activation_fn: str = "esm_gelu", 24 | predict_torsion_angles: bool = True, 25 | ): 26 | super().__init__() 27 | self.ffn1 = nn.Linear(input_dim, input_dim) 28 | self.activation_fn = nn.GELU() 29 | self.norm = nn.LayerNorm(input_dim) 30 | self.proj = nn.Linear(input_dim, 9 + 7 * 2) 31 | self.trans_scale_factor = trans_scale_factor 32 | self.predict_torsion_angles = predict_torsion_angles 33 | self.bb_local_coords = torch.tensor(BB_COORDINATES).float() 34 | 35 | def forward(self, x, affine, affine_mask, **kwargs): 36 | if affine is None: 37 | rigids = Affine3D.identity( 38 | x.shape[:-1], 39 | dtype=x.dtype, 40 | device=x.device, 41 | requires_grad=self.training, 42 | rotation_type=RotationMatrix, 43 | ) 44 | else: 45 | rigids = affine 46 | 47 | # [*, N] 48 | x = self.ffn1(x) 49 | x = self.activation_fn(x) 50 | x = self.norm(x) 51 | trans, x, y, angles = self.proj(x).split([3, 3, 3, 7 * 2], dim=-1) 52 | trans = trans * self.trans_scale_factor 53 | x = x / (x.norm(dim=-1, keepdim=True) + 1e-5) 54 | y = y / (y.norm(dim=-1, keepdim=True) + 1e-5) 55 | update = Affine3D.from_graham_schmidt(x + trans, trans, y + trans) 56 | rigids = rigids.compose(update.mask(affine_mask)) 57 | affine = rigids.tensor 58 | 59 | # We approximate the positions of the backbone atoms in the global frame by applying the rigid 60 | # transformation to the mean of the backbone atoms in the local frame. 61 | all_bb_coords_local = ( 62 | self.bb_local_coords[None, None, :, :] 63 | .expand(*x.shape[:-1], 3, 3) 64 | .to(x.device) 65 | ) 66 | pred_xyz = rigids[..., None].apply(all_bb_coords_local) 67 | 68 | return affine, pred_xyz 69 | -------------------------------------------------------------------------------- /src/esm/layers/transformer_stack.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from src.esm.layers.blocks import UnifiedTransformerBlock 7 | from src.esm.utils.structure.affine3d import Affine3D 8 | 9 | 10 | class TransformerStack(nn.Module): 11 | """ 12 | A stack of transformer blocks used in the ESM-3 model. Each block is a UnifiedTransformerBlock, 13 | which can either be geometric attention or standard multi-head attention. 14 | 15 | Args: 16 | d_model (int): The dimensionality of the input and output feature vectors. 17 | n_heads (int): The number of attention heads. 18 | v_heads (int): The number of voting heads. 19 | n_layers (int): The number of transformer blocks in the stack. 20 | n_layers_geom (int, optional): The number of transformer blocks that use geometric attention. 21 | scale_residue (bool, optional): Whether to scale the residue connections in each transformer block. 22 | mask_and_zero_frameless (bool, optional): Whether to mask and zero frameless positions in the input. 23 | Only applies in the geometric attention blocks, which is conditioned on the structure 24 | """ 25 | 26 | def __init__( 27 | self, 28 | d_model: int, 29 | n_heads: int, 30 | v_heads: int | None, 31 | n_layers: int, 32 | n_layers_geom: int = 1, 33 | scale_residue: bool = True, 34 | mask_and_zero_frameless: bool = False, 35 | bias: bool = False, 36 | qk_layernorm: bool = True, 37 | ffn_type: str = "swiglu", # swiglu | gelu 38 | expansion_ratio: float = 8 / 3, 39 | ): 40 | super().__init__() 41 | self.blocks = nn.ModuleList( 42 | [ 43 | UnifiedTransformerBlock( 44 | d_model, 45 | n_heads, 46 | v_heads=v_heads, 47 | use_geom_attn=i < n_layers_geom, 48 | residue_scaling_factor=( 49 | math.sqrt(n_layers / 36) if scale_residue else 1.0 50 | ), 51 | expansion_ratio=expansion_ratio, 52 | mask_and_zero_frameless=mask_and_zero_frameless, 53 | bias=bias, 54 | qk_layernorm=qk_layernorm, 55 | ffn_type=ffn_type, 56 | ) 57 | for i in range(n_layers) 58 | ] 59 | ) 60 | self.norm = nn.LayerNorm(d_model, bias=False) 61 | 62 | def forward( 63 | self, 64 | x: torch.Tensor, 65 | sequence_id: torch.Tensor | None = None, 66 | affine: Affine3D | None = None, 67 | affine_mask: torch.Tensor | None = None, 68 | chain_id: torch.Tensor | None = None, 69 | ) -> tuple[torch.Tensor, torch.Tensor]: 70 | """ 71 | Forward pass of the TransformerStack. 72 | 73 | Args: 74 | x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, d_model). 75 | sequence_id (torch.Tensor): The sequence ID tensor of shape (batch_size, sequence_length). 76 | affine (Affine3D | None): The affine transformation tensor or None. 77 | affine_mask (torch.Tensor | None): The affine mask tensor or None. 78 | chain_id (torch.Tensor): The protein chain tensor of shape (batch_size, sequence_length). 79 | Only used in geometric attention. 80 | 81 | Returns: 82 | post_norm: The output tensor of shape (batch_size, sequence_length, d_model). 83 | pre_norm: The embedding of shape (batch_size, sequence_length, d_model). 84 | """ 85 | *batch_dims, _ = x.shape 86 | if sequence_id is None: 87 | sequence_id = torch.ones( 88 | size=batch_dims, dtype=torch.int64, device=x.device 89 | ) 90 | if chain_id is None: 91 | chain_id = torch.ones(size=batch_dims, dtype=torch.int64, device=x.device) 92 | for block in self.blocks: 93 | x = block(x, sequence_id, affine, affine_mask, chain_id) 94 | return self.norm(x), x 95 | -------------------------------------------------------------------------------- /src/esm/pretrained.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from esm.models.esm3 import ESM3 7 | from esm.models.function_decoder import FunctionTokenDecoder 8 | from esm.models.vqvae import ( 9 | StructureTokenDecoder, 10 | StructureTokenEncoder, 11 | ) 12 | from esm.utils.constants.esm3 import data_root 13 | from esm.utils.constants.models import ( 14 | ESM3_FUNCTION_DECODER_V0, 15 | ESM3_OPEN_SMALL, 16 | ESM3_STRUCTURE_DECODER_V0, 17 | ESM3_STRUCTURE_ENCODER_V0, 18 | ) 19 | 20 | ModelBuilder = Callable[[torch.device | str], nn.Module] 21 | 22 | 23 | def ESM3_sm_open_v0(device: torch.device | str = "cpu"): 24 | model = ( 25 | ESM3( 26 | d_model=1536, 27 | n_heads=24, 28 | v_heads=256, 29 | n_layers=48, 30 | structure_encoder_name=ESM3_STRUCTURE_ENCODER_V0, 31 | structure_decoder_name=ESM3_STRUCTURE_DECODER_V0, 32 | function_decoder_name=ESM3_FUNCTION_DECODER_V0, 33 | ) 34 | .to(device) 35 | .eval() 36 | ) 37 | state_dict = torch.load( 38 | data_root() / "data/weights/esm3_sm_open_v1.pth", map_location=device 39 | ) 40 | model.load_state_dict(state_dict) 41 | return model 42 | 43 | 44 | def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"): 45 | model = ( 46 | StructureTokenEncoder( 47 | d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096 48 | ) 49 | .to(device) 50 | .eval() 51 | ) 52 | state_dict = torch.load( 53 | data_root() / "data/weights/esm3_structure_encoder_v0.pth", map_location=device 54 | ) 55 | model.load_state_dict(state_dict) 56 | return model 57 | 58 | 59 | def ESM3_structure_decoder_v0(device: torch.device | str = "cpu"): 60 | model = ( 61 | StructureTokenDecoder(d_model=1280, n_heads=20, n_layers=30).to(device).eval() 62 | ) 63 | state_dict = torch.load( 64 | data_root() / "data/weights/esm3_structure_decoder_v0.pth", map_location=device 65 | ) 66 | model.load_state_dict(state_dict) 67 | return model 68 | 69 | 70 | def ESM3_function_decoder_v0(device: torch.device | str = "cpu"): 71 | model = FunctionTokenDecoder().to(device).eval() 72 | state_dict = torch.load( 73 | data_root() / "data/weights/esm3_function_decoder_v0.pth", map_location=device 74 | ) 75 | model.load_state_dict(state_dict) 76 | return model 77 | 78 | 79 | LOCAL_MODEL_REGISTRY: dict[str, ModelBuilder] = { 80 | ESM3_OPEN_SMALL: ESM3_sm_open_v0, 81 | ESM3_STRUCTURE_ENCODER_V0: ESM3_structure_encoder_v0, 82 | ESM3_STRUCTURE_DECODER_V0: ESM3_structure_decoder_v0, 83 | ESM3_FUNCTION_DECODER_V0: ESM3_function_decoder_v0, 84 | } 85 | 86 | 87 | def load_local_model(model_name: str, device: torch.device | str = "cpu") -> nn.Module: 88 | if model_name not in LOCAL_MODEL_REGISTRY: 89 | raise ValueError(f"Model {model_name} not found in local model registry.") 90 | return LOCAL_MODEL_REGISTRY[model_name](device) 91 | 92 | 93 | # Register custom versions of ESM3 for use with the local inference API 94 | def register_local_model(model_name: str, model_builder: ModelBuilder) -> None: 95 | LOCAL_MODEL_REGISTRY[model_name] = model_builder 96 | -------------------------------------------------------------------------------- /src/esm/tokenization/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Protocol 3 | 4 | from src.esm.utils.constants.esm3 import VQVAE_SPECIAL_TOKENS 5 | from src.esm.utils.constants.models import ESM3_OPEN_SMALL 6 | 7 | from .function_tokenizer import InterProQuantizedTokenizer 8 | from .residue_tokenizer import ResidueAnnotationsTokenizer 9 | from .sasa_tokenizer import SASADiscretizingTokenizer 10 | from .sequence_tokenizer import EsmSequenceTokenizer 11 | from .ss_tokenizer import SecondaryStructureTokenizer 12 | from .structure_tokenizer import StructureTokenizer 13 | from .tokenizer_base import EsmTokenizerBase 14 | 15 | 16 | class TokenizerCollectionProtocol(Protocol): 17 | sequence: EsmSequenceTokenizer 18 | structure: StructureTokenizer 19 | secondary_structure: SecondaryStructureTokenizer 20 | sasa: SASADiscretizingTokenizer 21 | function: InterProQuantizedTokenizer 22 | residue_annotations: ResidueAnnotationsTokenizer 23 | 24 | 25 | @dataclass 26 | class TokenizerCollection: 27 | sequence: EsmSequenceTokenizer 28 | structure: StructureTokenizer 29 | secondary_structure: SecondaryStructureTokenizer 30 | sasa: SASADiscretizingTokenizer 31 | function: InterProQuantizedTokenizer 32 | residue_annotations: ResidueAnnotationsTokenizer 33 | 34 | 35 | def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection: 36 | if model == ESM3_OPEN_SMALL: 37 | return TokenizerCollection( 38 | sequence=EsmSequenceTokenizer(), 39 | structure=StructureTokenizer(vq_vae_special_tokens=VQVAE_SPECIAL_TOKENS), 40 | secondary_structure=SecondaryStructureTokenizer(kind="ss8"), 41 | sasa=SASADiscretizingTokenizer(), 42 | function=InterProQuantizedTokenizer(), 43 | residue_annotations=ResidueAnnotationsTokenizer(), 44 | ) 45 | else: 46 | raise ValueError(f"Unknown model: {model}") 47 | 48 | 49 | def get_invalid_tokenizer_ids(tokenizer: EsmTokenizerBase) -> list[int]: 50 | if isinstance(tokenizer, EsmSequenceTokenizer): 51 | return [ 52 | tokenizer.mask_token_id, # type: ignore 53 | tokenizer.pad_token_id, # type: ignore 54 | tokenizer.cls_token_id, # type: ignore 55 | tokenizer.eos_token_id, # type: ignore 56 | ] 57 | else: 58 | return [ 59 | tokenizer.mask_token_id, 60 | tokenizer.pad_token_id, 61 | tokenizer.bos_token_id, 62 | tokenizer.eos_token_id, 63 | ] 64 | -------------------------------------------------------------------------------- /src/esm/tokenization/residue_tokenizer.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | from pathlib import Path 3 | from typing import Any 4 | 5 | import pandas as pd 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from src.esm.tokenization.tokenizer_base import EsmTokenizerBase 10 | from src.esm.utils.constants import esm3 as C 11 | 12 | Sample = dict[str, Any] 13 | 14 | 15 | class ResidueAnnotationsTokenizer(EsmTokenizerBase): 16 | def __init__( 17 | self, 18 | csv_path: str | None = None, 19 | max_annotations: int = 16, 20 | ): 21 | if csv_path is None: 22 | csv_path = str(C.data_root() / C.RESID_CSV) 23 | self.csv_path = csv_path 24 | self.max_annotations = max_annotations 25 | 26 | @cached_property 27 | def _description2label(self) -> dict[str, str]: 28 | with Path(self.csv_path).open() as f: # type: ignore 29 | df = pd.read_csv(f) 30 | return dict(zip(df.label, df.label_clean)) 31 | 32 | @cached_property 33 | def _labels(self) -> list[str]: 34 | with Path(self.csv_path).open() as f: # type: ignore 35 | df = pd.read_csv(f) 36 | labels = ( 37 | df.groupby("label_clean")["count"] 38 | .sum() 39 | .sort_values(ascending=False, kind="stable") # type: ignore 40 | .index.tolist() 41 | ) 42 | assert isinstance(labels, list) 43 | return labels # type: ignore 44 | 45 | def _description2id(self, description: str) -> int | None: 46 | label = self._description2label.get(description) 47 | return self._label2id.get(label) # type: ignore 48 | 49 | @cached_property 50 | def _label2id(self) -> dict[str, int]: 51 | offset = len(self.special_tokens) + 1 # +1 for "" 52 | return {label: offset + i for i, label in enumerate(self._labels)} 53 | 54 | @cached_property 55 | def special_tokens(self) -> list[str]: 56 | """List of special tokens which come before cluster toknes in vocab.""" 57 | return ["", "", ""] 58 | 59 | @cached_property 60 | def vocab(self): 61 | annotation_tokens = [f"" for _, id in self._label2id.items()] 62 | return self.special_tokens + [""] + annotation_tokens 63 | 64 | @cached_property 65 | def vocab_to_index(self) -> dict[str, int]: 66 | return {token: token_id for token_id, token in enumerate(self.vocab)} 67 | 68 | @cached_property 69 | def vocabulary(self) -> list[str]: 70 | """Full vocabulary.""" 71 | return [*self.special_tokens, "", *self._labels] 72 | 73 | def get_special_tokens_mask(self, encoded: torch.Tensor) -> torch.Tensor: 74 | """Determines where in the sequence are special tokens.""" 75 | return encoded[:, 0] < len(self.special_tokens) 76 | 77 | def tokenize( 78 | self, sample: Sample | None, sequence: str, fail_on_mismatch: bool = False 79 | ) -> list[str]: 80 | """ 81 | # interpro_site_starts 82 | # interpro_site_ends # should always == interpro_site_starts. but I haven't checked overall. 83 | # interpro_site_residues # the residue identity of the specfic residue that is annotated. good for a sanity check that parsing occurred correctly. 84 | # interpro_site_descriptions 85 | # ASSERT (i.e. drop if bad) 86 | # interpro_site_residues matches the residue at that position 87 | # all these lists ^ above are the same length 88 | """ 89 | seqlen = len(sequence) 90 | assert seqlen >= 0 91 | # None mean sequence is *not annotated* - so use full 92 | if sample is None: 93 | return [""] * seqlen 94 | 95 | if any( 96 | sample.get(field) is None 97 | for field in [ 98 | "interpro_site_descriptions", 99 | "interpro_site_starts", 100 | "interpro_site_ends", 101 | "interpro_site_residues", 102 | ] 103 | ): 104 | return [""] * seqlen 105 | 106 | num_annotations = len(sample["interpro_site_descriptions"]) 107 | if any( 108 | len(sample[field]) != num_annotations 109 | for field in [ 110 | "interpro_site_starts", 111 | "interpro_site_ends", 112 | "interpro_site_residues", 113 | ] 114 | ): 115 | # mismatched length. 116 | return [""] * seqlen 117 | 118 | positional_ids = [set() for _ in range(seqlen)] 119 | for description, start, end, residues in zip( 120 | sample["interpro_site_descriptions"], 121 | sample["interpro_site_starts"], 122 | sample["interpro_site_ends"], 123 | sample["interpro_site_residues"], 124 | ): 125 | try: 126 | start = int(start) 127 | end = int(end) 128 | except (TypeError, ValueError): 129 | continue 130 | 131 | # Start / End are 1-indexed [inclusive, inclusive]. 132 | if start <= 0 or end > seqlen or start > end: 133 | print(f"invalid start/end: ({start}, {end}), len: {seqlen}") 134 | continue 135 | 136 | if len(residues) != (end - start) + 1: 137 | print(f"bad reference residue: {residues}") 138 | continue 139 | 140 | token_id = self._description2id(description) 141 | if token_id is None: 142 | token_id = self.vocab_to_index[""] 143 | 144 | for i, residue in zip(range(start - 1, end), residues): 145 | # If there are any mismatching residues, skip the entire sample. 146 | if sequence[i] != residue: 147 | if fail_on_mismatch: 148 | raise ValueError( 149 | f"Residue mismatch at position {i} (1-indexed): {sequence[i]} != {residue}" 150 | ) 151 | return [""] * seqlen 152 | 153 | positional_ids[i].add(token_id) 154 | 155 | tokens = [] 156 | for token_ids in positional_ids: 157 | if token_ids: 158 | token = "" 159 | else: 160 | token = "" 161 | tokens.append(token) 162 | return tokens 163 | 164 | def _token2ids(self, token: str) -> list[int]: 165 | if token.startswith(""): 166 | return [int(token_id) for token_id in token[4:-1].split(",")] 167 | else: 168 | token_id = self.vocab_to_index[token] 169 | return [token_id] 170 | 171 | def encode( 172 | self, tokens: list[str], add_special_tokens: bool = True 173 | ) -> torch.Tensor: 174 | token_ids = torch.full( 175 | size=(len(tokens), self.max_annotations), 176 | dtype=torch.int64, 177 | fill_value=self.vocab_to_index[""], 178 | ) 179 | for i, token in enumerate(tokens): 180 | ids = self._token2ids(token)[: self.max_annotations] 181 | token_ids[i, : len(ids)] = torch.tensor(ids) 182 | 183 | if add_special_tokens: 184 | token_ids = F.pad( 185 | token_ids, (0, 0, 1, 1), value=self.vocab_to_index[""] 186 | ) 187 | return token_ids 188 | 189 | def decode(self, encoded: torch.Tensor) -> list[str]: 190 | raise NotImplementedError( 191 | "Residue annotation decoding should be handled with util.decoding.decode_residue_annotations" 192 | ) 193 | 194 | @property 195 | def mask_token(self) -> str: 196 | return "" 197 | 198 | @property 199 | def mask_token_id(self) -> int: 200 | return self.vocab_to_index[self.mask_token] 201 | 202 | @property 203 | def bos_token(self) -> str: 204 | return "" 205 | 206 | @property 207 | def bos_token_id(self) -> int: 208 | return self.vocab_to_index[self.bos_token] 209 | 210 | @property 211 | def eos_token(self) -> str: 212 | return "" 213 | 214 | @property 215 | def eos_token_id(self) -> int: 216 | return self.vocab_to_index[self.eos_token] 217 | 218 | @property 219 | def pad_token(self) -> str: 220 | return "" 221 | 222 | @property 223 | def pad_token_id(self) -> int: 224 | return self.vocab_to_index[self.pad_token] 225 | -------------------------------------------------------------------------------- /src/esm/tokenization/sasa_tokenizer.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | 3 | import torch 4 | 5 | from src.esm.tokenization.tokenizer_base import EsmTokenizerBase 6 | from src.esm.utils.constants import esm3 as C 7 | 8 | 9 | class SASADiscretizingTokenizer(EsmTokenizerBase): 10 | """Tokenizer for Solvent Accessible Surface Area (SASA).""" 11 | 12 | def __init__(self, boundaries: list[float] = C.SASA_DISCRETIZATION_BOUNDARIES): 13 | self._boundaries = sorted(boundaries) 14 | 15 | @cached_property 16 | def special_tokens(self) -> list[str]: 17 | return ["", "", ""] 18 | 19 | @cached_property 20 | def vocab(self) -> list[str]: 21 | """Discrete token vocabulary. 22 | 23 | Returns: 24 | token vocabulary with ranges represented as "". 25 | """ 26 | boundary_strs = ["0"] + [str(b) for b in self._boundaries] + ["inf"] 27 | range_tokens = [ 28 | f"<{low}-{high}>" 29 | for low, high in zip(boundary_strs[:-1], boundary_strs[1:]) 30 | ] 31 | return self.special_tokens + range_tokens 32 | 33 | @cached_property 34 | def midpoints(self) -> list[float]: 35 | """Midpoints of the SASA token ranges.""" 36 | boundaries = [0] + self._boundaries + [self._boundaries[-1] * 2] 37 | midpoint_tokens = [ 38 | (float(high) + float(low)) / 2 39 | for low, high in zip(boundaries[:-1], boundaries[1:]) 40 | ] 41 | midpoint_tokens = [float("nan"), float("nan"), float("nan")] + midpoint_tokens 42 | return midpoint_tokens 43 | 44 | @cached_property 45 | def vocab_to_index(self) -> dict[str, int]: 46 | """Constructs token -> token id mapping.""" 47 | return {word: i for i, word in enumerate(self.vocab)} 48 | 49 | def get_special_tokens_mask(self, tokens: torch.Tensor) -> torch.Tensor: 50 | """Determines which positions are special tokens. 51 | 52 | Args: 53 | tokens: [length] 54 | Returns: 55 | [length] tensor, true where special tokens are located in the input. 56 | """ 57 | return tokens < len(self.special_tokens) 58 | 59 | def encode( 60 | self, values: list[float | str], add_special_tokens: bool = True 61 | ) -> torch.Tensor: 62 | """Encodes SASA values as discrete tokens. 63 | 64 | Args: 65 | values: list of either SASA values or individual tokens. For example 66 | [1.2, "", 10.3, , 0.] 67 | Returns: 68 | Token ids as tensor. Adds BOS and EOS special tokens. 69 | """ 70 | ids = [] 71 | if add_special_tokens: 72 | ids.append(self.vocab_to_index[""]) # BOS 73 | for value in values: 74 | if isinstance(value, (float, int)): 75 | bucket = torch.bucketize(value, torch.tensor(self._boundaries)) 76 | token_id = len(self.special_tokens) + bucket 77 | elif isinstance(value, str): 78 | token_id = self.vocab_to_index[value] 79 | else: 80 | raise TypeError(value) 81 | ids.append(token_id) 82 | if add_special_tokens: 83 | ids.append(self.vocab_to_index[""]) # EOS 84 | 85 | return torch.tensor(ids, dtype=torch.int64) 86 | 87 | def decode_float(self, encoded: torch.Tensor) -> list[float]: 88 | """Decodes SASA token ids into float values.""" 89 | return [self.midpoints[token_id] for token_id in encoded] 90 | 91 | def decode(self, encoded: torch.Tensor) -> str: 92 | """Decodes SASA token ids.""" 93 | return ",".join(self.vocab[i] for i in encoded) 94 | 95 | def decode_list(self, encoded: torch.Tensor) -> list[str]: 96 | """Decodes SASA token ids.""" 97 | return [self.vocab[i] for i in encoded] 98 | 99 | @property 100 | def mask_token(self) -> str: 101 | return "" 102 | 103 | @property 104 | def mask_token_id(self) -> int: 105 | return self.vocab_to_index[self.mask_token] 106 | 107 | @property 108 | def bos_token(self) -> str: 109 | return "" 110 | 111 | @property 112 | def bos_token_id(self) -> int: 113 | return self.vocab_to_index[self.bos_token] 114 | 115 | @property 116 | def eos_token(self) -> str: 117 | return "" 118 | 119 | @property 120 | def eos_token_id(self) -> int: 121 | return self.vocab_to_index[self.eos_token] 122 | 123 | @property 124 | def pad_token(self) -> str: 125 | return "" 126 | 127 | @property 128 | def pad_token_id(self) -> int: 129 | return self.vocab_to_index[self.pad_token] 130 | -------------------------------------------------------------------------------- /src/esm/tokenization/sequence_tokenizer.py: -------------------------------------------------------------------------------- 1 | from tokenizers import Tokenizer 2 | from tokenizers.models import BPE 3 | from tokenizers.processors import TemplateProcessing 4 | from transformers import PreTrainedTokenizerFast 5 | 6 | from src.esm.tokenization.tokenizer_base import EsmTokenizerBase 7 | from src.esm.utils.constants import esm3 as C 8 | 9 | 10 | class EsmSequenceTokenizer(PreTrainedTokenizerFast, EsmTokenizerBase): 11 | """ 12 | Constructs an ESM tokenizer. 13 | """ 14 | 15 | model_input_names = ["sequence_tokens", "attention_mask"] 16 | 17 | def __init__( 18 | self, 19 | unk_token="", 20 | cls_token="", 21 | pad_token="", 22 | mask_token="", 23 | eos_token="", 24 | chainbreak_token="|", 25 | **kwargs, 26 | ): 27 | all_tokens = C.SEQUENCE_VOCAB 28 | token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)} 29 | 30 | # a character-level tokenizer is the same as BPE with no token merges 31 | bpe = BPE(token_to_id, merges=[], unk_token=unk_token) 32 | tokenizer = Tokenizer(bpe) 33 | special_tokens = [cls_token, pad_token, mask_token, eos_token, chainbreak_token] 34 | additional_special_tokens = [chainbreak_token] 35 | 36 | tokenizer.add_special_tokens( 37 | special_tokens, 38 | ) 39 | 40 | # This is where we configure the automatic addition of special tokens when we call 41 | # tokenizer(text, add_special_tokens=True). Note that you can also configure how two 42 | # sequences are merged if you want. 43 | tokenizer.post_processor = TemplateProcessing( # type: ignore 44 | single=" $A ", 45 | special_tokens=[ 46 | ("", tokenizer.token_to_id("")), 47 | ("", tokenizer.token_to_id("")), 48 | ], 49 | ) 50 | super().__init__( 51 | tokenizer_object=tokenizer, 52 | unk_token=unk_token, 53 | cls_token=cls_token, 54 | pad_token=pad_token, 55 | mask_token=mask_token, 56 | eos_token=eos_token, 57 | additional_special_tokens=additional_special_tokens, 58 | **kwargs, 59 | ) 60 | 61 | # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here. 62 | @property 63 | def bos_token(self): 64 | return self.cls_token 65 | 66 | @property 67 | def bos_token_id(self): 68 | return self.cls_token_id 69 | -------------------------------------------------------------------------------- /src/esm/tokenization/ss_tokenizer.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | from typing import Sequence 3 | 4 | import torch 5 | 6 | from src.esm.tokenization.tokenizer_base import EsmTokenizerBase 7 | from src.esm.utils.constants import esm3 as C 8 | 9 | 10 | class SecondaryStructureTokenizer(EsmTokenizerBase): 11 | """Tokenizer for secondary structure strings.""" 12 | 13 | def __init__(self, kind: str = "ss8"): 14 | assert kind in ("ss8", "ss3") 15 | self.kind = kind 16 | 17 | @property 18 | def special_tokens(self) -> list[str]: 19 | return ["", "", ""] 20 | 21 | @cached_property 22 | def vocab(self): 23 | """Tokenzier vocabulary list.""" 24 | match self.kind: 25 | case "ss8": 26 | nonspecial_tokens = list(C.SSE_8CLASS_VOCAB) # "GHITEBSC" 27 | case "ss3": 28 | nonspecial_tokens = list(C.SSE_3CLASS_VOCAB) # HEC 29 | case _: 30 | raise ValueError(self.kind) 31 | 32 | # The non-special tokens ids match amino acid tokens ids when possible. 33 | return [*self.special_tokens, *nonspecial_tokens] 34 | 35 | @cached_property 36 | def vocab_to_index(self) -> dict[str, int]: 37 | """Constructs token -> token id mapping.""" 38 | return {word: i for i, word in enumerate(self.vocab)} 39 | 40 | def get_special_tokens_mask(self, tokens: torch.Tensor) -> torch.Tensor: 41 | """Determines which positions are special tokens. 42 | 43 | Args: 44 | tokens: [length] 45 | Returns: 46 | [length] tensor, true where special tokens are located in the input. 47 | """ 48 | return tokens < len(self.special_tokens) 49 | 50 | def encode( 51 | self, sequence: str | Sequence[str], add_special_tokens: bool = True 52 | ) -> torch.Tensor: 53 | """Encode secondary structure string 54 | 55 | Args: 56 | string: secondary structure string e.g. "GHHIT", or as token listk. 57 | Returns: 58 | [sequence_length] token ids representing. Will add /. 59 | """ 60 | ids = [] 61 | if add_special_tokens: 62 | ids.append(self.vocab_to_index[""]) # cls 63 | for char in sequence: 64 | ids.append(self.vocab_to_index[char]) 65 | if add_special_tokens: 66 | ids.append(self.vocab_to_index[""]) # eos 67 | return torch.tensor(ids, dtype=torch.int64) 68 | 69 | def decode(self, encoded: torch.Tensor) -> str: 70 | """Decodes token ids into secondary structure string. 71 | 72 | Args: 73 | encoded: [length] token id array. 74 | Returns 75 | Decoded secondary structure string. 76 | """ 77 | return "".join(self.vocab[i] for i in encoded) 78 | 79 | @property 80 | def mask_token(self) -> str: 81 | return "" 82 | 83 | @property 84 | def mask_token_id(self) -> int: 85 | return self.vocab_to_index[self.mask_token] 86 | 87 | @property 88 | def bos_token(self) -> str: 89 | return "" 90 | 91 | @property 92 | def bos_token_id(self) -> int: 93 | return self.vocab_to_index[self.bos_token] 94 | 95 | @property 96 | def eos_token(self) -> str: 97 | return "" 98 | 99 | @property 100 | def eos_token_id(self) -> int: 101 | return self.vocab_to_index[self.eos_token] 102 | 103 | @property 104 | def pad_token(self) -> str: 105 | return "" 106 | 107 | @property 108 | def pad_token_id(self) -> int: 109 | return self.vocab_to_index[self.pad_token] 110 | -------------------------------------------------------------------------------- /src/esm/tokenization/structure_tokenizer.py: -------------------------------------------------------------------------------- 1 | from src.esm.tokenization.tokenizer_base import EsmTokenizerBase 2 | 3 | 4 | class StructureTokenizer(EsmTokenizerBase): 5 | """A convenince class for accessing special token ids of 6 | the StructureTokenEncoder and StructureTokenDecoder.""" 7 | 8 | def __init__(self, vq_vae_special_tokens: dict[str, int]): 9 | self.vq_vae_special_tokens = vq_vae_special_tokens 10 | 11 | def mask_token(self) -> str: 12 | raise NotImplementedError( 13 | "Structure tokens are defined on 3D coordinates, not strings." 14 | ) 15 | 16 | @property 17 | def mask_token_id(self) -> int: 18 | return self.vq_vae_special_tokens["MASK"] 19 | 20 | def bos_token(self) -> str: 21 | raise NotImplementedError( 22 | "Structure tokens are defined on 3D coordinates, not strings." 23 | ) 24 | 25 | @property 26 | def bos_token_id(self) -> int: 27 | return self.vq_vae_special_tokens["BOS"] 28 | 29 | def eos_token(self) -> str: 30 | raise NotImplementedError( 31 | "Structure tokens are defined on 3D coordinates, not strings." 32 | ) 33 | 34 | @property 35 | def eos_token_id(self) -> int: 36 | return self.vq_vae_special_tokens["EOS"] 37 | 38 | def pad_token(self) -> str: 39 | raise NotImplementedError( 40 | "Structure tokens are defined on 3D coordinates, not strings." 41 | ) 42 | 43 | @property 44 | def pad_token_id(self) -> int: 45 | return self.vq_vae_special_tokens["PAD"] 46 | 47 | @property 48 | def chainbreak_token_id(self) -> int: 49 | return self.vq_vae_special_tokens["CHAINBREAK"] 50 | 51 | def encode(self, *args, **kwargs): 52 | raise NotImplementedError( 53 | "The StructureTokenizer class is provided as a convenience for " 54 | "accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.\n" 55 | "Please use them instead." 56 | ) 57 | 58 | def decode(self, *args, **kwargs): 59 | raise NotImplementedError( 60 | "The StructureTokenizer class is provided as a convenience for " 61 | "accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.\n" 62 | "Please use them instead." 63 | ) 64 | -------------------------------------------------------------------------------- /src/esm/tokenization/tokenizer_base.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, runtime_checkable 2 | 3 | 4 | @runtime_checkable 5 | class EsmTokenizerBase(Protocol): 6 | def encode(self, *args, **kwargs): 7 | ... 8 | 9 | def decode(self, *args, **kwargs): 10 | ... 11 | 12 | @property 13 | def mask_token(self) -> str: 14 | ... 15 | 16 | @property 17 | def mask_token_id(self) -> int: 18 | ... 19 | 20 | @property 21 | def bos_token(self) -> str: 22 | ... 23 | 24 | @property 25 | def bos_token_id(self) -> int: 26 | ... 27 | 28 | @property 29 | def eos_token(self) -> str: 30 | ... 31 | 32 | @property 33 | def eos_token_id(self) -> int: 34 | ... 35 | 36 | @property 37 | def pad_token(self) -> str: 38 | ... 39 | 40 | @property 41 | def pad_token_id(self) -> int: 42 | ... 43 | -------------------------------------------------------------------------------- /src/esm/utils/constants/esm3.py: -------------------------------------------------------------------------------- 1 | from functools import cache 2 | from pathlib import Path 3 | 4 | from huggingface_hub import snapshot_download 5 | 6 | SEQUENCE_BOS_TOKEN = 0 7 | SEQUENCE_PAD_TOKEN = 1 8 | SEQUENCE_EOS_TOKEN = 2 9 | SEQUENCE_CHAINBREAK_TOKEN = 31 10 | SEQUENCE_MASK_TOKEN = 32 11 | 12 | VQVAE_CODEBOOK_SIZE = 4096 13 | VQVAE_SPECIAL_TOKENS = { 14 | "MASK": VQVAE_CODEBOOK_SIZE, 15 | "EOS": VQVAE_CODEBOOK_SIZE + 1, 16 | "BOS": VQVAE_CODEBOOK_SIZE + 2, 17 | "PAD": VQVAE_CODEBOOK_SIZE + 3, 18 | "CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4, 19 | } 20 | VQVAE_DIRECTION_LOSS_BINS = 16 21 | VQVAE_PAE_BINS = 64 22 | VQVAE_MAX_PAE_BIN = 31.0 23 | VQVAE_PLDDT_BINS = 50 24 | 25 | STRUCTURE_MASK_TOKEN = VQVAE_SPECIAL_TOKENS["MASK"] 26 | STRUCTURE_BOS_TOKEN = VQVAE_SPECIAL_TOKENS["BOS"] 27 | STRUCTURE_EOS_TOKEN = VQVAE_SPECIAL_TOKENS["EOS"] 28 | STRUCTURE_PAD_TOKEN = VQVAE_SPECIAL_TOKENS["PAD"] 29 | STRUCTURE_CHAINBREAK_TOKEN = VQVAE_SPECIAL_TOKENS["CHAINBREAK"] 30 | STRUCTURE_UNDEFINED_TOKEN = 955 31 | 32 | SASA_UNK_TOKEN = 2 33 | SASA_PAD_TOKEN = 0 34 | 35 | SS8_UNK_TOKEN = 2 36 | SS8_PAD_TOKEN = 0 37 | 38 | INTERPRO_PAD_TOKEN = 0 39 | 40 | RESIDUE_PAD_TOKEN = 0 41 | 42 | CHAIN_BREAK_STR = "|" 43 | 44 | SEQUENCE_BOS_STR = "" 45 | SEQUENCE_EOS_STR = "" 46 | 47 | MASK_STR_SHORT = "_" 48 | SEQUENCE_MASK_STR = "" 49 | SASA_MASK_STR = "" 50 | SS8_MASK_STR = "" 51 | 52 | # fmt: off 53 | SEQUENCE_VOCAB = [ 54 | "", "", "", "", 55 | "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K", 56 | "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z", 57 | "O", ".", "-", "|", 58 | "", 59 | ] 60 | # fmt: on 61 | 62 | SSE_8CLASS_VOCAB = "GHITEBSC" 63 | SSE_3CLASS_VOCAB = "HEC" 64 | SSE_8CLASS_TO_3CLASS_MAP = { 65 | "G": "H", 66 | "H": "H", 67 | "I": "H", 68 | "T": "C", 69 | "E": "E", 70 | "B": "E", 71 | "S": "C", 72 | "C": "C", 73 | } 74 | 75 | SASA_DISCRETIZATION_BOUNDARIES = [ 76 | 0.8, 77 | 4.0, 78 | 9.6, 79 | 16.4, 80 | 24.5, 81 | 32.9, 82 | 42.0, 83 | 51.5, 84 | 61.2, 85 | 70.9, 86 | 81.6, 87 | 93.3, 88 | 107.2, 89 | 125.4, 90 | 151.4, 91 | ] 92 | 93 | MAX_RESIDUE_ANNOTATIONS = 16 94 | 95 | 96 | TFIDF_VECTOR_SIZE = 58641 97 | 98 | 99 | @staticmethod 100 | @cache 101 | def data_root(): 102 | # Try a few default directories 103 | for path in [ 104 | "esm/data", 105 | "esm/data", 106 | ]: 107 | if (p := Path(path)).exists(): 108 | return p.parent 109 | # Try to download from hugginface if it doesn't exist 110 | path = Path(snapshot_download(repo_id="EvolutionaryScale/esm3-sm-open-v1")) 111 | return path 112 | 113 | 114 | INTERPRO_ENTRY = "data/entry_list_safety_29026.list" 115 | INTERPRO_HIERARCHY = "data/ParentChildTreeFile.txt" 116 | INTERPRO2GO = "data/ParentChildTreeFile.txt" 117 | INTERPRO_2ID = "data/tag_dict_4_safety_filtered.json" 118 | 119 | LSH_TABLE_PATHS = { 120 | "8bit": "data/hyperplanes_8bit_58641.npz", 121 | } 122 | 123 | KEYWORDS_VOCABULARY = "data/keyword_vocabulary_safety_filtered_58641.txt" 124 | KEYWORDS_IDF = "data/keyword_idf_safety_filtered_58641.npy" 125 | 126 | RESID_CSV = "data/uniref90_and_mgnify90_residue_annotations_gt_1k_proteins.csv" 127 | INTERPRO2KEYWORDS = "data/interpro_29026_to_keywords_58641.csv" 128 | -------------------------------------------------------------------------------- /src/esm/utils/constants/models.py: -------------------------------------------------------------------------------- 1 | # Model names 2 | ESM3_OPEN_SMALL = "esm3_sm_open_v1" 3 | ESM3_STRUCTURE_ENCODER_V0 = "esm3_structure_encoder_v0" 4 | ESM3_STRUCTURE_DECODER_V0 = "esm3_structure_decoder_v0" 5 | ESM3_FUNCTION_DECODER_V0 = "esm3_function_decoder_v0" 6 | -------------------------------------------------------------------------------- /src/esm/utils/constants/physics.py: -------------------------------------------------------------------------------- 1 | BB_COORDINATES = [ 2 | [0.5256, 1.3612, 0.0000], 3 | [0.0000, 0.0000, 0.0000], 4 | [-1.5251, 0.0000, 0.0000], 5 | ] 6 | -------------------------------------------------------------------------------- /src/esm/utils/decoding.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import attr 4 | import torch 5 | 6 | from src.esm.models.function_decoder import FunctionTokenDecoder 7 | from src.esm.models.vqvae import StructureTokenDecoder 8 | from src.esm.sdk.api import ESMProtein, ESMProteinTensor 9 | from src.esm.tokenization import TokenizerCollectionProtocol 10 | from src.esm.tokenization.function_tokenizer import ( 11 | InterProQuantizedTokenizer, 12 | ) 13 | from src.esm.tokenization.residue_tokenizer import ( 14 | ResidueAnnotationsTokenizer, 15 | ) 16 | from src.esm.tokenization.sasa_tokenizer import ( 17 | SASADiscretizingTokenizer, 18 | ) 19 | from src.esm.tokenization.sequence_tokenizer import ( 20 | EsmSequenceTokenizer, 21 | ) 22 | from src.esm.tokenization.ss_tokenizer import ( 23 | SecondaryStructureTokenizer, 24 | ) 25 | from src.esm.tokenization.structure_tokenizer import ( 26 | StructureTokenizer, 27 | ) 28 | from src.esm.tokenization.tokenizer_base import EsmTokenizerBase 29 | from src.esm.utils.constants import esm3 as C 30 | from src.esm.utils.function.encode_decode import ( 31 | decode_function_tokens, 32 | decode_residue_annotation_tokens, 33 | ) 34 | from src.esm.utils.structure.protein_chain import ProteinChain 35 | from src.esm.utils.types import FunctionAnnotation 36 | 37 | 38 | def decode_protein_tensor( 39 | input: ESMProteinTensor, 40 | tokenizers: TokenizerCollectionProtocol, 41 | structure_token_decoder: StructureTokenDecoder, 42 | function_token_decoder: FunctionTokenDecoder, 43 | ) -> ESMProtein: 44 | input = attr.evolve(input) # Make a copy 45 | 46 | sequence = None 47 | secondary_structure = None 48 | sasa = None 49 | function_annotations = [] 50 | 51 | coordinates = None 52 | 53 | # If all pad tokens, set to None 54 | for track in attr.fields(ESMProteinTensor): 55 | tokens: torch.Tensor | None = getattr(input, track.name) 56 | if track.name == "coordinates": 57 | continue 58 | if tokens is not None: 59 | tokens = tokens[1:-1] # Remove BOS and EOS tokens 60 | tokens = tokens.flatten() # For multi-track tensors 61 | track_tokenizer = getattr(tokenizers, track.name) 62 | if torch.all(tokens == track_tokenizer.pad_token_id): 63 | setattr(input, track.name, None) 64 | 65 | if input.sequence is not None: 66 | sequence = decode_sequence(input.sequence, tokenizers.sequence) 67 | 68 | plddt, ptm = None, None 69 | if input.structure is not None: 70 | # Note: We give priority to the structure tokens over the coordinates when decoding 71 | coordinates, plddt, ptm = decode_structure( 72 | structure_tokens=input.structure, 73 | structure_decoder=structure_token_decoder, 74 | structure_tokenizer=tokenizers.structure, 75 | sequence=sequence, 76 | ) 77 | elif input.coordinates is not None: 78 | coordinates = input.coordinates[1:-1, ...] 79 | 80 | if input.secondary_structure is not None: 81 | secondary_structure = decode_secondary_structure( 82 | input.secondary_structure, tokenizers.secondary_structure 83 | ) 84 | if input.sasa is not None: 85 | sasa = decode_sasa(input.sasa, tokenizers.sasa) 86 | if input.function is not None: 87 | function_track_annotations = decode_function_annotations( 88 | input.function, 89 | function_token_decoder=function_token_decoder, 90 | function_tokenizer=tokenizers.function, 91 | ) 92 | function_annotations.extend(function_track_annotations) 93 | if input.residue_annotations is not None: 94 | residue_annotations = decode_residue_annotations( 95 | input.residue_annotations, tokenizers.residue_annotations 96 | ) 97 | function_annotations.extend(residue_annotations) 98 | 99 | return ESMProtein( 100 | sequence=sequence, 101 | secondary_structure=secondary_structure, 102 | sasa=sasa, # type: ignore 103 | function_annotations=function_annotations if function_annotations else None, 104 | coordinates=coordinates, 105 | plddt=plddt, 106 | ptm=ptm, 107 | ) 108 | 109 | 110 | def _bos_eos_warn(msg: str, tensor: torch.Tensor, tok: EsmTokenizerBase): 111 | if tensor[0] != tok.bos_token_id: 112 | warnings.warn( 113 | f"{msg} does not start with BOS token, token is ignored. BOS={tok.bos_token_id} vs {tensor}" 114 | ) 115 | if tensor[-1] != tok.eos_token_id: 116 | warnings.warn( 117 | f"{msg} does not end with EOS token, token is ignored. EOS='{tok.eos_token_id}': {tensor}" 118 | ) 119 | 120 | 121 | def decode_sequence( 122 | sequence_tokens: torch.Tensor, 123 | sequence_tokenizer: EsmSequenceTokenizer, 124 | **kwargs, 125 | ) -> str: 126 | _bos_eos_warn("Sequence", sequence_tokens, sequence_tokenizer) 127 | sequence = sequence_tokenizer.decode( 128 | sequence_tokens, 129 | **kwargs, 130 | ) 131 | sequence = sequence.replace(" ", "") 132 | sequence = sequence.replace(sequence_tokenizer.mask_token, C.MASK_STR_SHORT) 133 | sequence = sequence.replace(sequence_tokenizer.cls_token, "") 134 | sequence = sequence.replace(sequence_tokenizer.eos_token, "") 135 | 136 | return sequence 137 | 138 | 139 | def decode_structure( 140 | structure_tokens: torch.Tensor, 141 | structure_decoder: StructureTokenDecoder, 142 | structure_tokenizer: StructureTokenizer, 143 | sequence: str | None = None, 144 | ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: 145 | is_singleton = len(structure_tokens.size()) == 1 146 | if is_singleton: 147 | structure_tokens = structure_tokens.unsqueeze(0) 148 | else: 149 | raise ValueError( 150 | f"Only one structure can be decoded at a time, got structure tokens of shape {structure_tokens.size()}" 151 | ) 152 | _bos_eos_warn("Structure", structure_tokens[0], structure_tokenizer) 153 | 154 | decoder_output = structure_decoder.decode(structure_tokens) 155 | bb_coords: torch.Tensor = decoder_output["bb_pred"][ 156 | 0, 1:-1, ... 157 | ] # Remove BOS and EOS tokens 158 | bb_coords = bb_coords.detach().cpu() 159 | 160 | if "plddt" in decoder_output: 161 | plddt = decoder_output["plddt"][0, 1:-1] 162 | plddt = plddt.detach().cpu() 163 | else: 164 | plddt = None 165 | 166 | if "ptm" in decoder_output: 167 | ptm = decoder_output["ptm"] 168 | else: 169 | ptm = None 170 | 171 | chain = ProteinChain.from_backbone_atom_coordinates(bb_coords, sequence=sequence) 172 | chain = chain.infer_oxygen() 173 | return torch.tensor(chain.atom37_positions), plddt, ptm 174 | 175 | 176 | def decode_secondary_structure( 177 | secondary_structure_tokens: torch.Tensor, 178 | ss_tokenizer: SecondaryStructureTokenizer, 179 | ) -> str: 180 | _bos_eos_warn("Secondary structure", secondary_structure_tokens, ss_tokenizer) 181 | secondary_structure_tokens = secondary_structure_tokens[1:-1] 182 | secondary_structure = ss_tokenizer.decode( 183 | secondary_structure_tokens, 184 | ) 185 | return secondary_structure 186 | 187 | 188 | def decode_sasa( 189 | sasa_tokens: torch.Tensor, 190 | sasa_tokenizer: SASADiscretizingTokenizer, 191 | ) -> list[float]: 192 | _bos_eos_warn("SASA", sasa_tokens, sasa_tokenizer) 193 | sasa_tokens = sasa_tokens[1:-1] 194 | 195 | return sasa_tokenizer.decode_float(sasa_tokens) 196 | 197 | 198 | def decode_function_annotations( 199 | function_annotation_tokens: torch.Tensor, 200 | function_token_decoder: FunctionTokenDecoder, 201 | function_tokenizer: InterProQuantizedTokenizer, 202 | **kwargs, 203 | ) -> list[FunctionAnnotation]: 204 | # No need to check for BOS/EOS as function annotations are not affected 205 | 206 | function_annotations = decode_function_tokens( 207 | function_annotation_tokens, 208 | function_token_decoder=function_token_decoder, 209 | function_tokens_tokenizer=function_tokenizer, 210 | **kwargs, 211 | ) 212 | return function_annotations 213 | 214 | 215 | def decode_residue_annotations( 216 | residue_annotation_tokens: torch.Tensor, 217 | residue_annotation_decoder: ResidueAnnotationsTokenizer, 218 | ) -> list[FunctionAnnotation]: 219 | # No need to check for BOS/EOS as function annotations are not affected 220 | 221 | residue_annotations = decode_residue_annotation_tokens( 222 | residue_annotations_token_ids=residue_annotation_tokens, 223 | residue_annotations_tokenizer=residue_annotation_decoder, 224 | ) 225 | return residue_annotations 226 | -------------------------------------------------------------------------------- /src/esm/utils/encoding.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from src.esm.models.vqvae import StructureTokenEncoder 7 | from src.esm.tokenization.function_tokenizer import ( 8 | InterProQuantizedTokenizer as EsmFunctionTokenizer, 9 | ) 10 | from src.esm.tokenization.residue_tokenizer import ( 11 | ResidueAnnotationsTokenizer, 12 | ) 13 | from src.esm.tokenization.sasa_tokenizer import ( 14 | SASADiscretizingTokenizer, 15 | ) 16 | from src.esm.tokenization.sequence_tokenizer import ( 17 | EsmSequenceTokenizer, 18 | ) 19 | from src.esm.tokenization.ss_tokenizer import ( 20 | SecondaryStructureTokenizer, 21 | ) 22 | from src.esm.tokenization.structure_tokenizer import ( 23 | StructureTokenizer, 24 | ) 25 | from src.esm.utils.constants import esm3 as C 26 | from src.esm.utils.function.encode_decode import ( 27 | encode_function_annotations, 28 | ) 29 | from src.esm.utils.structure.protein_chain import ProteinChain 30 | from src.esm.utils.types import FunctionAnnotation 31 | 32 | 33 | # Raw Defaults 34 | def get_default_sequence(sequence_length: int) -> str: 35 | return C.MASK_STR_SHORT * sequence_length 36 | 37 | 38 | def get_default_secondary_structure(sequence_length: int) -> str: 39 | return C.MASK_STR_SHORT * sequence_length 40 | 41 | 42 | def get_default_sasa(sequence_length: int) -> Sequence[float | str | None]: 43 | return [None] * sequence_length 44 | 45 | 46 | # Tokenization 47 | def tokenize_sequence( 48 | sequence: str, 49 | sequence_tokenizer: EsmSequenceTokenizer, 50 | add_special_tokens: bool = True, 51 | ) -> torch.Tensor: 52 | sequence = sequence.replace(C.MASK_STR_SHORT, sequence_tokenizer.mask_token) 53 | sequence_tokens = sequence_tokenizer.encode( 54 | sequence, add_special_tokens=add_special_tokens 55 | ) 56 | sequence_tokens = torch.tensor(sequence_tokens, dtype=torch.int64) 57 | return sequence_tokens 58 | 59 | 60 | def tokenize_structure( 61 | coordinates: torch.Tensor, 62 | structure_encoder: StructureTokenEncoder, 63 | structure_tokenizer: StructureTokenizer, 64 | reference_sequence: str = "", 65 | add_special_tokens: bool = True, 66 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 67 | device = next(structure_encoder.parameters()).device 68 | chain = ProteinChain.from_atom37( 69 | coordinates, sequence=reference_sequence if reference_sequence else None 70 | ) 71 | 72 | # Setup padding 73 | if reference_sequence and len(reference_sequence) != coordinates.size(0): 74 | raise ValueError( 75 | f"Reference sequence length ({len(reference_sequence)}) does not match the number of residues in the coordinates ({coordinates.size(0)})" 76 | ) 77 | 78 | left_pad = 0 79 | right_pad = 0 80 | 81 | if add_special_tokens: 82 | left_pad += 1 # Add space for BOS token 83 | right_pad += 1 # Add space for EOS token 84 | 85 | coordinates, plddt, residue_index = chain.to_structure_encoder_inputs() 86 | coordinates = coordinates.to(device) # (1, L, 37, 3) 87 | plddt = plddt.to(device) # (1, L) 88 | residue_index = residue_index.to(device) # (1, L) 89 | _, structure_tokens = structure_encoder.encode( 90 | coordinates, residue_index=residue_index 91 | ) 92 | coordinates = torch.squeeze(coordinates, dim=0) # (L, 37, 3) # type: ignore 93 | plddt = torch.squeeze(plddt, dim=0) # (L,) # type: ignore 94 | structure_tokens = torch.squeeze(structure_tokens, dim=0) # (L,) # type: ignore 95 | 96 | # Add space for BOS and EOS tokens 97 | if add_special_tokens: 98 | coordinates = F.pad( 99 | coordinates, 100 | (0, 0, 0, 0, left_pad, right_pad), 101 | value=torch.inf, 102 | ) 103 | plddt = F.pad(plddt, (left_pad, right_pad), value=0) 104 | structure_tokens = F.pad( 105 | structure_tokens, 106 | (left_pad, right_pad), 107 | value=structure_tokenizer.pad_token_id, 108 | ) 109 | structure_tokens[0] = structure_tokenizer.bos_token_id 110 | structure_tokens[-1] = structure_tokenizer.eos_token_id 111 | return coordinates, plddt, structure_tokens 112 | 113 | 114 | def tokenize_secondary_structure( 115 | secondary_structure: str | Sequence[str], 116 | secondary_structure_tokenizer: SecondaryStructureTokenizer, 117 | add_special_tokens: bool = True, 118 | ) -> torch.Tensor: 119 | if isinstance(secondary_structure, str): 120 | # Ensure only one char per token 121 | secondary_structure = secondary_structure.replace( 122 | secondary_structure_tokenizer.mask_token, C.MASK_STR_SHORT 123 | ) 124 | 125 | # Input as list of chars 126 | secondary_structure = [char for char in secondary_structure] 127 | 128 | # Use tokenizer's mask token 129 | secondary_structure = [ 130 | secondary_structure_tokenizer.mask_token if char == C.MASK_STR_SHORT else char 131 | for char in secondary_structure 132 | ] 133 | 134 | secondary_structure_tokens = secondary_structure_tokenizer.encode( 135 | secondary_structure, add_special_tokens=add_special_tokens 136 | ) 137 | return secondary_structure_tokens 138 | 139 | 140 | def tokenize_sasa( 141 | sasa: Sequence[float | str | None], 142 | sasa_tokenizer: SASADiscretizingTokenizer, 143 | add_special_tokens: bool = True, 144 | ): 145 | sasa_tokens = sasa_tokenizer.encode( 146 | [sasa_tokenizer.mask_token if value is None else value for value in sasa], 147 | add_special_tokens=add_special_tokens, 148 | ) 149 | return sasa_tokens 150 | 151 | 152 | def tokenize_function_annotations( 153 | function_annotations: Sequence[FunctionAnnotation], 154 | reference_sequence: str, 155 | function_tokenizer: EsmFunctionTokenizer, 156 | residue_annotation_tokenizer: ResidueAnnotationsTokenizer, 157 | add_special_tokens: bool = True, 158 | ) -> tuple[torch.Tensor, torch.Tensor]: 159 | function_tokens, residue_annotation_tokens = encode_function_annotations( 160 | sequence=reference_sequence, 161 | function_annotations=function_annotations, 162 | function_tokens_tokenizer=function_tokenizer, 163 | residue_annotations_tokenizer=residue_annotation_tokenizer, 164 | add_special_tokens=add_special_tokens, 165 | ) 166 | return function_tokens, residue_annotation_tokens 167 | 168 | 169 | # Tokenized Defaults 170 | def get_default_sequence_tokens( 171 | sequence_length: int, 172 | sequence_tokenizer: EsmSequenceTokenizer, 173 | ) -> torch.Tensor: 174 | return tokenize_sequence( 175 | get_default_sequence(sequence_length), 176 | sequence_tokenizer, 177 | add_special_tokens=True, 178 | ) 179 | 180 | 181 | def get_default_structure_tokens( 182 | sequence_length: int, structure_tokenizer: StructureTokenizer 183 | ) -> torch.Tensor: 184 | structure_tokens = ( 185 | torch.ones( 186 | (sequence_length + 2,), 187 | dtype=torch.int64, 188 | ) 189 | * structure_tokenizer.pad_token_id 190 | ) 191 | # Always include BOS and EOS tokens 192 | structure_tokens[0] = structure_tokenizer.bos_token_id 193 | structure_tokens[-1] = structure_tokenizer.eos_token_id 194 | return structure_tokens 195 | 196 | 197 | def get_default_secondary_structure_tokens( 198 | sequence_length: int, secondary_structure_tokenizer: SecondaryStructureTokenizer 199 | ) -> torch.Tensor: 200 | return tokenize_secondary_structure( 201 | get_default_secondary_structure(sequence_length), 202 | secondary_structure_tokenizer, 203 | add_special_tokens=True, 204 | ) 205 | 206 | 207 | def get_default_sasa_tokens( 208 | sequence_length: int, sasa_tokenizer: SASADiscretizingTokenizer 209 | ) -> torch.Tensor: 210 | return tokenize_sasa( 211 | get_default_sasa(sequence_length), sasa_tokenizer, add_special_tokens=True 212 | ) 213 | 214 | 215 | def get_default_function_tokens( 216 | sequence_length: int, function_tokenizer: EsmFunctionTokenizer 217 | ) -> torch.Tensor: 218 | function_tokens = ( 219 | torch.ones((sequence_length + 2, function_tokenizer.depth), dtype=torch.int64) 220 | * function_tokenizer.pad_token_id 221 | ) 222 | # Always include BOS and EOS tokens 223 | function_tokens[0] = function_tokenizer.bos_token_id 224 | function_tokens[-1] = function_tokenizer.eos_token_id 225 | return function_tokens 226 | 227 | 228 | def get_default_residue_annotation_tokens( 229 | sequence_length: int, residue_annotation_tokenizer: ResidueAnnotationsTokenizer 230 | ) -> torch.Tensor: 231 | residue_annotation_tokens = ( 232 | torch.ones( 233 | (sequence_length + 2, C.MAX_RESIDUE_ANNOTATIONS), 234 | dtype=torch.int64, 235 | ) 236 | * residue_annotation_tokenizer.pad_token_id 237 | ) 238 | # Always include BOS and EOS tokens 239 | residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id 240 | residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id 241 | return residue_annotation_tokens 242 | -------------------------------------------------------------------------------- /src/esm/utils/function/encode_decode.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Sequence 3 | 4 | import torch 5 | 6 | from src.esm.models.function_decoder import ( 7 | FunctionTokenDecoder, 8 | _merge_annotations, 9 | ) 10 | from src.esm.tokenization.function_tokenizer import ( 11 | InterProQuantizedTokenizer, 12 | ) 13 | from src.esm.tokenization.residue_tokenizer import ( 14 | ResidueAnnotationsTokenizer, 15 | ) 16 | from src.esm.utils.constants import esm3 as C 17 | from src.esm.utils.types import FunctionAnnotation 18 | 19 | 20 | def encode_function_annotations( 21 | sequence: str, 22 | function_annotations: Sequence[FunctionAnnotation], 23 | function_tokens_tokenizer: InterProQuantizedTokenizer, 24 | residue_annotations_tokenizer: ResidueAnnotationsTokenizer, 25 | add_special_tokens: bool = True, 26 | ) -> tuple[torch.Tensor, torch.Tensor]: 27 | assert isinstance( 28 | residue_annotations_tokenizer, ResidueAnnotationsTokenizer 29 | ), "residue_annotations_tokenizer must be of type ResidueAnnotationsTokenizer" 30 | 31 | # Split the user's annotations by type 32 | ft_annotations: list[FunctionAnnotation] = [] 33 | ra_annotations: list[FunctionAnnotation] = [] 34 | for fa in function_annotations: 35 | assert ( 36 | 1 <= fa.start <= fa.end <= len(sequence) 37 | ), f"Invalid (start, end) in function annotation {fa}. Indices 1-indexed and [inclusive, inclusive]" 38 | 39 | supported_label = False 40 | 41 | # Is it an InterPro label? 42 | if match := re.match(r"IPR\d+", fa.label): 43 | if match.group() in function_tokens_tokenizer.interpro_to_index: 44 | ft_annotations.append(fa) 45 | supported_label = True 46 | 47 | # Is it a function keyword? 48 | if fa.label in function_tokens_tokenizer._tfidf.vocab_to_index: 49 | ft_annotations.append(fa) 50 | supported_label = True 51 | 52 | # Is it a residue annotation? 53 | if fa.label in residue_annotations_tokenizer._labels: 54 | ra_annotations.append(fa) 55 | supported_label = True 56 | 57 | if not supported_label: 58 | raise ValueError(f"Unknown label in FunctionAnnotation: {fa.label}") 59 | 60 | # Convert function token FunctionAnnotations -> Tensor 61 | function_tokens = function_tokens_tokenizer.tokenize( 62 | annotations=ft_annotations, 63 | seqlen=len(sequence), 64 | ) 65 | function_token_ids = function_tokens_tokenizer.encode( 66 | function_tokens, add_special_tokens=add_special_tokens 67 | ) 68 | 69 | # Convert residue annotation FunctionAnnotations -> Tensor 70 | if ra_annotations: 71 | descriptions, starts, ends = zip( 72 | *[(anot.label, anot.start, anot.end) for anot in ra_annotations] 73 | ) 74 | else: 75 | descriptions = starts = ends = None 76 | ra_tokens = residue_annotations_tokenizer.tokenize( 77 | { 78 | "interpro_site_descriptions": descriptions, 79 | "interpro_site_starts": starts, 80 | "interpro_site_ends": ends, 81 | }, 82 | sequence=sequence, 83 | fail_on_mismatch=True, 84 | ) 85 | residue_annotation_ids = residue_annotations_tokenizer.encode( 86 | ra_tokens, add_special_tokens=add_special_tokens 87 | ) 88 | 89 | return function_token_ids, residue_annotation_ids 90 | 91 | 92 | def decode_function_tokens( 93 | function_token_ids: torch.Tensor, 94 | function_token_decoder: FunctionTokenDecoder, 95 | function_tokens_tokenizer: InterProQuantizedTokenizer, 96 | decoder_annotation_threshold: float = 0.1, 97 | annotation_min_length: int | None = 5, 98 | annotation_gap_merge_max: int | None = 3, 99 | ) -> list[FunctionAnnotation]: 100 | """Decodes model prediction logits into function predictions. 101 | 102 | Merges function token and residue annotation predictions into a single 103 | set of FunctionAnnotation predictions. 104 | 105 | Args: 106 | function_token_ids: Tensor [length, depth] of 107 | function token ids. 108 | residue_annotation_logits: Tensor [length, RA-vocab] of residue 109 | annotation binary classification logits. 110 | function_tokens_tokenizer: InterPro annotation tokenizer. 111 | residue_annotation_threshold: tokenizer of residue annotations. 112 | residue_annotation_threshold: predicted probability threshold for emitting 113 | a predicted residue annotation. 114 | Returns: 115 | Predicted function annotations merged from both predictions. 116 | """ 117 | assert ( 118 | function_token_ids.ndim == 2 119 | ), "function_token_ids must be of shape (length, depth)" 120 | 121 | annotations: list[FunctionAnnotation] = [] 122 | 123 | # Function Annotations from predicted function tokens. 124 | decoded = function_token_decoder.decode( 125 | function_token_ids, 126 | tokenizer=function_tokens_tokenizer, 127 | annotation_threshold=decoder_annotation_threshold, 128 | annotation_min_length=annotation_min_length, 129 | annotation_gap_merge_max=annotation_gap_merge_max, 130 | ) 131 | 132 | # Convert predicted InterPro annotation to FunctionAnnotation. 133 | annotations.extend(decoded["function_keywords"]) 134 | for annotation in decoded["interpro_annotations"]: 135 | annotation: FunctionAnnotation 136 | label = function_tokens_tokenizer.format_annotation(annotation) 137 | annotations.append( 138 | FunctionAnnotation(label=label, start=annotation.start, end=annotation.end) 139 | ) 140 | 141 | return annotations 142 | 143 | 144 | def decode_residue_annotation_tokens( 145 | residue_annotations_token_ids: torch.Tensor, 146 | residue_annotations_tokenizer: ResidueAnnotationsTokenizer, 147 | annotation_min_length: int | None = 5, 148 | annotation_gap_merge_max: int | None = 3, 149 | ) -> list[FunctionAnnotation]: 150 | """Decodes residue annotation tokens into FunctionAnnotations. 151 | 152 | Args: 153 | tokens: Tensor [length, MAX_RESIDUE_ANNOTATIONS] of residue annotation tokens. 154 | residue_annotations_tokenizer: Tokenizer of residue annotations. 155 | threshold: predicted probability threshold for emitting a predicted residue 156 | annotation. 157 | Returns: 158 | Predicted residue annotations. 159 | """ 160 | assert ( 161 | residue_annotations_token_ids.ndim == 2 162 | ), "logits must be of shape (length, MAX_RESIDUE_ANNOTATIONS)" 163 | 164 | annotations: list[FunctionAnnotation] = [] 165 | 166 | for depth in range(0, C.MAX_RESIDUE_ANNOTATIONS): 167 | token_ids = residue_annotations_token_ids[:, depth] 168 | for loc, vocab_index in torch.nonzero(token_ids).cpu().numpy(): 169 | label = residue_annotations_tokenizer.vocabulary[vocab_index] 170 | if label not in [*residue_annotations_tokenizer.special_tokens, ""]: 171 | annotation = FunctionAnnotation(label=label, start=loc, end=loc) 172 | annotations.append(annotation) 173 | 174 | annotations = _merge_annotations( 175 | annotations, 176 | merge_gap_max=annotation_gap_merge_max, 177 | ) 178 | 179 | # Drop very small annotations. 180 | if annotation_min_length is not None: 181 | annotations = [ 182 | annotation 183 | for annotation in annotations 184 | if annotation.end - annotation.start + 1 >= annotation_min_length 185 | ] 186 | 187 | return annotations 188 | -------------------------------------------------------------------------------- /src/esm/utils/function/interpro.py: -------------------------------------------------------------------------------- 1 | """Utilities for interacting with InterPro.""" 2 | 3 | import itertools 4 | import re 5 | from dataclasses import dataclass 6 | from enum import IntEnum, auto 7 | from functools import cached_property 8 | from pathlib import Path 9 | 10 | import networkx as nx 11 | import numpy as np 12 | import pandas as pd 13 | 14 | from src.esm.utils.constants import esm3 as C 15 | 16 | 17 | def parse_go_terms(text: str) -> list[str]: 18 | """Parses GO terms from a string. 19 | 20 | Args: 21 | text: String containing GO terms. Example: "GO:0008309, GO:1902267" Note that GO 22 | terms have exactly 7 digits. 23 | Returns: 24 | All GO terms found in the string. Example: ['GO:0008309', 'GO:1902267'] 25 | """ 26 | return re.findall(r"GO:(?:\d{7,})", text) 27 | 28 | 29 | def _parse_interpro2go(path: str) -> dict[str, list[str]]: 30 | """Parses InterPro2GO file into map. 31 | 32 | NOTE: this file has a very strange, non-standard format. 33 | 34 | Args: 35 | path: path to InterPro2GO file from: https://www.ebi.ac.uk/GOA/InterPro2GO 36 | Returns: 37 | Mapping from InterPro to list of associated GO terms. 38 | """ 39 | with Path(path).open("r") as f: 40 | text = f.read() 41 | df = pd.Series(text.split("\n"), name="line").to_frame() 42 | df = df[~df.line.str.startswith("!")] 43 | df["interpro_id"] = df.line.apply(lambda line: re.findall(r"IPR\d+", line)) 44 | df["go_ids"] = df.line.apply(parse_go_terms) 45 | df = df[df.go_ids.apply(len).gt(0) & df.interpro_id.apply(len).eq(1)] 46 | df["interpro_id"] = df["interpro_id"].apply(lambda xs: xs[0]) # type: ignore 47 | 48 | # Group all mappints together into a single map. 49 | df = ( 50 | df.groupby("interpro_id")["go_ids"] # type: ignore 51 | .apply(lambda group: list(itertools.chain.from_iterable(group))) 52 | .reset_index() 53 | ) 54 | return dict(zip(df.interpro_id, df.go_ids)) # type: ignore 55 | 56 | 57 | class InterProEntryType(IntEnum): 58 | """InterPro types and representation counts: 59 | 60 | Family 21,942 61 | Domain 14,053 62 | Homologous_superfamily 3,446 63 | Conserved_site 728 64 | Repeat 374 65 | Active_site 133 66 | Binding_site 75 67 | PTM 17 68 | """ 69 | 70 | ACTIVE_SITE = 0 71 | BINDING_SITE = auto() 72 | CONSERVED_SITE = auto() 73 | DOMAIN = auto() 74 | FAMILY = auto() 75 | HOMOLOGOUS_SUPERFAMILY = auto() 76 | PTM = auto() 77 | REPEAT = auto() 78 | UNKNOWN = auto() 79 | 80 | 81 | @dataclass 82 | class InterProEntry: 83 | """Represents an InterPro entry.""" 84 | 85 | id: str # Example: IPR000006 86 | type: InterProEntryType 87 | name: str # Example: "Metallothionein, vertebrate" 88 | description: str | None = None 89 | 90 | 91 | @dataclass(frozen=True) 92 | class InterProRangeAnnotation: 93 | """Represents a InterPro annotation along a range of residues in a protein.""" 94 | 95 | interpro_accession: str 96 | start_idx: int 97 | end_idx: int 98 | 99 | 100 | class InterPro: 101 | """Convenience class interacting with InterPro ontology/data.""" 102 | 103 | def __init__( 104 | self, 105 | entries_path: str | None = None, 106 | hierarchy_path: str | None = None, 107 | interpro2go_path: str | None = None, 108 | ): 109 | """Constructs interface to query InterPro entries.""" 110 | default = lambda x, d: x if x is not None else d 111 | self.entries_path = default(entries_path, str(C.data_root() / C.INTERPRO_ENTRY)) 112 | self.hierarchy_graph_path = default( 113 | hierarchy_path, str(C.data_root() / C.INTERPRO_HIERARCHY) 114 | ) 115 | self.interpro2go_path = default( 116 | interpro2go_path, str(C.data_root() / C.INTERPRO2GO) 117 | ) 118 | 119 | @cached_property 120 | def interpro2go(self) -> dict[str, list[str]]: 121 | """Reads the InterPro to GO term mapping.""" 122 | assert self.interpro2go_path is not None 123 | return _parse_interpro2go(self.interpro2go_path) 124 | 125 | @cached_property 126 | def entries_frame(self) -> pd.DataFrame: 127 | """Loads full InterPro entry set as a DataFrame. 128 | 129 | Colums are 130 | - "id": str interpro accession /id as 131 | - "type": InterProEntryType representing the type of annotation. 132 | - "name": Short name of the entry. 133 | """ 134 | with Path(self.entries_path).open("r") as f: 135 | df = pd.read_csv(f, sep="\t") 136 | assert all( 137 | col in df.columns for col in ["ENTRY_AC", "ENTRY_TYPE", "ENTRY_NAME"] 138 | ) 139 | df.rename( 140 | columns={ 141 | "ENTRY_AC": "id", 142 | "ENTRY_TYPE": "type", 143 | "ENTRY_NAME": "name", 144 | }, 145 | inplace=True, 146 | ) 147 | df["type"] = df.type.str.upper().apply( 148 | lambda type_name: InterProEntryType[type_name] 149 | ) 150 | return df 151 | 152 | @cached_property 153 | def entries(self) -> dict[str, InterProEntry]: 154 | """Returns all InterPro entries.""" 155 | return { 156 | row.id: InterProEntry( # type: ignore 157 | id=row.id, # type: ignore 158 | type=row.type, # type: ignore 159 | name=row.name, # type: ignore 160 | ) 161 | for row in self.entries_frame.itertuples() 162 | } 163 | 164 | def lookup_name(self, interpro_id: str) -> str | None: 165 | """Short name / title for an interpro id.""" 166 | if interpro_id not in self.entries: 167 | return None 168 | return self.entries[interpro_id].name 169 | 170 | def lookup_entry_type(self, interpro_id: str) -> InterProEntryType: 171 | """Looks up entry-type for an interpro id.""" 172 | if interpro_id in self.entries: 173 | return self.entries[interpro_id].type 174 | else: 175 | return InterProEntryType.UNKNOWN 176 | 177 | @cached_property 178 | def graph(self) -> nx.DiGraph: 179 | """Reads the InterPro hierarchy of InterPro.""" 180 | graph = nx.DiGraph() 181 | with Path(self.hierarchy_graph_path).open("r") as f: 182 | parents = [] 183 | for line in f: 184 | ipr = line.split("::", maxsplit=1)[0] 185 | ipr_strip = ipr.lstrip("-") 186 | level = (len(ipr) - len(ipr_strip)) // 2 187 | parents = parents[:level] 188 | graph.add_node(ipr_strip) 189 | if parents: 190 | graph.add_edge(ipr_strip, parents[-1]) 191 | parents.append(ipr_strip) 192 | return graph 193 | 194 | 195 | def parse_interpro_features( 196 | interpro_accessions: list[str], 197 | interpro_starts: list[int], 198 | interpro_ends: list[int], 199 | ) -> list[InterProRangeAnnotation]: 200 | """Parses raw InterPro ranges. 201 | 202 | Args: 203 | interpro_accessions: list of InterPro accessions 204 | interpro_starts: list of one-indexed inclusive residue locations where the 205 | annotation from `interpro_accesisons` begin. 206 | interpro_ends: list of one-indexed *inclusive* residue locations where the 207 | annotation from `interpro_accesisons` end. 208 | Returns: 209 | Collated InterProRangeAnnotations. NOTE that index conversion will convert range 210 | bounds to zero-indexed [inclusive, exclusive) start/end indices. 211 | """ 212 | assert len(interpro_accessions) == len(interpro_starts) == len(interpro_ends) 213 | 214 | # Residue locations from Uniprot/InterPro are [inclusive, inclusive] and 1-index. 215 | start_idcs = np.array(interpro_starts).astype(int) 216 | end_idcs = np.array(interpro_ends).astype(int) 217 | 218 | # We want to use Python's convention of [inclusive, exclusive) and 0-indexing. 219 | # Interpro residue indices are [inclusive, inclusive] and 1-indexing. 220 | # The conversion ends up being: 221 | # ```python 222 | # end_idcs += 1 # [inclusive, inclusive] -> [inclusive, exclusive) 223 | # start_idcs -= 1 # 1 -> 0 indexing 224 | # end_idcs -= 1 # 1 -> 0 indexing 225 | # ``` 226 | # Which simply results in: 227 | start_idcs -= 1 228 | 229 | ranges = [] 230 | for interpro_accession, start_idx, end_idx in zip( 231 | interpro_accessions, start_idcs, end_idcs 232 | ): 233 | # NOTE: Skip unintegrated Interpro labels, for now. 234 | if interpro_accession == "-": 235 | continue 236 | 237 | ranges.append( 238 | InterProRangeAnnotation( 239 | interpro_accession=interpro_accession, 240 | start_idx=start_idx, 241 | end_idx=end_idx, 242 | ) 243 | ) 244 | 245 | return ranges 246 | -------------------------------------------------------------------------------- /src/esm/utils/function/lsh.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | 5 | from src.esm.utils.types import PathLike 6 | 7 | 8 | class LSHTable: 9 | def __init__(self, n_bits: int, dim: int, hyperplanes: np.ndarray | None = None): 10 | if hyperplanes is None: 11 | hyperplanes = np.random.randn(n_bits, dim) 12 | hyperplanes = hyperplanes / np.linalg.norm( 13 | hyperplanes, axis=-1, keepdims=True 14 | ) 15 | else: 16 | assert hyperplanes.shape == (n_bits, dim), ( 17 | hyperplanes.shape, 18 | (n_bits, dim), 19 | ) 20 | assert hyperplanes is not None 21 | self.hyperplanes: np.ndarray = hyperplanes 22 | self.values = 1 << np.arange(n_bits) 23 | 24 | def __call__(self, array, tokenize: bool = True): 25 | similarity = self.hyperplanes @ array.T 26 | bits = np.where(similarity >= 0, 1, 0) 27 | if tokenize: 28 | tokens = bits.T @ self.values 29 | return tokens 30 | else: 31 | return bits.T 32 | 33 | 34 | class LSHTokenized: 35 | def __init__( 36 | self, 37 | n_bits: int, 38 | dim: int, 39 | num_tables: int = 1, 40 | filepath: PathLike | None = None, 41 | allow_create_hyperplanes: bool = False, # set this if you want the lsh to allow creation of hyperplanes 42 | ): 43 | table_hyperplanes = None 44 | if filepath is not None: 45 | filepath = Path(filepath) 46 | if not filepath.exists(): 47 | raise FileNotFoundError(filepath) 48 | table_hyperplanes = np.load(filepath) # type: ignore 49 | for i in range(num_tables): 50 | assert str(i) in table_hyperplanes, f"Missing hyperplane for table {i}" 51 | elif not allow_create_hyperplanes: 52 | raise RuntimeError( 53 | "Not allowed to create hyperplanes but no filepath provided" 54 | ) 55 | 56 | self.tables = [ 57 | LSHTable( 58 | n_bits, 59 | dim, 60 | table_hyperplanes[str(i)] if table_hyperplanes is not None else None, 61 | ) 62 | for i in range(num_tables) 63 | ] 64 | 65 | def write_hyperplanes(self, filepath: PathLike): 66 | hyperplanes: dict[str, np.ndarray] = { # type: ignore 67 | str(i): table.hyperplanes for i, table in enumerate(self.tables) 68 | } 69 | np.savez(filepath, **hyperplanes) 70 | 71 | def __call__(self, array): 72 | tokens = np.stack([table(array) for table in self.tables], 1) 73 | return tokens 74 | 75 | 76 | class LSHBitstream: 77 | def __init__( 78 | self, 79 | n_bits: int, 80 | dim: int, 81 | filepath: PathLike | None = None, 82 | allow_create_hyperplanes: bool = False, # set this if you want the lsh to allow creation of hyperplanes 83 | ): 84 | table_hyperplanes = None 85 | if filepath is not None: 86 | filepath = Path(filepath) 87 | if not filepath.exists(): 88 | raise FileNotFoundError(filepath) 89 | table_hyperplanes = np.load(filepath) 90 | elif not allow_create_hyperplanes: 91 | raise RuntimeError( 92 | "Not allowed to create hyperplanes but no filepath provided" 93 | ) 94 | 95 | self.table = LSHTable( 96 | n_bits, dim, table_hyperplanes if table_hyperplanes is not None else None 97 | ) 98 | 99 | def write_hyperplanes(self, filepath: PathLike): 100 | np.save(filepath, self.table.hyperplanes) 101 | 102 | def __call__(self, array): 103 | return self.table(array, tokenize=False) 104 | -------------------------------------------------------------------------------- /src/esm/utils/function/tfidf.py: -------------------------------------------------------------------------------- 1 | """Term-Frequency / Inverse Document Frequency (TF-IDF) model.""" 2 | 3 | from collections import Counter 4 | from functools import cached_property 5 | 6 | import numpy as np 7 | from scipy import sparse 8 | 9 | 10 | class TFIDFModel: 11 | """Term-Frequency / Inverse Document Frequency (TF-IDF) model. 12 | Mimics sklearn.feature_extraction.text.TfidfVectorizer with sublinear_tf=True 13 | """ 14 | 15 | def __init__(self, vocabulary_path: str, idf_path: str): 16 | with open(vocabulary_path, "r") as f: 17 | self.vocabulary = f.read().strip().split("\n") 18 | 19 | with open(idf_path, "rb") as f: 20 | self.idf_ = np.load(f) 21 | 22 | assert self.idf_.ndim == 1 23 | assert ( 24 | len(self.idf_) == len(self.vocabulary) 25 | ), f"IDF size must match vocabulary size, got {len(self.idf_)} and {len(self.vocabulary)}" 26 | 27 | @cached_property 28 | def vocab_to_index(self) -> dict[str, int]: 29 | return {term: index for index, term in enumerate(self.vocabulary)} 30 | 31 | def encode(self, terms: list[str]) -> sparse.csr_matrix: 32 | """Encodes terms as TF-IDF vectors. 33 | 34 | Args: 35 | terms: list of terms to encode. 36 | 37 | Returns: 38 | TF-IDF vector encoded as sparse matrix of shape (1, num_terms) 39 | """ 40 | counter = Counter(filter(self.vocabulary.__contains__, terms)) 41 | indices = [self.vocab_to_index[term] for term in counter] 42 | 43 | tf = np.array([count for term, count in counter.items()]) 44 | idf = np.take(self.idf_, indices) 45 | 46 | values = (1 + np.log(tf)) * idf 47 | values /= np.linalg.norm(values) 48 | 49 | return sparse.csr_matrix( 50 | (values, (np.zeros_like(indices), indices)), 51 | shape=(1, len(self.vocabulary)), 52 | ) 53 | 54 | def decode(self, vec: sparse.csr_matrix) -> list[str]: 55 | """Extract terms from TF-IDF.""" 56 | return [self.vocabulary[i] for i in vec.indices] 57 | -------------------------------------------------------------------------------- /src/esm/utils/generation.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import attr 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from src.esm.sdk.api import ( 8 | ESM3InferenceClient, 9 | ESMProtein, 10 | ESMProteinTensor, 11 | GenerationConfig, 12 | SamplingConfig, 13 | SamplingTrackConfig, 14 | ) 15 | from src.esm.tokenization import ( 16 | EsmTokenizerBase, 17 | TokenizerCollectionProtocol, 18 | ) 19 | from src.esm.utils.constants import esm3 as C 20 | from src.esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY 21 | 22 | 23 | def iterative_sampling_raw( 24 | client: ESM3InferenceClient, 25 | input: ESMProtein, 26 | config: GenerationConfig, 27 | ): 28 | # Keep structure tokens 29 | input_tokens = client.encode(input) 30 | 31 | output_tokens = client.generate(input_tokens, config) 32 | 33 | raw_protein = client.decode(output_tokens) 34 | 35 | track_to_sample = config.track 36 | 37 | if track_to_sample not in ["function", "residue_annotations"]: 38 | # Function and residue annotation encoding/decoding is lossy 39 | # There is no guarantee that decoding encoded tokens will yield the same input 40 | raw_protein.function_annotations = input.function_annotations 41 | 42 | return raw_protein 43 | 44 | 45 | def iterative_sampling_tokens( 46 | client: ESM3InferenceClient, 47 | input_tokens: ESMProteinTensor, 48 | config: GenerationConfig, 49 | tokenizers: TokenizerCollectionProtocol, 50 | ) -> ESMProteinTensor: 51 | track_to_sample = config.track 52 | 53 | # Get all tracks that require sampling 54 | all_tracks = [ 55 | f.name for f in attr.fields(SamplingConfig) if "embedding" not in f.name 56 | ] 57 | 58 | sequence_length = len(input_tokens) 59 | device = input_tokens.device 60 | 61 | # Initialize schedule and masks 62 | decoding_schedule = NOISE_SCHEDULE_REGISTRY[config.schedule] 63 | sampled_tokens = attr.evolve(input_tokens) # Make a copy 64 | 65 | if config.condition_on_coordinates_only and input_tokens.coordinates is not None: 66 | sampled_tokens.structure = None 67 | 68 | sampling_mask = torch.ones( 69 | sequence_length, 70 | dtype=torch.bool, 71 | device=device, 72 | ) 73 | sampling_mask[0] = False 74 | sampling_mask[-1] = False 75 | 76 | get_tokenizer: Callable[[str], EsmTokenizerBase] = lambda s: getattr(tokenizers, s) 77 | if getattr(sampled_tokens, track_to_sample) is None: 78 | if track_to_sample == "function": 79 | dims = (sequence_length, tokenizers.function.depth) 80 | elif track_to_sample == "residue_annotations": 81 | dims = (sequence_length, C.MAX_RESIDUE_ANNOTATIONS) 82 | else: 83 | dims = (sequence_length,) 84 | masked_tokens = torch.full( 85 | dims, 86 | get_tokenizer(track_to_sample).mask_token_id, 87 | dtype=torch.long, 88 | device=device, 89 | ) 90 | if track_to_sample == "sequence": 91 | masked_tokens[0] = tokenizers.sequence.cls_token_id # type: ignore 92 | masked_tokens[-1] = tokenizers.sequence.eos_token_id # type: ignore 93 | else: 94 | masked_tokens[0] = get_tokenizer(track_to_sample).bos_token_id 95 | masked_tokens[-1] = get_tokenizer(track_to_sample).eos_token_id 96 | 97 | setattr( 98 | sampled_tokens, 99 | track_to_sample, 100 | masked_tokens, 101 | ) 102 | else: 103 | is_mask: torch.Tensor = ( 104 | getattr(input_tokens, track_to_sample) 105 | == get_tokenizer(track_to_sample).mask_token_id 106 | ) 107 | if not is_mask.any().item(): 108 | raise ValueError(f"Cannot sample {config.track} when input has no masks.") 109 | sampling_mask = sampling_mask & is_mask 110 | 111 | # Decode 112 | 113 | def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None: 114 | return x.clone() if x is not None else None 115 | 116 | L = sequence_length - 2 117 | positions_sampled = 0 118 | for t in tqdm(range(config.num_steps)): 119 | # Single step sampling at all positions 120 | track_sample_config = SamplingTrackConfig() 121 | track_sample_config.invalid_ids = config.invalid_ids 122 | track_sample_config.temperature = config.temperature 123 | track_sample_config.top_p = config.top_p 124 | sampling_config = SamplingConfig(**{track_to_sample: track_sample_config}) # type: ignore 125 | 126 | forward_and_sample_output = client.forward_and_sample( 127 | sampled_tokens, sampling_config 128 | ) 129 | new_samples = forward_and_sample_output.protein_tensor 130 | 131 | # Calculate number of tokens to sample 132 | perc_masked = decoding_schedule(torch.tensor((t + 1) / config.num_steps)) 133 | num_to_sample = int((1 - perc_masked) * L) - positions_sampled 134 | positions_sampled += num_to_sample 135 | 136 | # Select tokens based on lowest entropy 137 | if track_to_sample in ["function", "residue_annotations"]: 138 | # TODO: Implement iterative decoding for function and residue_annotations 139 | # TODO: Fix encode/decode of interpro tokens (not yet supported) 140 | sampled_tokens.function = maybe_clone(input_tokens.function) 141 | sampled_tokens.residue_annotations = maybe_clone( 142 | input_tokens.residue_annotations 143 | ) 144 | if track_to_sample in track_to_sample: 145 | raise NotImplementedError( 146 | f"Iterative decoding for {track_to_sample} is not supported yet." 147 | ) 148 | continue 149 | 150 | sampling_mask = sampling_mask & ( 151 | getattr(sampled_tokens, track_to_sample) 152 | == get_tokenizer(track_to_sample).mask_token_id 153 | ) 154 | 155 | track_entropy: torch.Tensor = getattr( 156 | forward_and_sample_output.entropy, track_to_sample 157 | ) 158 | track_entropy = track_entropy.masked_fill( 159 | ~sampling_mask, torch.finfo(track_entropy.dtype).max 160 | ) 161 | _, indices = track_entropy.topk(num_to_sample, dim=-1, largest=False) 162 | is_top_k = ~( 163 | torch.arange(sequence_length, device=device)[:, None] != indices[None, :] 164 | ).all(-1) 165 | tokens_to_sample = sampling_mask & is_top_k 166 | 167 | old_track_samples = getattr(sampled_tokens, track_to_sample) 168 | new_track_samples = getattr(new_samples, track_to_sample) 169 | 170 | new_track_samples = torch.where( 171 | tokens_to_sample, new_track_samples, old_track_samples 172 | ) 173 | 174 | setattr(sampled_tokens, track_to_sample, new_track_samples) 175 | 176 | # Do not update tracks that were not sampled (e.g. keep None instead of masks) 177 | for track in all_tracks: 178 | if track != track_to_sample: 179 | setattr( 180 | sampled_tokens, 181 | track, 182 | maybe_clone(getattr(input_tokens, track)), 183 | ) 184 | 185 | return sampled_tokens 186 | -------------------------------------------------------------------------------- /src/esm/utils/misc.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import ContextManager, Sequence, TypeVar 3 | 4 | import numpy as np 5 | import torch 6 | 7 | MAX_SUPPORTED_DISTANCE = 1e6 8 | 9 | 10 | TSequence = TypeVar("TSequence", bound=Sequence) 11 | 12 | 13 | def slice_python_object_as_numpy( 14 | obj: TSequence, idx: int | list[int] | slice | np.ndarray 15 | ) -> TSequence: 16 | """ 17 | Slice a python object (like a list, string, or tuple) as if it was a numpy object. 18 | 19 | Example: 20 | >>> obj = "ABCDE" 21 | >>> slice_python_object_as_numpy(obj, [1, 3, 4]) 22 | "BDE" 23 | 24 | >>> obj = [1, 2, 3, 4, 5] 25 | >>> slice_python_object_as_numpy(obj, np.arange(5) < 3) 26 | [1, 2, 3] 27 | """ 28 | if isinstance(idx, int): 29 | idx = [idx] 30 | 31 | if isinstance(idx, np.ndarray) and idx.dtype == bool: 32 | sliced_obj = [obj[i] for i in np.where(idx)[0]] 33 | elif isinstance(idx, slice): 34 | sliced_obj = obj[idx] 35 | else: 36 | sliced_obj = [obj[i] for i in idx] 37 | 38 | match obj, sliced_obj: 39 | case str(), list(): 40 | sliced_obj = "".join(sliced_obj) 41 | case _: 42 | sliced_obj = obj.__class__(sliced_obj) # type: ignore 43 | 44 | return sliced_obj # type: ignore 45 | 46 | 47 | def rbf(values, v_min, v_max, n_bins=16): 48 | """ 49 | Returns RBF encodings in a new dimension at the end. 50 | """ 51 | rbf_centers = torch.linspace( 52 | v_min, v_max, n_bins, device=values.device, dtype=values.dtype 53 | ) 54 | rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1]) 55 | rbf_std = (v_max - v_min) / n_bins 56 | z = (values.unsqueeze(-1) - rbf_centers) / rbf_std 57 | return torch.exp(-(z**2)) 58 | 59 | 60 | def batched_gather(data, inds, dim=0, no_batch_dims=0): 61 | ranges = [] 62 | for i, s in enumerate(data.shape[:no_batch_dims]): 63 | r = torch.arange(s) 64 | r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) 65 | ranges.append(r) 66 | 67 | remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)] 68 | remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds 69 | ranges.extend(remaining_dims) 70 | return data[ranges] 71 | 72 | 73 | def node_gather(s: torch.Tensor, edges: torch.Tensor) -> torch.Tensor: 74 | return batched_gather(s.unsqueeze(-3), edges, -2, no_batch_dims=len(s.shape) - 1) 75 | 76 | 77 | def knn_graph( 78 | coords: torch.Tensor, 79 | coord_mask: torch.Tensor, 80 | padding_mask: torch.Tensor, 81 | sequence_id: torch.Tensor, 82 | *, 83 | no_knn: int, 84 | ): 85 | L = coords.shape[-2] 86 | num_by_dist = min(no_knn, L) 87 | device = coords.device 88 | 89 | coords = coords.nan_to_num() 90 | coord_mask = ~(coord_mask[..., None, :] & coord_mask[..., :, None]) 91 | padding_pairwise_mask = padding_mask[..., None, :] | padding_mask[..., :, None] 92 | if sequence_id is not None: 93 | padding_pairwise_mask |= torch.unsqueeze(sequence_id, 1) != torch.unsqueeze( 94 | sequence_id, 2 95 | ) 96 | dists = (coords.unsqueeze(-2) - coords.unsqueeze(-3)).norm(dim=-1) 97 | arange = torch.arange(L, device=device) 98 | seq_dists = (arange.unsqueeze(-1) - arange.unsqueeze(-2)).abs() 99 | # We only support up to a certain distance, above that, we use sequence distance 100 | # instead. This is so that when a large portion of the structure is masked out, 101 | # the edges are built according to sequence distance. 102 | max_dist = MAX_SUPPORTED_DISTANCE 103 | torch._assert_async((dists[~coord_mask] < max_dist).all()) 104 | struct_then_seq_dist = ( 105 | seq_dists.to(dists.dtype) 106 | .mul(1e2) 107 | .add(max_dist) 108 | .where(coord_mask, dists) 109 | .masked_fill(padding_pairwise_mask, torch.inf) 110 | ) 111 | dists, edges = struct_then_seq_dist.sort(dim=-1, descending=False) 112 | # This is a L x L tensor, where we index by rows first, 113 | # and columns are the edges we should pick. 114 | chosen_edges = edges[..., :num_by_dist] 115 | chosen_mask = dists[..., :num_by_dist].isfinite() 116 | return chosen_edges, chosen_mask 117 | 118 | 119 | def stack_variable_length_tensors( 120 | sequences: Sequence[torch.Tensor], 121 | constant_value: int | float = 0, 122 | dtype: torch.dtype | None = None, 123 | ) -> torch.Tensor: 124 | """Automatically stack tensors together, padding variable lengths with the 125 | value in constant_value. Handles an arbitrary number of dimensions. 126 | 127 | Examples: 128 | >>> tensor1, tensor2 = torch.ones([2]), torch.ones([5]) 129 | >>> stack_variable_length_tensors(tensor1, tensor2) 130 | tensor of shape [2, 5]. First row is [1, 1, 0, 0, 0]. Second row is all ones. 131 | 132 | >>> tensor1, tensor2 = torch.ones([2, 4]), torch.ones([5, 3]) 133 | >>> stack_variable_length_tensors(tensor1, tensor2) 134 | tensor of shape [2, 5, 4] 135 | """ 136 | batch_size = len(sequences) 137 | shape = [batch_size] + np.max([seq.shape for seq in sequences], 0).tolist() 138 | 139 | if dtype is None: 140 | dtype = sequences[0].dtype 141 | device = sequences[0].device 142 | 143 | array = torch.full(shape, constant_value, dtype=dtype, device=device) 144 | for arr, seq in zip(array, sequences): 145 | arrslice = tuple(slice(dim) for dim in seq.shape) 146 | arr[arrslice] = seq 147 | 148 | return array 149 | 150 | 151 | def unbinpack( 152 | tensor: torch.Tensor, sequence_id: torch.Tensor | None, pad_value: int | float 153 | ): 154 | """ 155 | Args: 156 | tensor (Tensor): [B, L, ...] 157 | 158 | Returns: 159 | Tensor: [B_unbinpacked, L_unbinpack, ...] 160 | """ 161 | if sequence_id is None: 162 | return tensor 163 | 164 | unpacked_tensors = [] 165 | num_sequences = sequence_id.max(dim=-1).values + 1 166 | for batch_idx, (batch_seqid, batch_num_sequences) in enumerate( 167 | zip(sequence_id, num_sequences) 168 | ): 169 | for seqid in range(batch_num_sequences): 170 | mask = batch_seqid == seqid 171 | unpacked = tensor[batch_idx, mask] 172 | unpacked_tensors.append(unpacked) 173 | return stack_variable_length_tensors(unpacked_tensors, pad_value) 174 | 175 | 176 | def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast]: 177 | """ 178 | Returns an autocast context manager that disables downcasting by AMP. 179 | 180 | Args: 181 | device_type: The device type ('cpu' or 'cuda') 182 | 183 | Returns: 184 | An autocast context manager with the specified behavior. 185 | """ 186 | if device_type == "cpu": 187 | return torch.amp.autocast(device_type, enabled=False) 188 | elif device_type == "cuda": 189 | return torch.amp.autocast(device_type, dtype=torch.float32) 190 | else: 191 | raise ValueError(f"Unsupported device type: {device_type}") 192 | 193 | 194 | def merge_ranges(ranges: list[range], merge_gap_max: int | None = None) -> list[range]: 195 | """Merge overlapping ranges into sorted, non-overlapping segments. 196 | 197 | Args: 198 | ranges: collection of ranges to merge. 199 | merge_gap_max: optionally merge neighboring ranges that are separated by a gap 200 | no larger than this size. 201 | Returns: 202 | non-overlapping ranges merged from the inputs, sorted by position. 203 | """ 204 | ranges = sorted(ranges, key=lambda r: r.start) 205 | merge_gap_max = merge_gap_max if merge_gap_max is not None else 0 206 | assert merge_gap_max >= 0, f"Invalid merge_gap_max: {merge_gap_max}" 207 | 208 | merged = [] 209 | for r in ranges: 210 | if not merged: 211 | merged.append(r) 212 | else: 213 | last = merged[-1] 214 | if last.stop + merge_gap_max >= r.start: 215 | merged[-1] = range(last.start, max(last.stop, r.stop)) 216 | else: 217 | merged.append(r) 218 | return merged 219 | 220 | 221 | def list_nan_to_none(l: list) -> list: 222 | if l is None: 223 | return None # type: ignore 224 | elif isinstance(l, float): 225 | return None if math.isnan(l) else l # type: ignore 226 | elif isinstance(l, list): 227 | return [list_nan_to_none(x) for x in l] 228 | else: 229 | # Don't go into other structures. 230 | return l 231 | 232 | 233 | def list_none_to_nan(l: list) -> list: 234 | if l is None: 235 | return math.nan # type: ignore 236 | elif isinstance(l, list): 237 | return [list_none_to_nan(x) for x in l] 238 | else: 239 | return l 240 | 241 | 242 | def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None: 243 | if x is None: 244 | return None 245 | if convert_none_to_nan: 246 | x = list_none_to_nan(x) 247 | return torch.tensor(x) 248 | 249 | 250 | def maybe_list(x, convert_nan_to_none: bool = False) -> list | None: 251 | if x is None: 252 | return None 253 | x = x.tolist() 254 | if convert_nan_to_none: 255 | x = list_nan_to_none(x) 256 | return x 257 | -------------------------------------------------------------------------------- /src/esm/utils/noise_schedules.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | def cosine_schedule(t: torch.Tensor): 7 | # t is a tensor of size (batch_size,) with values between 0 and 1. This is the 8 | # schedule used in the MaskGIT paper 9 | return torch.cos(t * math.pi * 0.5) 10 | 11 | 12 | def cubic_schedule(t): 13 | return 1 - t**3 14 | 15 | 16 | def linear_schedule(t): 17 | return 1 - t 18 | 19 | 20 | def square_root_schedule(t): 21 | return 1 - torch.sqrt(t) 22 | 23 | 24 | def square_schedule(t): 25 | return 1 - t**2 26 | 27 | 28 | NOISE_SCHEDULE_REGISTRY = { 29 | "cosine": cosine_schedule, 30 | "linear": linear_schedule, 31 | "square_root_schedule": square_root_schedule, 32 | "cubic": cubic_schedule, 33 | "square": square_schedule, 34 | } 35 | -------------------------------------------------------------------------------- /src/esm/utils/residue_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This mapping is used when we need to store atom data in a format that requires 17 | # fixed atom data size for every residue (e.g. a numpy array). 18 | atom_types = [ 19 | "N", 20 | "CA", 21 | "C", 22 | "CB", 23 | "O", 24 | "CG", 25 | "CG1", 26 | "CG2", 27 | "OG", 28 | "OG1", 29 | "SG", 30 | "CD", 31 | "CD1", 32 | "CD2", 33 | "ND1", 34 | "ND2", 35 | "OD1", 36 | "OD2", 37 | "SD", 38 | "CE", 39 | "CE1", 40 | "CE2", 41 | "CE3", 42 | "NE", 43 | "NE1", 44 | "NE2", 45 | "OE1", 46 | "OE2", 47 | "CH2", 48 | "NH1", 49 | "NH2", 50 | "OH", 51 | "CZ", 52 | "CZ2", 53 | "CZ3", 54 | "NZ", 55 | "OXT", 56 | ] 57 | atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} 58 | atom_type_num = len(atom_types) # := 37. 59 | 60 | restype_1to3 = { 61 | "A": "ALA", 62 | "R": "ARG", 63 | "N": "ASN", 64 | "D": "ASP", 65 | "C": "CYS", 66 | "Q": "GLN", 67 | "E": "GLU", 68 | "G": "GLY", 69 | "H": "HIS", 70 | "I": "ILE", 71 | "L": "LEU", 72 | "K": "LYS", 73 | "M": "MET", 74 | "F": "PHE", 75 | "P": "PRO", 76 | "S": "SER", 77 | "T": "THR", 78 | "W": "TRP", 79 | "Y": "TYR", 80 | "V": "VAL", 81 | } 82 | -------------------------------------------------------------------------------- /src/esm/utils/sampling.py: -------------------------------------------------------------------------------- 1 | import attr 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from src.esm.sdk.api import ( 6 | SamplingConfig, 7 | SamplingTrackConfig, 8 | ) 9 | from src.esm.tokenization import ( 10 | TokenizerCollection, 11 | get_invalid_tokenizer_ids, 12 | ) 13 | from src.esm.tokenization.function_tokenizer import ( 14 | InterProQuantizedTokenizer, 15 | ) 16 | from src.esm.utils.constants.esm3 import MAX_RESIDUE_ANNOTATIONS 17 | 18 | 19 | def get_default_sampling_config(tokenizers: TokenizerCollection) -> SamplingConfig: 20 | tracks = [f.name for f in attr.fields(SamplingConfig)] 21 | sampling_config = SamplingConfig() 22 | for current_track in tracks: 23 | setattr( 24 | sampling_config, 25 | current_track, 26 | SamplingTrackConfig( 27 | invalid_ids=get_invalid_tokenizer_ids( 28 | getattr(tokenizers, current_track) 29 | ), 30 | temperature=1.0, 31 | top_p=1.0, 32 | # TODO: Add different mask and padding tokens for all tracks 33 | # Some tracks have the same pad and mask, which causes ambiguity when sampling 34 | only_sample_masked_tokens=current_track 35 | not in ["secondary_structure", "sasa", "function"], 36 | ), 37 | ) 38 | return sampling_config 39 | 40 | 41 | def sample_logits( 42 | logits: torch.Tensor, 43 | temperature: float | torch.Tensor, 44 | top_p: float | torch.Tensor = 1.0, 45 | ): 46 | """Default sampling from logits. 47 | 48 | Args: 49 | logits is shape (..., vocab_size) 50 | temperature is broadcastable to (...) 51 | """ 52 | 53 | if top_p < 1.0: 54 | logits = top_p_logits(logits, top_p=top_p) 55 | 56 | temperature = _tensorize_like(temperature, logits) 57 | 58 | if torch.all(temperature == 0): 59 | ids = logits.argmax(-1) 60 | return ids 61 | 62 | assert not torch.any(temperature == 0), "Partial temperature 0 not supported." 63 | 64 | batch_dims = logits.size()[:-1] 65 | logits = logits.reshape(-1, logits.shape[-1]) 66 | 67 | # Sample from all logits 68 | probs = F.softmax(logits / temperature[..., None], dim=-1) 69 | ids = torch.multinomial(probs, 1).squeeze(1) 70 | 71 | ids = ids.reshape(*batch_dims) 72 | return ids 73 | 74 | 75 | def sample_function_logits( 76 | logits: torch.Tensor, 77 | tokenizer: InterProQuantizedTokenizer, 78 | top_p: float | torch.Tensor = 1.0, 79 | temperature: float | torch.Tensor = 1.0, 80 | p_none_threshold: float = 0.05, 81 | ) -> tuple[torch.Tensor, torch.Tensor]: 82 | [L, D, V] = logits.shape 83 | assert D == tokenizer.depth 84 | 85 | if top_p < 1.0: 86 | logits = top_p_logits(logits, top_p=top_p) 87 | 88 | temperature = torch.ones_like(logits[..., 0]) * temperature 89 | 90 | log_p = F.log_softmax(logits / temperature[..., None], dim=-1) # (L, D, V) 91 | 92 | # Choose which positions have no predicted function. 93 | log_p_nones = log_p[..., tokenizer.vocab_to_index[""]] # (L, D) 94 | p_none = torch.exp(log_p_nones).mean(dim=-1) # "Ensemble of predictions" 95 | where_none = p_none > p_none_threshold # (L, ) 96 | 97 | # Set probability of to 0 for all not-none positions 98 | none_index = tokenizer.vocab_to_index[""] 99 | log_p[~where_none, :, none_index] = -torch.inf 100 | 101 | ids = torch.argmax(log_p, dim=-1) # (L, D) 102 | ids[where_none, :] = tokenizer.vocab_to_index[""] 103 | 104 | return ids, log_p 105 | 106 | 107 | def sample_residue_annotation_logits( 108 | logits: torch.Tensor, annotation_threshold: float = 0.5 109 | ) -> tuple[torch.Tensor, torch.Tensor]: 110 | # Take top residue annotations 111 | top_residue_annotations_idx = logits.argsort(dim=-1, descending=True)[ 112 | ..., :MAX_RESIDUE_ANNOTATIONS 113 | ] # (L, MAX_R) 114 | top_residue_annotations_logprobs = torch.gather( 115 | F.logsigmoid(logits), -1, top_residue_annotations_idx 116 | ) # (L, MAX_R) 117 | top_residue_annotations_probs = top_residue_annotations_logprobs.exp() 118 | # Keep only positive predictions 119 | is_negative = top_residue_annotations_probs < annotation_threshold 120 | top_residue_annotations_idx[is_negative] = 0 121 | 122 | top_residue_annotations_logprobs = top_residue_annotations_logprobs 123 | 124 | return top_residue_annotations_idx, top_residue_annotations_logprobs 125 | 126 | 127 | def top_p_logits( 128 | logits: torch.Tensor, 129 | top_p: float | torch.Tensor, 130 | ) -> torch.Tensor: 131 | top_p = _tensorize_like(top_p, logits) 132 | 133 | batch_dims = logits.size()[:-1] 134 | logits = logits.reshape(-1, logits.shape[-1]) 135 | 136 | # Sort logits in descending order and extract the mask for the top_p 137 | sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) 138 | cumsum_logits = sorted_logits.softmax(-1).cumsum(-1) 139 | top_p_mask = cumsum_logits <= top_p[:, None] 140 | 141 | # Make sure at least one token is sampled 142 | top_p_mask[:, 0] = True 143 | 144 | # Mask out the logits that are not in the top_p 145 | batch_indices_to_mask, _ = torch.where(~top_p_mask) 146 | vocab_indices_to_mask = sorted_indices[~top_p_mask] 147 | logits[batch_indices_to_mask, vocab_indices_to_mask] = torch.finfo(logits.dtype).min 148 | 149 | return logits.reshape(*batch_dims, -1) 150 | 151 | 152 | def _tensorize_like(value: int | float | torch.Tensor, logits: torch.Tensor): 153 | if isinstance(value, (float, int)): 154 | value = torch.full_like(logits[..., 0], value, dtype=logits.dtype) 155 | return value.to(logits.device).expand_as(logits[..., 0]).reshape(-1) 156 | -------------------------------------------------------------------------------- /src/esm/utils/structure/aligner.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import replace 4 | from typing import TYPE_CHECKING 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from src.esm.utils.structure.protein_structure import ( 10 | compute_affine_and_rmsd, 11 | ) 12 | 13 | if TYPE_CHECKING: 14 | from src.esm.utils.structure.protein_chain import ProteinChain 15 | 16 | 17 | class Aligner: 18 | def __init__( 19 | self, 20 | mobile: ProteinChain, 21 | target: ProteinChain, 22 | only_use_backbone: bool = False, 23 | use_reflection: bool = False, 24 | ): 25 | """ 26 | Aligns a mobile protein chain against a target protein chain. 27 | 28 | Args: 29 | mobile (ProteinChain): Protein chain to be aligned. 30 | target (ProteinChain): Protein chain target. 31 | only_use_backbone (bool): Whether to only use backbone atoms. 32 | use_reflection (bool): Whether to align to target reflection. 33 | """ 34 | # Check proteins must have same number of residues 35 | assert len(mobile) == len(target) 36 | 37 | # Determine overlapping atoms 38 | joint_atom37_mask = mobile.atom37_mask.astype(bool) & target.atom37_mask.astype( 39 | bool 40 | ) 41 | 42 | # Backbone atoms are first sites in atom37 representation 43 | if only_use_backbone: 44 | joint_atom37_mask[:, 3:] = False 45 | 46 | # Extract matching atom positions and convert to batched tensors 47 | mobile_atom_tensor = ( 48 | torch.from_numpy(mobile.atom37_positions).type(torch.double).unsqueeze(0) 49 | ) 50 | target_atom_tensor = ( 51 | torch.from_numpy(target.atom37_positions).type(torch.double).unsqueeze(0) 52 | ) 53 | joint_atom37_mask = ( 54 | torch.from_numpy(joint_atom37_mask).type(torch.bool).unsqueeze(0) 55 | ) 56 | 57 | # If using reflection flip target 58 | if use_reflection: 59 | target_atom_tensor = -target_atom_tensor 60 | 61 | # Compute alignment and rmsd 62 | affine3D, rmsd = compute_affine_and_rmsd( 63 | mobile_atom_tensor, target_atom_tensor, atom_exists_mask=joint_atom37_mask 64 | ) 65 | self._affine3D = affine3D 66 | self._rmsd = rmsd.item() 67 | 68 | @property 69 | def rmsd(self): 70 | return self._rmsd 71 | 72 | def apply(self, mobile: ProteinChain) -> ProteinChain: 73 | """Apply alignment to a protein chain""" 74 | # Extract atom positions and convert to batched tensors 75 | mobile_atom_tensor = ( 76 | torch.from_numpy(mobile.atom37_positions[mobile.atom37_mask]) 77 | .type(torch.float32) 78 | .unsqueeze(0) 79 | ) 80 | 81 | # Transform atom arrays 82 | aligned_atom_tensor = self._affine3D.apply(mobile_atom_tensor).squeeze(0) 83 | 84 | # Rebuild atom37 positions 85 | aligned_atom37_positions = np.full_like(mobile.atom37_positions, np.nan) 86 | aligned_atom37_positions[mobile.atom37_mask] = aligned_atom_tensor 87 | 88 | return replace(mobile, atom37_positions=aligned_atom37_positions) 89 | -------------------------------------------------------------------------------- /src/esm/utils/structure/lddt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | 4 | from src.esm.utils import residue_constants as RC 5 | 6 | 7 | def compute_lddt( 8 | all_atom_pred_pos: torch.Tensor, 9 | all_atom_positions: torch.Tensor, 10 | all_atom_mask: torch.Tensor, 11 | cutoff: float = 15.0, 12 | eps: float = 1e-10, 13 | per_residue: bool = True, 14 | ) -> torch.Tensor: 15 | """ 16 | Computes LDDT for a protein. Tensor sizes below include some optional dimensions. Specifically: 17 | Nstates: 18 | all_atom_pred_pos can contain multiple states in the first dimension which corresponds to outputs from different layers of a model (e.g. each IPA block). The return size will be [Nstates x Batch size] if this is included. 19 | Natoms: 20 | LDDT can be computed for all atoms or some atoms. The second to last dimension should contain the *FLATTENED* representation of L x Natoms. If you want to calculate for atom37, e.g., this will be of size (L * 37). If you are only calculating CA LDDT, it will be of size L. 21 | 22 | Args: 23 | all_atom_pred_pos (Tensor[float], [(Nstates x) B x (L * Natoms x) 3]): Tensor of predicted positions 24 | all_atom_positions (Tensor[float], [B x (L * Natoms x) 3]): Tensor of true positions 25 | all_atom_mask (Tensor[float], [B x (L * Natoms)]): Tensor of masks, indicating whether an atom exists. 26 | cutoff (float): Max distance to score lddt over. 27 | per_residue (bool): Whether to return per-residue or full-protein lddt. 28 | 29 | Returns: 30 | LDDT Tensor: 31 | if per_residue: 32 | Tensor[float], [(Nstates x) B x (L * Natoms)] 33 | else: 34 | Tensor[float], [(Nstates x) B] 35 | """ 36 | n = all_atom_mask.shape[-2] 37 | dmat_true = torch.sqrt( 38 | eps 39 | + torch.sum( 40 | (all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :]) 41 | ** 2, 42 | dim=-1, 43 | ) 44 | ) 45 | 46 | dmat_pred = torch.sqrt( 47 | eps 48 | + torch.sum( 49 | (all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :]) ** 2, 50 | dim=-1, 51 | ) 52 | ) 53 | dists_to_score = ( 54 | (dmat_true < cutoff) 55 | * all_atom_mask 56 | * rearrange(all_atom_mask, "... a b -> ... b a") 57 | * (1.0 - torch.eye(n, device=all_atom_mask.device)) 58 | ) 59 | 60 | dist_l1 = torch.abs(dmat_true - dmat_pred) 61 | 62 | score = ( 63 | (dist_l1 < 0.5).type(dist_l1.dtype) 64 | + (dist_l1 < 1.0).type(dist_l1.dtype) 65 | + (dist_l1 < 2.0).type(dist_l1.dtype) 66 | + (dist_l1 < 4.0).type(dist_l1.dtype) 67 | ) 68 | score = score * 0.25 69 | 70 | dims = (-1,) if per_residue else (-2, -1) 71 | norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims)) 72 | score = norm * (eps + torch.sum(dists_to_score * score, dim=dims)) 73 | 74 | return score 75 | 76 | 77 | def compute_lddt_ca( 78 | all_atom_pred_pos: torch.Tensor, 79 | all_atom_positions: torch.Tensor, 80 | all_atom_mask: torch.Tensor, 81 | cutoff: float = 15.0, 82 | eps: float = 1e-10, 83 | per_residue: bool = True, 84 | ) -> torch.Tensor: 85 | ca_pos = RC.atom_order["CA"] 86 | if all_atom_pred_pos.dim() != 3: 87 | all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] 88 | all_atom_positions = all_atom_positions[..., ca_pos, :] 89 | all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim 90 | 91 | return compute_lddt( 92 | all_atom_pred_pos, 93 | all_atom_positions, 94 | all_atom_mask, 95 | cutoff=cutoff, 96 | eps=eps, 97 | per_residue=per_residue, 98 | ) 99 | -------------------------------------------------------------------------------- /src/esm/utils/structure/normalize_coordinates.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | 3 | import numpy as np 4 | import torch 5 | from torch import Tensor 6 | 7 | from src.esm.utils import residue_constants as RC 8 | from src.esm.utils.structure.affine3d import Affine3D 9 | 10 | ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor) 11 | 12 | 13 | def atom3_to_backbone_frames(bb_positions: torch.Tensor) -> Affine3D: 14 | N, CA, C = bb_positions.unbind(dim=-2) 15 | return Affine3D.from_graham_schmidt(C, CA, N) 16 | 17 | 18 | def index_by_atom_name( 19 | atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2 20 | ) -> ArrayOrTensor: 21 | squeeze = False 22 | if isinstance(atom_names, str): 23 | atom_names = [atom_names] 24 | squeeze = True 25 | indices = [RC.atom_order[atom_name] for atom_name in atom_names] 26 | dim = dim % atom37.ndim 27 | index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim)) 28 | result = atom37[index] # type: ignore 29 | if squeeze: 30 | result = result.squeeze(dim) 31 | return result 32 | 33 | 34 | def get_protein_normalization_frame(coords: Tensor) -> Affine3D: 35 | """Given a set of coordinates for a protein, compute a single frame that can be used to normalize the coordinates. 36 | Specifically, we compute the average position of the N, CA, and C atoms use those 3 points to construct a frame 37 | using the Gram-Schmidt algorithm. The average CA position is used as the origin of the frame. 38 | 39 | Args: 40 | coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates 41 | 42 | Returns: 43 | Affine3D: tensor of Affine3D frame 44 | """ 45 | bb_coords = index_by_atom_name(coords, ["N", "CA", "C"], dim=-2) 46 | coord_mask = torch.all( 47 | torch.all(torch.isfinite(bb_coords), dim=-1), 48 | dim=-1, 49 | ) 50 | 51 | average_position_per_n_ca_c = bb_coords.masked_fill( 52 | ~coord_mask[..., None, None], 0 53 | ).sum(-3) / (coord_mask.sum(-1)[..., None, None] + 1e-8) 54 | frame = atom3_to_backbone_frames(average_position_per_n_ca_c.float()) 55 | 56 | return frame 57 | 58 | 59 | def apply_frame_to_coords(coords: Tensor, frame: Affine3D) -> Tensor: 60 | """Given a set of coordinates and a single frame, apply the frame to the coordinates. 61 | 62 | Args: 63 | coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates 64 | frame (Affine3D): Affine3D frame 65 | 66 | Returns: 67 | torch.FloatTensor: [L, 37, 3] tensor of transformed coordinates 68 | """ 69 | coords_trans_rot = frame[..., None, None].invert().apply(coords) 70 | 71 | # only transform coordinates with frame that have a valid rotation 72 | valid_frame = frame.trans.norm(dim=-1) > 0 73 | 74 | is_inf = torch.isinf(coords) 75 | coords = coords_trans_rot.where(valid_frame[..., None, None, None], coords) 76 | coords.masked_fill_(is_inf, torch.inf) 77 | 78 | return coords 79 | 80 | 81 | def normalize_coordinates(coords: Tensor) -> Tensor: 82 | return apply_frame_to_coords(coords, get_protein_normalization_frame(coords)) 83 | -------------------------------------------------------------------------------- /src/esm/utils/structure/predicted_aligned_error.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from src.esm.utils.structure.affine3d import Affine3D 5 | 6 | 7 | def masked_mean( 8 | mask: torch.Tensor, 9 | value: torch.Tensor, 10 | dim: int | None | tuple[int, ...] = None, 11 | eps=1e-10, 12 | ) -> torch.Tensor: 13 | """Compute the mean of `value` where only positions where `mask == true` are 14 | counted. 15 | """ 16 | mask = mask.expand(*value.shape) 17 | return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) 18 | 19 | 20 | def _pae_bins( 21 | max_bin: float = 31, num_bins: int = 64, device: torch.device = torch.device("cpu") 22 | ): 23 | bins = torch.linspace(0, max_bin, steps=(num_bins - 1), device=device) 24 | step = max_bin / (num_bins - 2) 25 | bin_centers = bins + step / 2 26 | bin_centers = torch.cat( 27 | [bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0 28 | ) 29 | return bin_centers 30 | 31 | 32 | def _compute_pae_masks(mask: torch.Tensor): 33 | square_mask = (mask.unsqueeze(-1) * mask.unsqueeze(-2)).bool() 34 | return square_mask 35 | 36 | 37 | def compute_predicted_aligned_error( 38 | logits: torch.Tensor, 39 | aa_mask: torch.Tensor, 40 | sequence_id: torch.Tensor | None = None, 41 | max_bin: float = 31, 42 | ) -> torch.Tensor: 43 | bins = _pae_bins(max_bin, logits.shape[-1], logits.device) 44 | square_mask = _compute_pae_masks(aa_mask) 45 | min_v = torch.finfo(logits.dtype).min 46 | probs = logits.masked_fill(~square_mask.unsqueeze(-1), min_v).softmax(dim=-1) 47 | 48 | return (probs * bins).sum(dim=-1) 49 | 50 | 51 | @torch.no_grad 52 | def compute_tm( 53 | logits: torch.Tensor, 54 | aa_mask: torch.Tensor, 55 | max_bin: float = 31.0, 56 | ): 57 | square_mask = _compute_pae_masks(aa_mask) 58 | seqlens = aa_mask.sum(-1, keepdim=True) 59 | bins = _pae_bins(max_bin, logits.shape[-1], logits.device) 60 | d0 = 1.24 * (seqlens.clamp_min(19) - 15) ** (1 / 3) - 1.8 61 | f_d = 1.0 / (1 + (bins / d0.unsqueeze(-1)) ** 2) 62 | 63 | min_v = torch.finfo(logits.dtype).min 64 | probs = logits.masked_fill(~square_mask.unsqueeze(-1), min_v).softmax(dim=-1) 65 | # This is the sum over bins 66 | ptm = (probs * f_d.unsqueeze(-2)).sum(dim=-1) 67 | # This is the mean over residues j 68 | ptm = masked_mean(square_mask, ptm, dim=-1) 69 | # The we do a max over residues i 70 | return ptm.max(dim=-1).values 71 | 72 | 73 | def tm_loss( 74 | logits: torch.Tensor, 75 | pred_affine: torch.Tensor, 76 | targ_affine: torch.Tensor, 77 | targ_mask: torch.Tensor, 78 | tm_mask: torch.Tensor | None = None, 79 | sequence_id: torch.Tensor | None = None, 80 | max_bin: float = 31, 81 | ): 82 | pred = Affine3D.from_tensor(pred_affine) 83 | targ = Affine3D.from_tensor(targ_affine) 84 | 85 | def transform(affine: Affine3D): 86 | pts = affine.trans[..., None, :, :] 87 | return affine.invert()[..., None].apply(pts) 88 | 89 | with torch.no_grad(): 90 | sq_diff = (transform(pred) - transform(targ)).square().sum(dim=-1) 91 | 92 | num_bins = logits.shape[-1] 93 | sq_bins = torch.linspace( 94 | 0, max_bin, num_bins - 1, device=logits.device 95 | ).square() 96 | # Gets the bin id by using a sum. 97 | true_bins = (sq_diff[..., None] > sq_bins).sum(dim=-1).long() 98 | 99 | errors = F.cross_entropy(logits.movedim(3, 1), true_bins, reduction="none") 100 | square_mask = _compute_pae_masks(targ_mask) 101 | loss = masked_mean(square_mask, errors, dim=(-1, -2)) 102 | 103 | if tm_mask is not None: 104 | loss = masked_mean(tm_mask, loss, dim=None) 105 | else: 106 | loss = loss.mean() 107 | 108 | return loss 109 | -------------------------------------------------------------------------------- /src/esm/utils/types.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import io 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import Union 7 | 8 | PathLike = Union[str, Path] 9 | PathOrBuffer = Union[PathLike, io.StringIO] 10 | 11 | 12 | @dataclass 13 | class FunctionAnnotation: 14 | """Represents an annotation of a protein's function over a range of residues. 15 | 16 | Fields: 17 | label (str): An entry in either the function_tokens or residue_annotations tokenizer vocabs 18 | start (int): Start index of this annotation. 1-indexed, inclusive. 19 | end (int): End index of this annotation. 1-indexed, inclusive. 20 | """ 21 | 22 | label: str 23 | start: int 24 | end: int 25 | 26 | def to_tuple(self) -> tuple[str, int, int]: 27 | return self.label, self.start, self.end 28 | -------------------------------------------------------------------------------- /src/esmfold.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import esm 3 | import os 4 | import gc 5 | import argparse 6 | import biotite.structure.io as bsio 7 | import pandas as pd 8 | from tqdm import tqdm 9 | from Bio import SeqIO 10 | from transformers import AutoTokenizer, EsmForProteinFolding 11 | 12 | from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein 13 | from transformers.models.esm.openfold_utils.feats import atom14_to_atom37 14 | 15 | def read_fasta(file_path, key): 16 | return str(getattr(SeqIO.read(file_path, 'fasta'), key)) 17 | 18 | def read_multi_fasta(file_path): 19 | """ 20 | params: 21 | file_path: path to a fasta file 22 | return: 23 | a dictionary of sequences 24 | """ 25 | sequences = {} 26 | current_sequence = '' 27 | with open(file_path, 'r') as file: 28 | for line in file: 29 | line = line.strip() 30 | if line.startswith('>'): 31 | if current_sequence: 32 | sequences[header] = current_sequence 33 | current_sequence = '' 34 | header = line 35 | else: 36 | current_sequence += line 37 | if current_sequence: 38 | sequences[header] = current_sequence 39 | return sequences 40 | 41 | def convert_outputs_to_pdb(outputs): 42 | final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs) 43 | outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()} 44 | final_atom_positions = final_atom_positions.cpu().numpy() 45 | final_atom_mask = outputs["atom37_atom_exists"] 46 | pdbs = [] 47 | for i in range(outputs["aatype"].shape[0]): 48 | aa = outputs["aatype"][i] 49 | pred_pos = final_atom_positions[i] 50 | mask = final_atom_mask[i] 51 | resid = outputs["residue_index"][i] + 1 52 | pred = OFProtein( 53 | aatype=aa, 54 | atom_positions=pred_pos, 55 | atom_mask=mask, 56 | residue_index=resid, 57 | b_factors=outputs["plddt"][i], 58 | chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None, 59 | ) 60 | pdbs.append(to_pdb(pred)) 61 | return pdbs 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument("--sequence", type=str, default=None) 66 | parser.add_argument("--fasta_file", type=str, default=None) 67 | parser.add_argument("--fasta_chunk_num", type=int, default=None) 68 | parser.add_argument("--fasta_chunk_id", type=int, default=None) 69 | parser.add_argument("--fasta_dir", type=str, default=None) 70 | parser.add_argument("--out_dir", type=str) 71 | parser.add_argument("--out_file", type=str, default="result.pdb") 72 | parser.add_argument("--out_info_file", type=str, default=None) 73 | parser.add_argument("--fold_chunk_size", type=int) 74 | args = parser.parse_args() 75 | 76 | # model = esm.pretrained.esmfold_v1() 77 | # model = model.eval().cuda() 78 | 79 | tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") 80 | model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True) 81 | 82 | model = model.cuda() 83 | # model.esm = model.esm.half() 84 | torch.backends.cuda.matmul.allow_tf32 = True 85 | # Optionally, uncomment to set a chunk size for axial attention. This can help reduce memory. 86 | # Lower sizes will have lower memory requirements at the cost of increased speed. 87 | if args.fold_chunk_size is not None: 88 | model.trunk.set_chunk_size(args.fold_chunk_size) 89 | 90 | if args.fasta_file is not None: 91 | seq_dict = read_multi_fasta(args.fasta_file) 92 | os.makedirs(args.out_dir, exist_ok=True) 93 | names, sequences = list(seq_dict.keys()), list(seq_dict.values()) 94 | if args.fasta_chunk_num is not None: 95 | chunk_size = len(names) // args.fasta_chunk_num + 1 96 | start = args.fasta_chunk_id * chunk_size 97 | end = min((args.fasta_chunk_id + 1) * chunk_size, len(names)) 98 | names, sequences = names[start:end], sequences[start:end] 99 | 100 | out_info_dict = {"name": [], "plddt": []} 101 | bar = tqdm(zip(names, sequences)) 102 | for name, sequence in bar: 103 | bar.set_description(name) 104 | name = name[1:].split(" ")[0] 105 | out_file = os.path.join(args.out_dir, f"{name}.ef.pdb") 106 | if os.path.exists(out_file): 107 | out_info_dict["name"].append(name) 108 | struct = bsio.load_structure(out_file, extra_fields=["b_factor"]) 109 | out_info_dict["plddt"].append(struct.b_factor.mean()) 110 | continue 111 | 112 | # Multimer prediction can be done with chains separated by ':' 113 | try: 114 | tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda() 115 | with torch.no_grad(): 116 | output = model(tokenized_input) 117 | except: 118 | print(f"Failed to predict {name}") 119 | continue 120 | gc.collect() 121 | pdb = convert_outputs_to_pdb(output) 122 | with open(out_file, "w") as f: 123 | f.write("\n".join(pdb)) 124 | 125 | out_info_dict["name"].append(name) 126 | struct = bsio.load_structure(out_file, extra_fields=["b_factor"]) 127 | out_info_dict["plddt"].append(struct.b_factor.mean()) 128 | 129 | if args.out_info_file is not None: 130 | pd.DataFrame(out_info_dict).to_csv(args.out_info_file, index=False) 131 | 132 | if args.fasta_dir is not None: 133 | os.makedirs(args.out_dir, exist_ok=True) 134 | proteins = sorted(os.listdir(args.fasta_dir)) 135 | bar = tqdm(proteins) 136 | for p in bar: 137 | name = p[:-6] 138 | bar.set_description(name) 139 | out_file = os.path.join(args.out_dir, f"{name}.ef.pdb") 140 | if os.path.exists(out_file): 141 | continue 142 | bar.set_description(p) 143 | sequence = read_fasta(os.path.join(args.fasta_dir, p), "seq") 144 | tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda() 145 | # Multimer prediction can be done with chains separated by ':' 146 | 147 | with torch.no_grad(): 148 | output = model(tokenized_input) 149 | 150 | pdb = convert_outputs_to_pdb(output) 151 | with open(out_file, "w") as f: 152 | f.write("\n".join(pdb)) 153 | 154 | struct = bsio.load_structure(out_file, extra_fields=["b_factor"]) 155 | print(p, struct.b_factor.mean()) 156 | elif args.sequence is not None: 157 | sequence = args.sequence 158 | # Multimer prediction can be done with chains separated by ':' 159 | 160 | with torch.no_grad(): 161 | output = model.infer_pdb(sequence) 162 | 163 | with open(args.out_file, "w") as f: 164 | f.write(output) 165 | 166 | struct = bsio.load_structure(args.out_file, extra_fields=["b_factor"]) 167 | print(struct.b_factor.mean()) -------------------------------------------------------------------------------- /src/models/__pycache__/adapter.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4protein/VenusVaccine/a05435581246573845d8a32336340e08c821158d/src/models/__pycache__/adapter.cpython-312.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/pooling.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4protein/VenusVaccine/a05435581246573845d8a32336340e08c821158d/src/models/__pycache__/pooling.cpython-312.pyc -------------------------------------------------------------------------------- /src/models/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from transformers.activations import ACT2FN 5 | 6 | class MaskedConv1d(nn.Conv1d): 7 | """A masked 1-dimensional convolution layer. 8 | 9 | Takes the same arguments as torch.nn.Conv1D, except that the padding is set automatically. 10 | 11 | Shape: 12 | Input: (N, L, in_channels) 13 | input_mask: (N, L, 1), optional 14 | Output: (N, L, out_channels) 15 | """ 16 | 17 | def __init__( 18 | self, 19 | in_channels: int, 20 | out_channels: int, 21 | kernel_size: int, 22 | stride: int = 1, 23 | dilation: int = 1, 24 | groups: int = 1, 25 | bias: bool = True, 26 | ): 27 | """ 28 | :param in_channels: input channels 29 | :param out_channels: output channels 30 | :param kernel_size: the kernel width 31 | :param stride: filter shift 32 | :param dilation: dilation factor 33 | :param groups: perform depth-wise convolutions 34 | :param bias: adds learnable bias to output 35 | """ 36 | padding = dilation * (kernel_size - 1) // 2 37 | super().__init__( 38 | in_channels, 39 | out_channels, 40 | kernel_size, 41 | stride=stride, 42 | dilation=dilation, 43 | groups=groups, 44 | bias=bias, 45 | padding=padding, 46 | ) 47 | 48 | def forward(self, x, input_mask=None): 49 | if input_mask is not None: 50 | x = x * input_mask 51 | return super().forward(x.transpose(1, 2)).transpose(1, 2) 52 | 53 | 54 | class Attention1dPooling(nn.Module): 55 | def __init__(self, hidden_size): 56 | super().__init__() 57 | self.layer = MaskedConv1d(hidden_size, 1, 1) 58 | 59 | def forward(self, x, input_mask=None): 60 | batch_szie = x.shape[0] 61 | attn = self.layer(x) 62 | attn = attn.view(batch_szie, -1) 63 | if input_mask is not None: 64 | attn = attn.masked_fill_( 65 | ~input_mask.view(batch_szie, -1).bool(), float("-inf") 66 | ) 67 | attn = F.softmax(attn, dim=-1).view(batch_szie, -1, 1) 68 | out = (attn * x).sum(dim=1) 69 | return out, attn 70 | 71 | class Attention1dPoolingProjection(nn.Module): 72 | def __init__(self, hidden_size, num_labels, dropout=0.25) -> None: 73 | super(Attention1dPoolingProjection, self).__init__() 74 | self.linear = nn.Linear(hidden_size, hidden_size) 75 | self.dropout = nn.Dropout(dropout) 76 | self.relu = nn.ReLU() 77 | self.final = nn.Linear(hidden_size, num_labels) 78 | 79 | def forward(self, x): 80 | x = self.linear(x) 81 | x = self.dropout(x) 82 | x = self.relu(x) 83 | x = self.final(x) 84 | return x 85 | 86 | class Attention1dPoolingHead(nn.Module): 87 | """Outputs of the model with the attention1d""" 88 | 89 | def __init__( 90 | self, hidden_size: int, num_labels: int, dropout: float = 0.25, return_attentions: bool = False 91 | ): # [batch x sequence(751) x embedding (1280)] --> [batch x embedding] --> [batch x 1] 92 | super(Attention1dPoolingHead, self).__init__() 93 | self.return_attentions = return_attentions 94 | self.attention1d = Attention1dPooling(hidden_size) 95 | self.attention1d_projection = Attention1dPoolingProjection(hidden_size, num_labels, dropout) 96 | 97 | def forward(self, x, input_mask=None): 98 | x, attn_weights = self.attention1d(x, input_mask=input_mask.unsqueeze(-1)) 99 | x = self.attention1d_projection(x) 100 | if self.return_attentions: 101 | return x, attn_weights 102 | else: 103 | return x 104 | 105 | class MeanPooling(nn.Module): 106 | """Mean Pooling for sentence-level classification tasks.""" 107 | 108 | def __init__(self): 109 | super().__init__() 110 | 111 | def forward(self, features, input_mask=None): 112 | if input_mask is not None: 113 | # Applying input_mask to zero out masked values 114 | masked_features = features * input_mask.unsqueeze(2) 115 | sum_features = torch.sum(masked_features, dim=1) 116 | mean_pooled_features = sum_features / input_mask.sum(dim=1, keepdim=True) 117 | else: 118 | mean_pooled_features = torch.mean(features, dim=1) 119 | return mean_pooled_features 120 | 121 | 122 | class MeanPoolingProjection(nn.Module): 123 | """Mean Pooling with a projection layer for sentence-level classification tasks.""" 124 | 125 | def __init__(self, hidden_size, num_labels, dropout=0.25): 126 | super().__init__() 127 | self.dense = nn.Linear(hidden_size, hidden_size) 128 | self.dropout = nn.Dropout(dropout) 129 | self.out_proj = nn.Linear(hidden_size, num_labels) 130 | 131 | def forward(self, mean_pooled_features): 132 | x = self.dropout(mean_pooled_features) 133 | x = self.dense(x) 134 | x = ACT2FN['gelu'](x) 135 | x = self.dropout(x) 136 | x = self.out_proj(x) 137 | return x 138 | 139 | 140 | class MeanPoolingHead(nn.Module): 141 | """Mean Pooling Head for sentence-level classification tasks.""" 142 | 143 | def __init__(self, hidden_size, num_labels, dropout=0.25): 144 | super().__init__() 145 | self.mean_pooling = MeanPooling() 146 | self.mean_pooling_projection = MeanPoolingProjection(hidden_size, num_labels, dropout) 147 | 148 | def forward(self, features, input_mask=None): 149 | mean_pooling_features = self.mean_pooling(features, input_mask=input_mask) 150 | x = self.mean_pooling_projection(mean_pooling_features) 151 | return x 152 | 153 | 154 | class LightAttentionPoolingHead(nn.Module): 155 | def __init__(self, hidden_size=1280, num_labels=11, dropout=0.25, kernel_size=9, conv_dropout: float = 0.25): 156 | super(LightAttentionPoolingHead, self).__init__() 157 | 158 | self.feature_convolution = nn.Conv1d(hidden_size, hidden_size, kernel_size, stride=1, 159 | padding=kernel_size // 2) 160 | self.attention_convolution = nn.Conv1d(hidden_size, hidden_size, kernel_size, stride=1, 161 | padding=kernel_size // 2) 162 | 163 | self.softmax = nn.Softmax(dim=-1) 164 | 165 | self.dropout = nn.Dropout(conv_dropout) 166 | 167 | self.linear = nn.Sequential( 168 | nn.Linear(2 * hidden_size, 32), 169 | nn.Dropout(dropout), 170 | nn.ReLU(), 171 | nn.BatchNorm1d(32) 172 | ) 173 | 174 | self.output = nn.Linear(32, num_labels) 175 | 176 | def forward(self, x: torch.Tensor, mask, **kwargs) -> torch.Tensor: 177 | """ 178 | Args: 179 | x: [batch_size, sequence_length, hidden_size] embedding tensor that should be classified 180 | mask: [batch_size, sequence_length] mask corresponding to the zero padding used for the shorter sequecnes in the batch. All values corresponding to padding are False and the rest is True. 181 | 182 | Returns: 183 | classification: [batch_size,num_labels] tensor with logits 184 | """ 185 | x = x.permute(0, 2, 1) # [batch_size, hidden_size, sequence_length] 186 | o = self.feature_convolution(x) # [batch_size, hidden_size, sequence_length] 187 | o = self.dropout(o) # [batch_gsize, hidden_size, sequence_length] 188 | attention = self.attention_convolution(x) # [batch_size, hidden_size, sequence_length] 189 | 190 | # mask out the padding to which we do not want to pay any attention (we have the padding because the sequences have different lenghts). 191 | # This padding is added by the dataloader when using the padded_permuted_collate function in utils/general.py 192 | attention = attention.masked_fill(mask[:, None, :] == False, -1e9) 193 | 194 | # code used for extracting embeddings for UMAP visualizations 195 | # extraction = torch.sum(x * self.softmax(attention), dim=-1) 196 | # extraction = self.id0(extraction) 197 | 198 | o1 = torch.sum(o * self.softmax(attention), dim=-1) # [batchsize, hidden_size] 199 | o2, _ = torch.max(o, dim=-1) # [batchsize, hidden_size] 200 | o = torch.cat([o1, o2], dim=-1) # [batchsize, 2*hidden_size] 201 | o = self.linear(o) # [batchsize, 32] 202 | return self.output(o) # [batchsize, num_labels] -------------------------------------------------------------------------------- /src/utils/__pycache__/data_utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4protein/VenusVaccine/a05435581246573845d8a32336340e08c821158d/src/utils/__pycache__/data_utils.cpython-312.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/loss_fn.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4protein/VenusVaccine/a05435581246573845d8a32336340e08c821158d/src/utils/__pycache__/loss_fn.cpython-312.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/metrics.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4protein/VenusVaccine/a05435581246573845d8a32336340e08c821158d/src/utils/__pycache__/metrics.cpython-312.pyc -------------------------------------------------------------------------------- /src/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import biotite 3 | import numpy as np 4 | import torch.utils.data as data 5 | from typing import List 6 | from biotite.structure.residues import get_residues 7 | from biotite.sequence import ProteinSequence 8 | from biotite.structure.io import pdbx, pdb 9 | from biotite.structure import filter_backbone 10 | from biotite.structure import get_chains 11 | 12 | def load_structure(fpath, chain=None): 13 | """ 14 | Args: 15 | fpath: filepath to either pdb or cif file 16 | chain: the chain id or list of chain ids to load 17 | Returns: 18 | biotite.structure.AtomArray 19 | """ 20 | if fpath.endswith('cif'): 21 | with open(fpath) as fin: 22 | pdbxf = pdbx.PDBxFile.read(fin) 23 | structure = pdbx.get_structure(pdbxf, model=1) 24 | elif fpath.endswith('pdb'): 25 | with open(fpath) as fin: 26 | pdbf = pdb.PDBFile.read(fin) 27 | structure = pdb.get_structure(pdbf, model=1) 28 | bbmask = filter_backbone(structure) 29 | structure = structure[bbmask] 30 | all_chains = get_chains(structure) 31 | if len(all_chains) == 0: 32 | raise ValueError('No chains found in the input file.') 33 | if chain is None: 34 | chain_ids = all_chains 35 | elif isinstance(chain, list): 36 | chain_ids = chain 37 | else: 38 | chain_ids = [chain] 39 | for chain in chain_ids: 40 | if chain not in all_chains: 41 | raise ValueError(f'Chain {chain} not found in input file') 42 | chain_filter = [a.chain_id in chain_ids for a in structure] 43 | structure = structure[chain_filter] 44 | return structure 45 | 46 | def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray): 47 | """ 48 | Example for atoms argument: ["N", "CA", "C"] 49 | """ 50 | def filterfn(s, axis=None): 51 | filters = np.stack([s.atom_name == name for name in atoms], axis=1) 52 | sum = filters.sum(0) 53 | if not np.all(sum <= np.ones(filters.shape[1])): 54 | raise RuntimeError("structure has multiple atoms with same name") 55 | index = filters.argmax(0) 56 | coords = s[index].coord 57 | coords[sum == 0] = float("nan") 58 | return coords 59 | 60 | return biotite.structure.apply_residue_wise(struct, struct, filterfn) 61 | 62 | def extract_coords_from_structure(structure: biotite.structure.AtomArray): 63 | """ 64 | Args: 65 | structure: An instance of biotite AtomArray 66 | Returns: 67 | Tuple (coords, seq) 68 | - coords is an L x 3 x 3 array for N, CA, C coordinates 69 | - seq is the extracted sequence 70 | """ 71 | coords = get_atom_coords_residuewise(["N", "CA", "C"], structure) 72 | residue_identities = get_residues(structure)[1] 73 | seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities]) 74 | return coords 75 | 76 | def extract_seq_from_pdb(pdb_file, chain=None): 77 | """ 78 | Args: 79 | structure: An instance of biotite AtomArray 80 | Returns: 81 | - seq is the extracted sequence 82 | """ 83 | structure = load_structure(pdb_file, chain) 84 | residue_identities = get_residues(structure)[1] 85 | seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities]) 86 | return seq 87 | 88 | 89 | class BatchSampler(data.Sampler): 90 | ''' 91 | A `torch.utils.data.Sampler` which samples batches according to a 92 | maximum number of graph nodes. 93 | 94 | :param node_counts: array of node counts in the dataset to sample from 95 | :param max_batch_nodes: the maximum number of nodes in any batch, 96 | including batches of a single element 97 | :param shuffle: if `True`, batches in shuffled order 98 | ''' 99 | def __init__(self, node_counts, max_batch_nodes=10000, shuffle=True): 100 | 101 | self.node_counts = node_counts 102 | self.idx = [i for i in range(len(node_counts)) if node_counts[i] <= max_batch_nodes] 103 | self.shuffle = shuffle 104 | self.max_batch_nodes = max_batch_nodes 105 | self._form_batches() 106 | 107 | def _form_batches(self): 108 | self.batches = [] 109 | if self.shuffle: random.shuffle(self.idx) 110 | idx = self.idx 111 | while idx: 112 | batch = [] 113 | max_n_node = 0 114 | while idx: 115 | if max(self.node_counts[idx[0]], max_n_node) * (len(batch) + 1) > self.max_batch_nodes: 116 | break 117 | next_idx, idx = idx[0], idx[1:] 118 | current_n_node = self.node_counts[next_idx] 119 | if current_n_node > max_n_node: 120 | max_n_node = current_n_node 121 | batch.append(next_idx) 122 | self.batches.append(batch) 123 | 124 | def __len__(self): 125 | if not self.batches: self._form_batches() 126 | return len(self.batches) 127 | 128 | def __iter__(self): 129 | if not self.batches: self._form_batches() 130 | for batch in self.batches: yield batch 131 | 132 | 133 | def top_k_accuracy(labels, probas): 134 | probas, labels = np.array(probas), np.array(labels) 135 | k = int(len(labels) * 0.3) 136 | topk = probas.argsort()[-k:] 137 | correct = labels[topk] == 1 138 | return correct.sum() / k 139 | 140 | def plot_roc_curve(y_true, y_pred, save_fig=None): 141 | import matplotlib.pyplot as plt 142 | from sklearn import metrics 143 | fpr, tpr, _ = metrics.roc_curve(y_true, y_pred) 144 | roc_auc = metrics.roc_auc_score(y_true, y_pred) 145 | plt.plot(fpr, tpr, label=f'Our (AUC = {roc_auc:.2f})') 146 | plt.plot([0, 1], [0, 1], 'k--') 147 | plt.xlabel('False Positive Rate') 148 | plt.ylabel('True Positive Rate') 149 | plt.title(f'ROC Curve') 150 | plt.legend() 151 | if save_fig: 152 | plt.savefig(save_fig, dpi=300, bbox_inches='tight') 153 | plt.close() 154 | -------------------------------------------------------------------------------- /src/utils/loss_fn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class MultiClassFocalLossWithAlpha(nn.Module): 5 | def __init__(self, num_classes, alpha=None, gamma=1, reduction='mean', device="cuda"): 6 | super(MultiClassFocalLossWithAlpha, self).__init__() 7 | if alpha is None: 8 | self.alpha = torch.ones(num_classes, dtype=torch.float32) 9 | self.alpha = torch.tensor(alpha).to(device) 10 | self.gamma = gamma 11 | self.reduction = reduction 12 | 13 | def forward(self, pred, target): 14 | alpha = self.alpha[target] 15 | log_softmax = torch.log_softmax(pred, dim=1) 16 | logpt = torch.gather(log_softmax, dim=1, index=target.view(-1, 1)) 17 | logpt = logpt.view(-1) 18 | ce_loss = -logpt 19 | pt = torch.exp(logpt) 20 | focal_loss = alpha * (1 - pt) ** self.gamma * ce_loss 21 | if self.reduction == "mean": 22 | return torch.mean(focal_loss) 23 | if self.reduction == "sum": 24 | return torch.sum(focal_loss) 25 | return focal_loss -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics.classification import MultilabelAveragePrecision 3 | 4 | 5 | def count_f1_max(pred, target): 6 | """ 7 | F1 score with the optimal threshold, Copied from TorchDrug. 8 | 9 | This function first enumerates all possible thresholds for deciding positive and negative 10 | samples, and then pick the threshold with the maximal F1 score. 11 | 12 | Parameters: 13 | pred (Tensor): predictions of shape :math:`(B, N)` 14 | target (Tensor): binary targets of shape :math:`(B, N)` 15 | """ 16 | 17 | order = pred.argsort(descending=True, dim=1) 18 | target = target.gather(1, order) 19 | precision = target.cumsum(1) / torch.ones_like(target).cumsum(1) 20 | recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10) 21 | is_start = torch.zeros_like(target).bool() 22 | is_start[:, 0] = 1 23 | is_start = torch.scatter(is_start, 1, order, is_start) 24 | 25 | all_order = pred.flatten().argsort(descending=True) 26 | order = ( 27 | order 28 | + torch.arange(order.shape[0], device=order.device).unsqueeze(1) 29 | * order.shape[1] 30 | ) 31 | order = order.flatten() 32 | inv_order = torch.zeros_like(order) 33 | inv_order[order] = torch.arange(order.shape[0], device=order.device) 34 | is_start = is_start.flatten()[all_order] 35 | all_order = inv_order[all_order] 36 | precision = precision.flatten() 37 | recall = recall.flatten() 38 | all_precision = precision[all_order] - torch.where( 39 | is_start, torch.zeros_like(precision), precision[all_order - 1] 40 | ) 41 | all_precision = all_precision.cumsum(0) / is_start.cumsum(0) 42 | all_recall = recall[all_order] - torch.where( 43 | is_start, torch.zeros_like(recall), recall[all_order - 1] 44 | ) 45 | all_recall = all_recall.cumsum(0) / pred.shape[0] 46 | all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10) 47 | return all_f1.max() 48 | 49 | 50 | class MultilabelF1Max(MultilabelAveragePrecision): 51 | 52 | def compute(self): 53 | return count_f1_max(torch.cat(self.preds), torch.cat(self.target)) --------------------------------------------------------------------------------