├── .gitignore
├── LICENSE
├── README.md
├── environment.yaml
├── eval.py
├── figs
└── architecture.png
├── infer.py
├── pdb2json.py
├── script
├── eval_case_study.sh
├── eval_example.pbs
└── run_example.pbs
├── src
├── data
│ ├── add_noise_to_backbone.py
│ ├── get_esm3_structure_seq.py
│ ├── get_ss_seq.py
│ └── processors
│ │ ├── descriptor_features.py
│ │ ├── foldseek.py
│ │ ├── protein_features.py
│ │ └── structure_features.py
├── esm
│ ├── __init__.py
│ ├── layers
│ │ ├── attention.py
│ │ ├── blocks.py
│ │ ├── codebook.py
│ │ ├── ffn.py
│ │ ├── geom_attention.py
│ │ ├── regression_head.py
│ │ ├── rotary.py
│ │ ├── structure_proj.py
│ │ └── transformer_stack.py
│ ├── models
│ │ ├── esm3.py
│ │ ├── function_decoder.py
│ │ └── vqvae.py
│ ├── pretrained.py
│ ├── sdk
│ │ └── api.py
│ ├── tokenization
│ │ ├── __init__.py
│ │ ├── function_tokenizer.py
│ │ ├── residue_tokenizer.py
│ │ ├── sasa_tokenizer.py
│ │ ├── sequence_tokenizer.py
│ │ ├── ss_tokenizer.py
│ │ ├── structure_tokenizer.py
│ │ └── tokenizer_base.py
│ └── utils
│ │ ├── constants
│ │ ├── esm3.py
│ │ ├── models.py
│ │ └── physics.py
│ │ ├── decoding.py
│ │ ├── encoding.py
│ │ ├── function
│ │ ├── encode_decode.py
│ │ ├── interpro.py
│ │ ├── lsh.py
│ │ └── tfidf.py
│ │ ├── generation.py
│ │ ├── misc.py
│ │ ├── noise_schedules.py
│ │ ├── residue_constants.py
│ │ ├── sampling.py
│ │ ├── structure
│ │ ├── affine3d.py
│ │ ├── aligner.py
│ │ ├── lddt.py
│ │ ├── normalize_coordinates.py
│ │ ├── predicted_aligned_error.py
│ │ ├── protein_chain.py
│ │ └── protein_structure.py
│ │ └── types.py
├── esmfold.py
├── models
│ ├── __pycache__
│ │ ├── adapter.cpython-312.pyc
│ │ └── pooling.cpython-312.pyc
│ ├── adapter.py
│ └── pooling.py
└── utils
│ ├── __pycache__
│ ├── data_utils.cpython-312.pyc
│ ├── loss_fn.cpython-312.pyc
│ └── metrics.cpython-312.pyc
│ ├── data_utils.py
│ ├── loss_fn.py
│ └── metrics.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | .idea/*
163 | wandb/
164 | ckpt/
165 | dataset/
166 | result/
167 | src/data/weights/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VenusVaccine
2 |
3 |
4 |
5 | [](https://github.com/ai4protein/VenusVaccine)
6 |
7 | [](https://www.python.org/)
8 | [](https://pytorch.org/)
9 | [](https://creativecommons.org/licenses/by-nc-nd/4.0/)
10 |
11 |
12 |
13 | ## 📋 Overview
14 |
15 | VenusVaccine is a deep learning-based immunogenicity prediction tool focused on the classification of protective antigen or non-protective antigen. The project leverages advanced pre-trained language models and adapter architectures to interpret immunogenicity based on the multimodal encoding of antigens, including their sequences, structures, and physico-chemical properties.
16 |
17 |
18 |

19 |
20 |
21 | ### 🌟 Key Features
22 |
23 | - 🔬 **Versatile Data Processing**
24 | - Support for multiple protein database formats
25 | - Efficient data preprocessing and feature extraction
26 | - Flexible data augmentation strategies
27 |
28 | - 🧬 **Protein Feature Extraction**
29 | - E-descriptor and Z-descriptor physicochemical features
30 | - Foldseek secondary structure prediction
31 | - ESM3 structure sequence encoding
32 |
33 | - 🤖 **Advanced Model Architecture**
34 | - Integration with pre-trained protein language models
35 | - Innovative adapter design
36 | - Support for multiple PLM types (ESM, Bert, AnKh etc.)
37 |
38 | - 📊 **Comprehensive Training Framework**
39 | - Cross-validation support
40 | - Early stopping strategy
41 | - Wandb experiment tracking
42 | - Automated model evaluation
43 |
44 | - 🚀 **High-Performance Computing**
45 | - GPU acceleration support
46 | - Distributed training
47 | - Gradient accumulation optimization
48 |
49 | ## 🛠️ Installation Guide
50 |
51 | ### Requirements
52 |
53 | - Python 3.7+
54 | - CUDA 11.0+ (for GPU training)
55 | - 8GB+ RAM
56 |
57 | ### Setup Steps
58 |
59 | 1. Clone the repository:
60 | ```bash
61 | git clone https://github.com/songleee/VenusVaccine.git
62 | cd VenusVaccine
63 | ```
64 |
65 | 2. Create a virtual environment:
66 | ```bash
67 | conda env create -f environment.yaml
68 | ```
69 |
70 | 3. Download data and checkpoints:
71 | Download the pre-trained model files, training data, and model evaluation results from [Google Drive](https://drive.google.com/drive/folders/1VLEGpFv7jFyWGChzxchxv-D99QUBlqOA?usp=sharing)
72 |
73 | Pre-trained model files should be placed in the `ckpt` directory:
74 | - `ckpt/Bacteria.pt`: Model for bacterial protective antigens
75 | - `ckpt/Virus.pt`: Model for viral protective antigens
76 | - `ckpt/Tumor.pt`: Model for tumor protective antigens
77 |
78 | 4. Download and install dependencies:
79 | - [Foldseek](https://github.com/steineggerlab/foldseek/releases/tag/10-941cd33)
80 | - [ESM3_encoder](https://huggingface.co/EvolutionaryScale/esm3-sm-open-v1/blob/main/data/weights/esm3_structure_encoder_v0.pth)
81 | ```bash
82 | wget https://huggingface.co/EvolutionaryScale/esm3-sm-open-v1/blob/main/data/weights/esm3_structure_encoder_v0.pth
83 | mkdir -p ./src/data/weights
84 | mv esm3_structure_encoder_v0.pth ./src/data/weights
85 | ```
86 |
87 | ## 📊 Data Processing
88 |
89 | ### Predict single protein sequence
90 |
91 | ```bash
92 | # Predict single protein sequence
93 | python src/esmfold.py --sequence "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG" --out_file output.pdb
94 |
95 | # Predict multiple proteins from FASTA file
96 | python src/esmfold.py --fasta_file proteins.fasta --out_dir pdb_structures --fold_chunk_size 128
97 |
98 | ```
99 |
100 | ### PDB to JSON Conversion
101 |
102 | Make sure you have got the PDB file (cryo-EM structure or predicted by AF2 or ESMFold) of interest protein first, and use `pdb2json.py` to convert PDB files to a feature-rich JSON format:
103 |
104 | ```bash
105 | python pdb2json.py
106 | ```
107 |
108 | This tool automatically extracts:
109 | - Amino acid sequence
110 | - ESM3 structure sequence
111 | - Foldseek secondary structure prediction
112 | - E-descriptor (5-dimensional) features
113 | - Z-descriptor (3-dimensional) features
114 |
115 | ## 🚀 Quick Start
116 |
117 | ### Basic Usage
118 |
119 | ```bash
120 | python infer.py -i input.json -t Bacteria
121 | ```
122 |
123 | ### Command Line Arguments
124 |
125 | ```bash
126 | python infer.py [-h] -i INPUT -t {Bacteria,Virus,Tumor} [--structure_seqs STRUCTURE_SEQS]
127 | [--max_seq_len MAX_SEQ_LEN] [--max_batch_token MAX_BATCH_TOKEN]
128 | [--num_workers NUM_WORKERS] [-o OUTPUT]
129 | ```
130 |
131 | Arguments:
132 | - `-i, --input`: Path to input JSON file (required)
133 | - `-t, --type`: Pathogen type, choose from: Bacteria, Virus, Tumor (required)
134 | - `--structure_seqs`: Types of structure sequences, comma-separated (default: e_descriptor,z_descriptor,foldseek_seq,esm3_structure_seq)
135 | - `--max_seq_len`: Maximum sequence length (default: 1024)
136 | - `--max_batch_token`: Maximum tokens per batch (default: 10000)
137 | - `--num_workers`: Number of data loading workers (default: 4)
138 | - `-o, --output`: Path to output CSV file (default: results_{type}.csv)
139 |
140 | ### Input Format
141 |
142 | The input should be a JSON file with one sample per line. Fields required depend on the specified structure_seqs parameter:
143 |
144 | ```json
145 | {
146 | "name": "protein1",
147 | "aa_seq": "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG",
148 | "foldseek_seq": "HHHEEELLCCHHHHHHHHHHHHSTTHHHHHHHHHHHHHHHHHHHHHHHHEETTEEHHHHHH",
149 | "esm3_structure_seq": [1, 2, 3, \...],
150 | "e_descriptor": [[0.1, 0.2, 0.3, 0.4, 0.5], \...],
151 | "z_descriptor": [[0.1, 0.2, 0.3], \...]
152 | }
153 | ```
154 |
155 | Required fields:
156 | - `name`: Protein sequence identifier
157 | - `aa_seq`: Amino acid sequence
158 |
159 | Optional fields (depending on structure_seqs parameter):
160 | - `foldseek_seq`: Secondary structure sequence predicted by Foldseek
161 | - `esm3_structure_seq`: Structure sequence predicted by ESM3
162 | - `e_descriptor`: E-descriptor features (5-dimensional)
163 | - `z_descriptor`: Z-descriptor features (3-dimensional)
164 |
165 | ### Output Format
166 |
167 | The output is a CSV file containing:
168 | - `name`: Protein sequence identifier
169 | - `aa_seq`: Amino acid sequence
170 | - `pred_label`: Prediction label (0: non-protective antigen, 1: protective antigen)
171 | - `pred_proba`: Prediction probability of being a protective antigen
172 |
173 | ### Examples
174 |
175 | 1. Predict using all structural features:
176 | ```bash
177 | python infer.py -i proteins.json -t Bacteria
178 | ```
179 |
180 | 2. Use only specific structural features:
181 | ```bash
182 | python infer.py -i proteins.json -t Virus --structure_seqs "e_descriptor,z_descriptor"
183 | ```
184 |
185 | 3. Specify output file:
186 | ```bash
187 | python infer.py -i proteins.json -t Tumor -o predictions.csv
188 | ```
189 |
190 | 4. Adjust sequence length and batch size:
191 | ```bash
192 | python infer.py -i proteins.json -t Bacteria --max_seq_len 512 --max_batch_token 5000
193 | ```
194 |
195 | ## ⚠️ Important Notes
196 |
197 | 1. Ensure all required dependencies are installed
198 | 2. Make sure corresponding model files exist in the `ckpt` directory (`Bacteria.pt`, `Virus.pt`, or `Tumor.pt`)
199 | 3. Make sure the PLM checkpoints downloaded from huggingface are set up correctly if the network failed
200 | 4. GPU is recommended for better inference performance
201 |
202 | ## 📝 Citation
203 |
204 | If you find this tool helpful, please cite our work:
205 | ```
206 | @inproceedings{
207 | li2025immunogenicity,
208 | title={Immunogenicity Prediction with Dual Attention Enables Vaccine Target Selection},
209 | author={Song Li and Yang Tan and Song Ke and Liang Hong and Bingxin Zhou},
210 | booktitle={The Thirteenth International Conference on Learning Representations},
211 | year={2025},
212 | url={https://openreview.net/forum?id=hWmwL9gizZ}
213 | }
214 | ```
215 |
216 | ## 📝 License
217 |
218 | This project is licensed under the terms of the [CC-BY-NC-ND-4.0](https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode) license.
219 |
220 | ## 📮 Contact
221 |
222 | - Project Maintainer: Song Li, Yang Tan
223 | - Email: songlee@sjtu.edu.cn
224 | - Issue Tracking: [Issue Page](https://github.com/songleee/VenusVaccine/issues)
225 |
226 | ---
227 |
228 |
229 | ⭐️ If you find this project helpful, please give it a star!
230 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: venusvaccine
2 | channels:
3 | - conda-forge
4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=conda_forge
9 | - _openmp_mutex=4.5=2_gnu
10 | - bzip2=1.0.8=h4bc722e_7
11 | - ca-certificates=2024.7.4=hbcca054_0
12 | - ld_impl_linux-64=2.40=hf3520f5_7
13 | - libffi=3.4.2=h7f98852_5
14 | - libgcc-ng=14.1.0=h77fa898_0
15 | - libgomp=14.1.0=h77fa898_0
16 | - libnsl=2.0.1=hd590300_0
17 | - libsqlite=3.46.0=hde9e2c9_0
18 | - libuuid=2.38.1=h0b41bf4_0
19 | - libxcrypt=4.4.36=hd590300_1
20 | - libzlib=1.3.1=h4ab18f5_1
21 | - ncurses=6.5=h59595ed_0
22 | - openssl=3.3.1=h4bc722e_2
23 | - pip=24.2=pyhd8ed1ab_0
24 | - python=3.10.14=hd12c33a_0_cpython
25 | - readline=8.2=h8228510_1
26 | - setuptools=72.1.0=pyhd8ed1ab_0
27 | - tk=8.6.13=noxft_h4845f30_101
28 | - wheel=0.44.0=pyhd8ed1ab_0
29 | - xz=5.2.6=h166bdaf_0
30 | - pip:
31 | - accelerate==0.33.0
32 | - aiohappyeyeballs==2.3.5
33 | - aiohttp==3.10.3
34 | - aiosignal==1.3.1
35 | - async-timeout==4.0.3
36 | - attrs==24.2.0
37 | - biotite==0.41.2
38 | - certifi==2024.7.4
39 | - charset-normalizer==3.3.2
40 | - click==8.1.7
41 | - datasets==2.20.0
42 | - dill==0.3.8
43 | - docker-pycreds==0.4.0
44 | - filelock==3.15.4
45 | - frozenlist==1.4.1
46 | - fsspec==2024.5.0
47 | - gitdb==4.0.11
48 | - gitpython==3.1.43
49 | - huggingface-hub==0.24.5
50 | - idna==3.7
51 | - jinja2==3.1.4
52 | - lightning-utilities==0.11.6
53 | - markupsafe==2.1.5
54 | - mpmath==1.3.0
55 | - msgpack==1.0.8
56 | - multidict==6.0.5
57 | - multiprocess==0.70.16
58 | - networkx==3.3
59 | - numpy==1.26.4
60 | - nvidia-cublas-cu12==12.1.3.1
61 | - nvidia-cuda-cupti-cu12==12.1.105
62 | - nvidia-cuda-nvrtc-cu12==12.1.105
63 | - nvidia-cuda-runtime-cu12==12.1.105
64 | - nvidia-cudnn-cu12==9.1.0.70
65 | - nvidia-cufft-cu12==11.0.2.54
66 | - nvidia-curand-cu12==10.3.2.106
67 | - nvidia-cusolver-cu12==11.4.5.107
68 | - nvidia-cusparse-cu12==12.1.0.106
69 | - nvidia-nccl-cu12==2.20.5
70 | - nvidia-nvjitlink-cu12==12.6.20
71 | - nvidia-nvtx-cu12==12.1.105
72 | - packaging==24.1
73 | - pandas==2.2.2
74 | - platformdirs==4.2.2
75 | - protobuf==5.27.3
76 | - psutil==6.0.0
77 | - pyarrow==17.0.0
78 | - pyarrow-hotfix==0.6
79 | - python-dateutil==2.9.0.post0
80 | - pytz==2024.1
81 | - pyyaml==6.0.2
82 | - regex==2024.7.24
83 | - requests==2.32.3
84 | - safetensors==0.4.4
85 | - sentry-sdk==2.12.0
86 | - setproctitle==1.3.3
87 | - six==1.16.0
88 | - smmap==5.0.1
89 | - sympy==1.13.2
90 | - tokenizers==0.19.1
91 | - torch==2.4.0
92 | - torchmetrics==1.4.1
93 | - tqdm==4.66.5
94 | - transformers==4.44.0
95 | - triton==3.0.0
96 | - typing-extensions==4.12.2
97 | - tzdata==2024.1
98 | - urllib3==2.2.2
99 | - wandb==0.17.6
100 | - xxhash==3.4.1
101 | - yarl==1.9.4
102 | prefix: /home/lisong/software/anaconda3/envs/venusvaccine
103 |
--------------------------------------------------------------------------------
/figs/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ai4protein/VenusVaccine/a05435581246573845d8a32336340e08c821158d/figs/architecture.png
--------------------------------------------------------------------------------
/pdb2json.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import pandas as pd
4 | from tqdm import tqdm
5 | import warnings
6 | from src.data.processors.descriptor_features import DescriptorFeatureProcessor
7 | from src.data.processors.structure_features import StructureFeatureProcessor
8 | from src.data.processors.foldseek import FoldseekProcessor
9 |
10 | warnings.filterwarnings("ignore")
11 |
12 |
13 | def process_pdb_folder(pdb_dir, output_json_file):
14 | # Initialize processors
15 | descriptor_processor = DescriptorFeatureProcessor()
16 | structure_processor = StructureFeatureProcessor()
17 |
18 | # Get Foldseek features
19 | foldseek_dict = FoldseekProcessor.run_foldseek_commands(pdb_dir)
20 |
21 | results = []
22 | for pdb_file in tqdm(os.listdir(pdb_dir)):
23 | if not pdb_file.endswith(".pdb"):
24 | continue
25 |
26 | pdb_path = os.path.join(pdb_dir, pdb_file)
27 | name = pdb_file[:-4]
28 |
29 | # Get structure features
30 | esm3_structure_seq, sequence = structure_processor.get_esm3_structure_seq(pdb_path)
31 |
32 | # Get other features
33 | foldseek_seq = foldseek_dict.get(name)
34 | e_descriptor = descriptor_processor.e_descriptor_embedding(sequence)
35 | z_descriptor = descriptor_processor.z_descriptor_embedding(sequence)
36 |
37 | result = {
38 | "name": name,
39 | "aa_seq": sequence,
40 | "esm3_structure_seq": esm3_structure_seq,
41 | "foldseek_seq": foldseek_seq,
42 | "e_descriptor": e_descriptor,
43 | "z_descriptor": z_descriptor
44 | }
45 | results.append(result)
46 |
47 | # Save results
48 | pd.DataFrame(results).to_json(output_json_file, orient="records", lines=True)
49 | print("JSON file created successfully!")
50 |
51 | if __name__ == "__main__":
52 | if len(sys.argv) != 3:
53 | print("Usage: python pdb2json.py ")
54 | sys.exit(1)
55 |
56 | pdb_dir = sys.argv[1]
57 | output_json_file = sys.argv[2]
58 | process_pdb_folder(pdb_dir, output_json_file)
59 |
--------------------------------------------------------------------------------
/script/eval_case_study.sh:
--------------------------------------------------------------------------------
1 | # --------------------case study--------------------
2 | # ElnaggarLab/ankh-large
3 | # facebook/esm2_t33_650M_UR50D
4 | # Rostlab/prot_bert
5 | dataset=BacteriaBinary
6 | pdb_type=ESMFold
7 | seqs=ez_descriptor,foldseek_seq,esm3_structure_seq
8 | seqs_type=full
9 | plm_group=facebook
10 | plm_model=esm2_t33_650M_UR50D
11 | pooling_head=attention1d
12 | lr=5e-4
13 | num_labels=2
14 |
15 | CUDA_VISIBLE_DEVICES=0 python eval.py \
16 | --plm_model ${plm_group}/${plm_model} \
17 | --dataset $dataset \
18 | --problem_type single_label_classification \
19 | --num_labels $num_labels \
20 | --pooling_method $pooling_head \
21 | --return_attentions \
22 | --test_file dataset/Case_1_Helicobacter_pylori/case.json \
23 | --test_result_dir result_random/$plm_model/case1 \
24 | --metrics auc,accuracy,precision,recall,f1,mcc \
25 | --structure_seqs $seqs \
26 | --max_batch_token 10000 \
27 | --ckpt_root result_random \
28 | --ckpt_dir $plm_model/$dataset \
29 | --model_name "$pdb_type"_"$plm_model"_"$pooling_head"_"$lr"_"$seqs_type".pt
30 |
--------------------------------------------------------------------------------
/script/eval_example.pbs:
--------------------------------------------------------------------------------
1 | #PBS -q ai
2 | #PBS -l walltime=72:00:00
3 | #PBS -l ncpus=6
4 | #PBS -l ngpus=1
5 | #PBS -l host=ai1
6 | #PBS -l mem=100gb
7 | #PBS -N ProVaccine
8 | #PBS -o out.log
9 | #PBS -e out.log
10 |
11 | cd $PBS_O_WORKDIR
12 | #module purge
13 | #module load Anaconda3
14 | export PATH=/home/lisong/software/anaconda3/bin:$PATH
15 | export PATH=/home/lisong/local/bin:$PATH
16 | export HF_ENDPOINT=https://hf-mirror.com
17 | source activate venusvaccine
18 |
19 | dataset=BacteriaBinary
20 | pdb_type=ESMFold
21 | seqs=ez_descriptor,foldseek_seq,esm3_structure_seq
22 | seqs_type=full
23 | plm_group=Rostlab
24 | plm_model=prot_bert
25 | checkpoint=/home/lisong/huggingface/checkpoints/Rostlab/prot_bert
26 |
27 | pooling_head=attention1d
28 | lr=5e-4
29 | num_labels=2
30 | CUDA_VISIBLE_DEVICES=0 python eval.py \
31 | --plm_model $checkpoint \
32 | --dataset $dataset \
33 | --problem_type single_label_classification \
34 | --num_labels $num_labels \
35 | --pooling_method $pooling_head \
36 | --test_file dataset/$dataset/$pdb_type/test.json \
37 | --test_result_dir result/$plm_model/$dataset/${seqs_type} \
38 | --metrics auc,accuracy,precision,recall,f1,mcc \
39 | --structure_seqs $seqs \
40 | --max_batch_token 10000 \
41 | --ckpt_root result \
42 | --ckpt_dir $plm_model/$dataset \
43 | --model_name "$pdb_type"_"$plm_model"_"$pooling_head"_"$lr"_"$seqs_type".pt
--------------------------------------------------------------------------------
/script/run_example.pbs:
--------------------------------------------------------------------------------
1 | #PBS -q ai
2 | #PBS -l walltime=72:00:00
3 | #PBS -l ncpus=6
4 | #PBS -l ngpus=1
5 | #PBS -l host=ai1
6 | #PBS -l mem=100gb
7 | #PBS -N VenusVaccine
8 | #PBS -o out.log
9 | #PBS -e out.log
10 |
11 | cd $PBS_O_WORKDIR
12 | #module purge
13 | #module load Anaconda3
14 | export PATH=/home/lisong/software/anaconda3/bin:$PATH
15 | export PATH=/home/lisong/local/bin:$PATH
16 | export HF_ENDPOINT=https://hf-mirror.com
17 | source activate venusvaccine
18 |
19 | # ElnaggarLab/ankh-large
20 | # facebook/esm2_t33_650M_UR50D
21 | # Rostlab/prot_bert
22 | dataset=BacteriaBinary
23 | pdb_type=ESMFold
24 | seqs=ez_descriptor,foldseek_seq,esm3_structure_seq
25 | seqs_type=full
26 | plm_group=Rostlab
27 | plm_model=prot_bert
28 | checkpoint=/home/lisong/huggingface/checkpoints/Rostlab/prot_bert
29 |
30 | pooling_head=attention1d
31 | lr=5e-4
32 |
33 | CUDA_VISIBLE_DEVICES=0 python train.py \
34 | --plm_model $checkpoint \
35 | --num_attention_heads 8 \
36 | --pooling_method $pooling_head \
37 | --pooling_dropout 0.1 \
38 | --dataset_config dataset/$dataset/"$dataset"_"$pdb_type".json \
39 | --lr $lr \
40 | --num_workers 4 \
41 | --gradient_accumulation_steps 1 \
42 | --max_train_epochs 50 \
43 | --max_batch_token 40000 \
44 | --patience 5 \
45 | --structure_seqs $seqs \
46 | --ckpt_root result \
47 | --ckpt_dir $plm_model/$dataset \
48 | --model_name "$pdb_type"_"$plm_model"_"$pooling_head"_"$lr"_"$seqs_type".pt \
49 | # --wandb \
50 | # --wandb_entity your/wandb/name \
51 | # --wandb_project VenusVaccine \
52 | # --wandb_run_name "$dataset"_"$pdb_type"_"$plm_model"_"$pooling_head"_"$lr"_"$seqs_type"
53 |
--------------------------------------------------------------------------------
/src/data/add_noise_to_backbone.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import json
4 | from tqdm import tqdm
5 |
6 | pdbs = os.listdir('alphafold_pdb')
7 | for pdb in tqdm(pdbs):
8 | pdb_lines = open(f"alphafold_pdb/{pdb}").read().splitlines()
9 |
10 | def add_noise_and_save(variance, file_name):
11 | with open(file_name, "w") as file:
12 | for line in pdb_lines:
13 | if line.startswith("ATOM"):
14 | parts = line.split()
15 | try:
16 | coords = np.array([float(parts[6]), float(parts[7]), float(parts[8])])
17 | noise = np.random.normal(0, variance, coords.shape)
18 | new_coords = coords + noise
19 | new_line = f"{line[:30]}{new_coords[0]:8.3f}{new_coords[1]:8.3f}{new_coords[2]:8.3f}{line[54:]}"
20 | file.write(new_line + "\n")
21 | except:
22 | file.write(line + "\n")
23 | else:
24 | file.write(line + "\n")
25 |
26 | variances = [0.5]
27 |
28 | for variance in variances:
29 | file_name = f"alphafold_pdb_noise_{variance}/{pdb}"
30 | try:
31 | add_noise_and_save(variance, file_name)
32 | except Exception as e:
33 | print(e)
34 | print(pdb)
35 |
36 |
--------------------------------------------------------------------------------
/src/data/get_esm3_structure_seq.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import sys
4 | sys.path.append(os.getcwd())
5 | import json
6 | import argparse
7 | import pandas as pd
8 | import numpy as np
9 | from tqdm import tqdm
10 | from biotite.structure.io.pdb import PDBFile
11 | from torch.nn import functional as F
12 | from src.esm.utils.structure.protein_chain import ProteinChain
13 | from src.esm.utils.constants.esm3 import VQVAE_SPECIAL_TOKENS
14 | from src.esm.tokenization.structure_tokenizer import StructureTokenizer
15 | from src.esm.models.vqvae import (
16 | StructureTokenDecoder,
17 | StructureTokenEncoder,
18 | )
19 | import torch._dynamo
20 | torch._dynamo.config.suppress_errors = True
21 |
22 | VQVAE_CODEBOOK_SIZE = 4096
23 | VQVAE_SPECIAL_TOKENS = {
24 | "MASK": VQVAE_CODEBOOK_SIZE,
25 | "EOS": VQVAE_CODEBOOK_SIZE + 1,
26 | "BOS": VQVAE_CODEBOOK_SIZE + 2,
27 | "PAD": VQVAE_CODEBOOK_SIZE + 3,
28 | "CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4,
29 | }
30 |
31 | def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"):
32 | model = (
33 | StructureTokenEncoder(
34 | d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096
35 | )
36 | .to(device)
37 | .eval()
38 | )
39 | state_dict = torch.load(
40 | "data/weights/esm3_structure_encoder_v0.pth", map_location=device
41 | )
42 | model.load_state_dict(state_dict)
43 | return model
44 |
45 | if __name__ == "__main__":
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument("--pdb_file", type=str, default=None)
48 | parser.add_argument("--pdb_dir", type=str, default=None)
49 | parser.add_argument("--out_file", type=str, default='structure_tokens.json')
50 | args = parser.parse_args()
51 |
52 | device="cuda:0"
53 | results = []
54 | # result_dict = {'name':[], 'aa_seq':[], 'esm3_structure_tokens':[], 'plddt':[], 'residue_index':[]}
55 |
56 | encoder = ESM3_structure_encoder_v0(device)
57 |
58 | if args.pdb_file is not None:
59 | # Extract Unique Chain IDs
60 | chain_ids = np.unique(PDBFile.read(args.pdb_file).get_structure().chain_id)
61 | # print(chain_ids)
62 | # ['L', 'H']
63 |
64 | # By Default, ProteinChain takes first one
65 | chain = ProteinChain.from_pdb(args.pdb_file, chain_id=chain_ids[0])
66 | sequence = chain.sequence
67 |
68 | # Encoder
69 | coords, plddt, residue_index = chain.to_structure_encoder_inputs()
70 | coords = coords.to(device)
71 | #plddt = plddt.cuda()
72 | residue_index = residue_index.to(device)
73 | _, structure_tokens = encoder.encode(coords, residue_index=residue_index)
74 |
75 | result = {'name':args.pdb_file, 'aa_seq':sequence, 'esm3_structure_seq':structure_tokens.cpu().numpy().tolist()[0]}
76 | results.append(result)
77 |
78 | with open(args.out_file, "w") as f:
79 | f.write("\n".join([json.dumps(r) for r in results]))
80 |
81 | elif args.pdb_dir is not None:
82 | pdb_files = os.listdir(args.pdb_dir)
83 | for pdb_file in tqdm(pdb_files):
84 | # Extract Unique Chain IDs
85 | chain_ids = np.unique(PDBFile.read(os.path.join(args.pdb_dir, pdb_file)).get_structure().chain_id)
86 | # print(chain_ids)
87 | # ['L', 'H']
88 |
89 | # By Default, ProteinChain takes first one
90 | chain = ProteinChain.from_pdb(os.path.join(args.pdb_dir, pdb_file), chain_id=chain_ids[0])
91 | sequence = chain.sequence
92 |
93 | # Encoder
94 | coords, plddt, residue_index = chain.to_structure_encoder_inputs()
95 | coords = coords.to(device)
96 | #plddt = pldt.cuda()
97 | residue_index = residue_index.to(device)
98 | _, structure_tokens = encoder.encode(coords, residue_index=residue_index)
99 |
100 | result = {'name':pdb_file, 'aa_seq':sequence, 'esm3_structure_seq':structure_tokens.cpu().numpy().tolist()[0]}
101 | results.append(result)
102 |
103 | with open(args.out_file, "w") as f:
104 | f.write("\n".join([json.dumps(r) for r in results]))
105 |
--------------------------------------------------------------------------------
/src/data/get_ss_seq.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.getcwd())
4 | import argparse
5 | import json
6 | import pandas as pd
7 | from tqdm import tqdm
8 | from Bio import PDB
9 | from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
10 | from src.utils.data_utils import extract_seq_from_pdb
11 |
12 |
13 | ss_alphabet = ['H', 'E', 'C']
14 | ss_alphabet_dic = {
15 | "H": "H", "G": "H", "E": "E",
16 | "B": "E", "I": "C", "T": "C",
17 | "S": "C", "L": "C", "-": "C",
18 | "P": "C"
19 | }
20 |
21 | def generate_feature(pdb_file):
22 | try:
23 | # extract amino acid sequence
24 | aa_seq = extract_seq_from_pdb(pdb_file)
25 | pdb_parser = PDB.PDBParser(QUIET=True)
26 | structure = pdb_parser.get_structure("protein", pdb_file)
27 | model = structure[0]
28 | dssp = PDB.DSSP(model, pdb_file)
29 | # extract secondary structure sequence
30 | sec_structures = []
31 | for i, dssp_res in enumerate(dssp):
32 | sec_structures.append(dssp_res[2])
33 |
34 | except Exception as e:
35 | return pdb_file, e
36 |
37 | sec_structure_str_8 = ''.join(sec_structures)
38 | sec_structure_str_8 = sec_structure_str_8.replace('-', 'L')
39 | if len(aa_seq) != len(sec_structure_str_8):
40 | return pdb_file, f"aa_seq {len(aa_seq)} and sec_structure_str_8 {len(sec_structure_str_8)} length mismatch"
41 |
42 | sec_structure_str_3 = ''.join([ss_alphabet_dic[ss] for ss in sec_structures])
43 |
44 | final_feature = {}
45 | final_feature["name"] = pdb_file.split('/')[-1]
46 | final_feature["aa_seq"] = aa_seq
47 | final_feature["ss8_seq"] = sec_structure_str_8
48 | final_feature["ss3_seq"] = sec_structure_str_3
49 |
50 | return final_feature, None
51 |
52 | if __name__ == '__main__':
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument('--pdb_dir', type=str, help='pdb dir')
55 | parser.add_argument('--pdb_file', type=str, help='pdb file')
56 |
57 | # multi processing
58 | parser.add_argument('--num_workers', type=int, default=4, help='number of workers')
59 |
60 | # index pdb for large scale inference
61 | parser.add_argument("--pdb_index_file", default=None, type=str, help="pdb index file")
62 | parser.add_argument("--pdb_index_level", default=1, type=int, help="pdb index level")
63 |
64 | # save file
65 | parser.add_argument('--error_file', type=str, help='save error file')
66 | parser.add_argument('--out_file', type=str, help='save file')
67 | args = parser.parse_args()
68 |
69 | out_dir = os.path.dirname(args.out_file)
70 | os.makedirs(out_dir, exist_ok=True)
71 |
72 | if args.pdb_dir is not None:
73 | # load pdb index file
74 | if args.pdb_index_file:
75 | pdbs = open(args.pdb_index_file).read().splitlines()
76 | pdb_files = []
77 | for pdb in pdbs:
78 | pdb_relative_dir = args.pdb_dir
79 | for i in range(1, args.pdb_index_level+1):
80 | pdb_relative_dir = os.path.join(pdb_relative_dir, pdb[:i])
81 | pdb_files.append(os.path.join(pdb_relative_dir, pdb+".pdb"))
82 |
83 | # regular pdb dir
84 | else:
85 | pdb_files = sorted([os.path.join(args.pdb_dir, p) for p in os.listdir(args.pdb_dir)])
86 |
87 | results, error_pdbs, error_messages = [], [], []
88 | with ThreadPoolExecutor(max_workers=args.num_workers) as executor:
89 | futures = [executor.submit(generate_feature, pdb_file) for pdb_file in pdb_files]
90 |
91 | with tqdm(total=len(pdb_files), desc="Processing pdb") as progress:
92 | for future in as_completed(futures):
93 | result, message = future.result()
94 | if message is None:
95 | results.append(result)
96 | else:
97 | error_pdbs.append(result)
98 | error_messages.append(message)
99 | progress.update(1)
100 | progress.close()
101 |
102 | if error_pdbs:
103 | if args.error_file is None:
104 | args.error_file = args.out_file.split(".")[0]+"_error.csv"
105 | error_dir = os.path.dirname(args.error_file)
106 | os.makedirs(error_dir, exist_ok=True)
107 | error_info = {"error_pdbs": error_pdbs, "error_messages": error_messages}
108 | pd.DataFrame(error_info).to_csv(args.error_file, index=False)
109 |
110 | with open(args.out_file, "w") as f:
111 | f.write("\n".join([json.dumps(r) for r in results]))
112 |
113 | elif args.pdb_file is not None:
114 | result, message = generate_feature(args.pdb_file)
115 | with open(args.out_file, "w") as f:
116 | json.dump(result, f)
117 |
--------------------------------------------------------------------------------
/src/data/processors/descriptor_features.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class DescriptorFeatureProcessor:
4 | def __init__(self):
5 | self.e_descriptors = self._init_e_descriptors()
6 | self.z_descriptors = self._init_z_descriptors()
7 |
8 | def _init_e_descriptors(self):
9 | e1 = {'A': 0.008, 'R': 0.171, 'N': 0.255, 'D': 0.303, 'C': -0.132, 'Q': 0.149, 'E': 0.221, 'G': 0.218,
10 | 'H': 0.023, 'I': -0.353, 'L': -0.267, 'K': 0.243, 'M': -0.239, 'F': -0.329, 'P': 0.173, 'S': 0.199,
11 | 'T': 0.068, 'W': -0.296, 'Y': -0.141, 'V': -0.274}
12 | e2 = {'A': 0.134, 'R': -0.361, 'N': 0.038, 'D': -0.057, 'C': 0.174, 'Q': -0.184, 'E': -0.28, 'G': 0.562,
13 | 'H': -0.177, 'I': 0.071, 'L': 0.018, 'K': -0.339, 'M': -0.141, 'F': -0.023, 'P': 0.286, 'S': 0.238,
14 | 'T': 0.147, 'W': -0.186, 'Y': -0.057, 'V': 0.136}
15 | e3 = {'A': -0.475, 'R': 0.107, 'N': 0.117, 'D': -0.014, 'C': 0.07, 'Q': -0.03, 'E': -0.315, 'G': -0.024,
16 | 'H': 0.041, 'I': -0.088, 'L': -0.265, 'K': -0.044, 'M': -0.155, 'F': 0.072, 'P': 0.407, 'S': -0.015,
17 | 'T': -0.015, 'W': 0.389, 'Y': 0.425, 'V': -0.187}
18 | e4 = {'A': -0.039, 'R': -0.258, 'N': 0.118, 'D': 0.225, 'C': 0.565, 'Q': 0.035, 'E': 0.157, 'G': 0.018,
19 | 'H': 0.28, 'I': -0.195, 'L': -0.274, 'K': -0.325, 'M': 0.321, 'F': -0.002, 'P': -0.215, 'S': -0.068,
20 | 'T': -0.132, 'W': 0.083, 'Y': -0.096, 'V': -0.196}
21 | e5 = {'A': 0.181, 'R': -0.364, 'N': -0.055, 'D': 0.156, 'C': -0.374, 'Q': -0.112, 'E': 0.303, 'G': 0.106,
22 | 'H': -0.021, 'I': -0.107, 'L': 0.206, 'K': -0.027, 'M': 0.077, 'F': 0.208, 'P': 0.384, 'S': -0.196,
23 | 'T': -0.274, 'W': 0.297, 'Y': -0.091, 'V': -0.299}
24 | return [e1, e2, e3, e4, e5]
25 |
26 | def _init_z_descriptors(self):
27 | z1 = {'A': 0.07, 'R': 2.88, 'N': 3.22, 'D': 3.64, 'C': 0.71, 'Q': 2.18, 'E': 3.08, 'G': 2.23, 'H': 2.41,
28 | 'I': -4.44, 'L': -4.19, 'K': 2.84, 'M': -2.49, 'F': -4.92, 'P': -1.22, 'S': 1.96, 'T': 0.92, 'W': -4.75,
29 | 'Y': -1.39, 'V': -2.69}
30 | z2 = {'A': -1.73, 'R': 2.52, 'N': 1.45, 'D': 1.13, 'C': -0.97, 'Q': 0.53, 'E': 0.39, 'G': -5.36, 'H': 1.74,
31 | 'I': -1.68, 'L': -1.03, 'K': 1.41, 'M': -0.27, 'F': 1.30, 'P': 0.88, 'S': -1.63, 'T': -2.09, 'W': 3.65,
32 | 'Y': 2.32, 'V': -2.53}
33 | z3 = {'A': 0.09, 'R': -3.44, 'N': 0.84, 'D': 2.36, 'C': 4.13, 'Q': -1.14, 'E': -0.07, 'G': 0.30, 'H': 1.11,
34 | 'I': -1.03, 'L': -0.98, 'K': -3.14, 'M': -0.41, 'F': 0.45, 'P': 2.23, 'S': 0.57, 'T': -1.40, 'W': 0.85,
35 | 'Y': 0.01, 'V': -1.29}
36 | return [z1, z2, z3]
37 |
38 | def e_descriptor_embedding(self, seq):
39 | descriptors = {aa: [d[aa] for d in self.e_descriptors] for aa in self.e_descriptors[0].keys()}
40 | return [descriptors.get(aa, [0.0]*5) for aa in seq]
41 |
42 | def z_descriptor_embedding(self, seq):
43 | descriptors = {aa: [d[aa] for d in self.z_descriptors] for aa in self.z_descriptors[0].keys()}
44 | return [descriptors.get(aa, [0.0]*3) for aa in seq]
--------------------------------------------------------------------------------
/src/data/processors/foldseek.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import os
3 | from Bio import SeqIO
4 |
5 | class FoldseekProcessor:
6 | @staticmethod
7 | def run_foldseek_commands(pdb_dir):
8 | temp_dir = "temp"
9 | fasta_file = "foldseek_seq.fasta"
10 |
11 | os.makedirs(temp_dir, exist_ok=True)
12 |
13 | try:
14 | subprocess.run(["foldseek", "createdb", pdb_dir, f"{temp_dir}/db"], check=True)
15 | subprocess.run(["foldseek", "lndb", f"{temp_dir}/db_h", f"{temp_dir}/db_ss_h"], check=True)
16 | subprocess.run(["foldseek", "convert2fasta", f"{temp_dir}/db_ss", fasta_file], check=True)
17 |
18 | foldseek_dict = {record.id: str(record.seq) for record in SeqIO.parse(fasta_file, "fasta")}
19 |
20 | return foldseek_dict
21 | finally:
22 | if os.path.exists(temp_dir):
23 | subprocess.run(["rm", "-rf", temp_dir], check=True)
24 | if os.path.exists(fasta_file):
25 | os.remove(fasta_file)
--------------------------------------------------------------------------------
/src/data/processors/protein_features.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class DescriptorFeatureProcessor:
4 | def __init__(self):
5 | self.e_descriptors = self._init_e_descriptors()
6 | self.z_descriptors = self._init_z_descriptors()
7 |
8 | def _init_e_descriptors(self):
9 | e1 = {'A': 0.008, 'R': 0.171, 'N': 0.255, 'D': 0.303, 'C': -0.132, 'Q': 0.149, 'E': 0.221, 'G': 0.218,
10 | 'H': 0.023, 'I': -0.353, 'L': -0.267, 'K': 0.243, 'M': -0.239, 'F': -0.329, 'P': 0.173, 'S': 0.199,
11 | 'T': 0.068, 'W': -0.296, 'Y': -0.141, 'V': -0.274}
12 | e2 = {'A': 0.134, 'R': -0.361, 'N': 0.038, 'D': -0.057, 'C': 0.174, 'Q': -0.184, 'E': -0.28, 'G': 0.562,
13 | 'H': -0.177, 'I': 0.071, 'L': 0.018, 'K': -0.339, 'M': -0.141, 'F': -0.023, 'P': 0.286, 'S': 0.238,
14 | 'T': 0.147, 'W': -0.186, 'Y': -0.057, 'V': 0.136}
15 | e3 = {'A': -0.475, 'R': 0.107, 'N': 0.117, 'D': -0.014, 'C': 0.07, 'Q': -0.03, 'E': -0.315, 'G': -0.024,
16 | 'H': 0.041, 'I': -0.088, 'L': -0.265, 'K': -0.044, 'M': -0.155, 'F': 0.072, 'P': 0.407, 'S': -0.015,
17 | 'T': -0.015, 'W': 0.389, 'Y': 0.425, 'V': -0.187}
18 | e4 = {'A': -0.039, 'R': -0.258, 'N': 0.118, 'D': 0.225, 'C': 0.565, 'Q': 0.035, 'E': 0.157, 'G': 0.018,
19 | 'H': 0.28, 'I': -0.195, 'L': -0.274, 'K': -0.325, 'M': 0.321, 'F': -0.002, 'P': -0.215, 'S': -0.068,
20 | 'T': -0.132, 'W': 0.083, 'Y': -0.096, 'V': -0.196}
21 | e5 = {'A': 0.181, 'R': -0.364, 'N': -0.055, 'D': 0.156, 'C': -0.374, 'Q': -0.112, 'E': 0.303, 'G': 0.106,
22 | 'H': -0.021, 'I': -0.107, 'L': 0.206, 'K': -0.027, 'M': 0.077, 'F': 0.208, 'P': 0.384, 'S': -0.196,
23 | 'T': -0.274, 'W': 0.297, 'Y': -0.091, 'V': -0.299}
24 | return [e1, e2, e3, e4, e5]
25 |
26 | def _init_z_descriptors(self):
27 | z1 = {'A': 0.07, 'R': 2.88, 'N': 3.22, 'D': 3.64, 'C': 0.71, 'Q': 2.18, 'E': 3.08, 'G': 2.23, 'H': 2.41,
28 | 'I': -4.44, 'L': -4.19, 'K': 2.84, 'M': -2.49, 'F': -4.92, 'P': -1.22, 'S': 1.96, 'T': 0.92, 'W': -4.75,
29 | 'Y': -1.39, 'V': -2.69}
30 | z2 = {'A': -1.73, 'R': 2.52, 'N': 1.45, 'D': 1.13, 'C': -0.97, 'Q': 0.53, 'E': 0.39, 'G': -5.36, 'H': 1.74,
31 | 'I': -1.68, 'L': -1.03, 'K': 1.41, 'M': -0.27, 'F': 1.30, 'P': 0.88, 'S': -1.63, 'T': -2.09, 'W': 3.65,
32 | 'Y': 2.32, 'V': -2.53}
33 | z3 = {'A': 0.09, 'R': -3.44, 'N': 0.84, 'D': 2.36, 'C': 4.13, 'Q': -1.14, 'E': -0.07, 'G': 0.30, 'H': 1.11,
34 | 'I': -1.03, 'L': -0.98, 'K': -3.14, 'M': -0.41, 'F': 0.45, 'P': 2.23, 'S': 0.57, 'T': -1.40, 'W': 0.85,
35 | 'Y': 0.01, 'V': -1.29}
36 | return [z1, z2, z3]
37 |
38 | def e_descriptor_embedding(self, seq):
39 | descriptors = {aa: [d[aa] for d in self.e_descriptors] for aa in self.e_descriptors[0].keys()}
40 | return [descriptors.get(aa, [0.0]*5) for aa in seq]
41 |
42 | def z_descriptor_embedding(self, seq):
43 | descriptors = {aa: [d[aa] for d in self.z_descriptors] for aa in self.z_descriptors[0].keys()}
44 | return [descriptors.get(aa, [0.0]*3) for aa in seq]
--------------------------------------------------------------------------------
/src/data/processors/structure_features.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from biotite.structure.io.pdb import PDBFile
4 | from src.esm.utils.structure.protein_chain import ProteinChain
5 | from src.esm.models.vqvae import StructureTokenEncoder
6 |
7 | class StructureFeatureProcessor:
8 | def __init__(self, device="cpu"):
9 | self.device = device
10 | self.encoder = self._load_esm3_encoder()
11 |
12 | def _load_esm3_encoder(self):
13 | model = (
14 | StructureTokenEncoder(
15 | d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096
16 | )
17 | .to(self.device)
18 | .eval()
19 | )
20 | state_dict = torch.load(
21 | "src/data/weights/esm3_structure_encoder_v0.pth", map_location=self.device
22 | )
23 | model.load_state_dict(state_dict)
24 | return model
25 |
26 | def get_esm3_structure_seq(self, pdb_file):
27 | chain_ids = self._get_chain_ids(pdb_file)
28 | chain = ProteinChain.from_pdb(pdb_file, chain_id=chain_ids[0])
29 |
30 | coords, plddt, residue_index = chain.to_structure_encoder_inputs()
31 | coords = coords.to(self.device)
32 | residue_index = residue_index.to(self.device)
33 |
34 | _, structure_tokens = self.encoder.encode(coords, residue_index=residue_index)
35 | return structure_tokens.cpu().numpy().tolist()[0], chain.sequence
36 |
37 | @staticmethod
38 | def _get_chain_ids(pdb_file):
39 | return np.unique(PDBFile.read(pdb_file).get_structure().chain_id)
--------------------------------------------------------------------------------
/src/esm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ai4protein/VenusVaccine/a05435581246573845d8a32336340e08c821158d/src/esm/__init__.py
--------------------------------------------------------------------------------
/src/esm/layers/attention.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import einops
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 |
8 | from src.esm.layers.rotary import RotaryEmbedding
9 |
10 |
11 | class MultiHeadAttention(nn.Module):
12 | def __init__(
13 | self,
14 | d_model: int,
15 | n_heads: int,
16 | bias: bool = False,
17 | qk_layernorm: bool = True,
18 | ):
19 | super().__init__()
20 |
21 | self.d_model = d_model
22 | self.n_heads = n_heads
23 |
24 | self.d_head = self.d_model // self.n_heads
25 | self.layernorm_qkv = nn.Sequential(
26 | nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=bias)
27 | )
28 | self.out_proj = nn.Linear(d_model, d_model, bias=bias)
29 |
30 | if qk_layernorm:
31 | self.q_ln = nn.LayerNorm(d_model, bias=bias)
32 | self.k_ln = nn.LayerNorm(d_model, bias=bias)
33 | else:
34 | self.q_ln = nn.Identity()
35 | self.k_ln = nn.Identity()
36 |
37 | self.rotary = RotaryEmbedding(d_model // n_heads)
38 |
39 | def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor):
40 | q = q.unflatten(-1, (self.n_heads, self.d_head))
41 | k = k.unflatten(-1, (self.n_heads, self.d_head))
42 | q, k = self.rotary(q, k)
43 | q = q.flatten(-2, -1)
44 | k = k.flatten(-2, -1)
45 | return q, k
46 |
47 | def forward(self, x, seq_id):
48 | qkv_BLD3 = self.layernorm_qkv(x)
49 | query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
50 | query_BLD, key_BLD = self.q_ln(query_BLD), self.k_ln(key_BLD)
51 | query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
52 |
53 | n_heads = self.n_heads
54 | reshaper = functools.partial(
55 | einops.rearrange, pattern="b s (h d) -> b h s d", h=n_heads
56 | )
57 |
58 | query_BHLD, key_BHLD, value_BHLD = map(
59 | reshaper, (query_BLD, key_BLD, value_BLD)
60 | )
61 |
62 | # Where True, enable participation in attention.
63 | mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
64 | mask_BHLL = mask_BLL.unsqueeze(1)
65 |
66 | context_BHLD = F.scaled_dot_product_attention(
67 | query_BHLD, key_BHLD, value_BHLD, mask_BHLL
68 | )
69 | context_BLD = einops.rearrange(context_BHLD, "b h s d -> b s (h d)")
70 | return self.out_proj(context_BLD)
71 |
--------------------------------------------------------------------------------
/src/esm/layers/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from src.esm.layers.attention import MultiHeadAttention
6 | from src.esm.layers.geom_attention import (
7 | GeometricReasoningOriginalImpl,
8 | )
9 | from src.esm.utils.structure.affine3d import Affine3D
10 |
11 |
12 | def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
13 | # set hidden dimesion to nearest multiple of 256 after expansion ratio
14 | return int(((expansion_ratio * d_model) + 255) // 256 * 256)
15 |
16 |
17 | class SwiGLU(nn.Module):
18 | """
19 | SwiGLU activation function as an nn.Module, allowing it to be used within nn.Sequential.
20 | This module splits the input tensor along the last dimension and applies the SiLU (Swish)
21 | activation function to the first half, then multiplies it by the second half.
22 | """
23 |
24 | def __init__(self):
25 | super(SwiGLU, self).__init__()
26 |
27 | def forward(self, x: torch.Tensor) -> torch.Tensor:
28 | x1, x2 = x.chunk(2, dim=-1)
29 | return F.silu(x1) * x2
30 |
31 |
32 | def swiglu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool):
33 | return nn.Sequential(
34 | nn.LayerNorm(d_model),
35 | nn.Linear(
36 | d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=bias
37 | ),
38 | SwiGLU(),
39 | nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=bias),
40 | )
41 |
42 |
43 | def gelu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool):
44 | hidden_dim = int(expansion_ratio * d_model)
45 | return nn.Sequential(
46 | nn.LayerNorm(d_model),
47 | nn.Linear(d_model, hidden_dim, bias=bias),
48 | nn.GELU(),
49 | nn.Linear(hidden_dim, d_model, bias=bias),
50 | )
51 |
52 |
53 | class UnifiedTransformerBlock(nn.Module):
54 | """
55 | A unified transformer block that can optionally incorporate geometric attention.
56 |
57 | This class defines a transformer block that can be configured to use geometric attention
58 | alongside the standard multi-head attention mechanism. It is designed to be a flexible
59 | component of transformer-based models, allowing for the integration of geometric reasoning.
60 |
61 | Parameters
62 | ----------
63 | d_model : int
64 | The dimensionality of the input and output features of the transformer block.
65 | n_heads : int
66 | The number of attention heads in the multi-head attention mechanism.
67 | n_layers : int
68 | The number of layers in the transformer block.
69 | use_geom_attn : bool, optional
70 | Whether to use geometric attention in addition to the standard multi-head attention. Defaults to False.
71 | v_heads : int, optional
72 | The number of heads to use for the geometric attention mechanism, if enabled. Must be specified if `use_geom_attn` is True.
73 | """
74 |
75 | def __init__(
76 | self,
77 | d_model: int,
78 | n_heads: int,
79 | use_geom_attn: bool = False,
80 | use_plain_attn: bool = True,
81 | v_heads: int | None = None,
82 | bias: bool = False,
83 | expansion_ratio: float = 4.0,
84 | residue_scaling_factor: float = 1,
85 | mask_and_zero_frameless: bool = False,
86 | qk_layernorm: bool = True,
87 | ffn_type: str = "swiglu", # swiglu | gelu
88 | ):
89 | super().__init__()
90 | self.use_plain_attn = use_plain_attn
91 | if self.use_plain_attn:
92 | self.attn = MultiHeadAttention(
93 | d_model, n_heads, bias, qk_layernorm=qk_layernorm
94 | )
95 | self.use_geom_attn = use_geom_attn
96 | if self.use_geom_attn:
97 | if v_heads is None:
98 | raise ValueError("v_heads must be specified when use_geom_attn is True")
99 | self.geom_attn = GeometricReasoningOriginalImpl(
100 | c_s=d_model,
101 | v_heads=v_heads,
102 | bias=bias,
103 | mask_and_zero_frameless=mask_and_zero_frameless,
104 | )
105 | if ffn_type == "swiglu":
106 | self.ffn = swiglu_ln_ffn(d_model, expansion_ratio, bias)
107 | elif ffn_type == "gelu":
108 | self.ffn = gelu_ln_ffn(d_model, expansion_ratio, bias)
109 | else:
110 | raise ValueError(f"Unknown ffn_type: {ffn_type}")
111 | self.scaling_factor = residue_scaling_factor
112 |
113 | def forward(
114 | self,
115 | x: torch.Tensor,
116 | sequence_id: torch.Tensor,
117 | frames: Affine3D,
118 | frames_mask: torch.Tensor,
119 | chain_id: torch.Tensor,
120 | ) -> torch.Tensor:
121 | """
122 | Forward pass for the UnifiedTransformerBlock.
123 |
124 | Parameters
125 | ----------
126 | x : torch.Tensor[float]
127 | Input tensor to the transformer block, typically the output from the previous layer.
128 | sequence_id : torch.Tensor[int]
129 | Tensor containing sequence IDs for each element in the batch, used for attention masking.
130 | frames : Affine3D
131 | Affine3D containing geometric frame information for geometric attention.
132 | frames_mask : torch.Tensor[bool]
133 | Boolean mask tensor indicating valid frames for geometric attention.
134 | chain_id : torch.Tensor[int]
135 | Tensor containing chain IDs for each element, used for attention masking in geometric attention.
136 |
137 | Returns
138 | -------
139 | torch.Tensor[float]
140 | The output tensor after applying the transformer block operations.
141 | """
142 | if self.use_plain_attn:
143 | r1 = self.attn(x, sequence_id)
144 | x = x + r1 / self.scaling_factor
145 |
146 | if self.use_geom_attn:
147 | r2 = self.geom_attn(x, frames, frames_mask, sequence_id, chain_id)
148 | x = x + r2 / self.scaling_factor
149 |
150 | r3 = self.ffn(x) / self.scaling_factor
151 | x = x + r3
152 |
153 | return x
154 |
--------------------------------------------------------------------------------
/src/esm/layers/codebook.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.distributed as dist
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class EMACodebook(nn.Module):
9 | def __init__(
10 | self,
11 | n_codes,
12 | embedding_dim,
13 | no_random_restart=True,
14 | restart_thres=1.0,
15 | ema_decay=0.99,
16 | ):
17 | super().__init__()
18 | self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim))
19 | self.register_buffer("N", torch.zeros(n_codes))
20 | self.register_buffer("z_avg", self.embeddings.data.clone())
21 |
22 | self.n_codes = n_codes
23 | self.embedding_dim = embedding_dim
24 | self._need_init = True
25 | self.no_random_restart = no_random_restart
26 | self.restart_thres = restart_thres
27 | self.freeze_codebook = False
28 | self.ema_decay = ema_decay
29 |
30 | def reset_parameters(self):
31 | # For meta init
32 | pass
33 |
34 | def _tile(self, x):
35 | d, ew = x.shape
36 | if d < self.n_codes:
37 | n_repeats = (self.n_codes + d - 1) // d
38 | std = 0.01 / np.sqrt(ew)
39 | x = x.repeat(n_repeats, 1)
40 | x = x + torch.randn_like(x) * std
41 | return x
42 |
43 | def _init_embeddings(self, z):
44 | # z: [b, t, c]
45 | self._need_init = False
46 | flat_inputs = z.view(-1, self.embedding_dim)
47 | y = self._tile(flat_inputs)
48 |
49 | y.shape[0]
50 | _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
51 | if dist.is_initialized():
52 | dist.broadcast(_k_rand, 0)
53 | self.embeddings.data.copy_(_k_rand)
54 | self.z_avg.data.copy_(_k_rand)
55 | self.N.data.copy_(torch.ones(self.n_codes))
56 |
57 | def forward(self, z):
58 | # z: [b, t, c]
59 | if self._need_init and self.training and not self.freeze_codebook:
60 | self._init_embeddings(z)
61 | # z is of shape [batch_size, sequence length, channels]
62 | flat_inputs = z.view(-1, self.embedding_dim)
63 | distances = (
64 | (flat_inputs**2).sum(dim=1, keepdim=True)
65 | - 2 * flat_inputs @ self.embeddings.t()
66 | + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
67 | ) # [bt, c]
68 |
69 | encoding_indices = torch.argmin(distances, dim=1)
70 | encoding_indices = encoding_indices.view(*z.shape[:2]) # [b, t, ncode]
71 |
72 | embeddings = F.embedding(encoding_indices, self.embeddings) # [b, t, c]
73 |
74 | commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
75 |
76 | # EMA codebook update
77 | if self.training and not self.freeze_codebook:
78 | assert False, "Not implemented"
79 | embeddings_st = (embeddings - z).detach() + z
80 |
81 | return embeddings_st, encoding_indices, commitment_loss
82 |
83 | def dictionary_lookup(self, encodings):
84 | embeddings = F.embedding(encodings, self.embeddings)
85 | return embeddings
86 |
87 | def soft_codebook_lookup(self, weights: torch.Tensor) -> torch.Tensor:
88 | return weights @ self.embeddings
89 |
--------------------------------------------------------------------------------
/src/esm/layers/ffn.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from torch import Tensor
4 |
5 | # NOT CURRENTLY USED
6 |
7 |
8 | class SwiGLU(nn.Module):
9 | def __init__(self) -> None:
10 | super().__init__()
11 |
12 | def forward(self, x: Tensor) -> Tensor:
13 | x1, x2 = x.chunk(2, dim=-1)
14 | hidden = F.silu(x1) * x2
15 | return hidden
16 |
17 |
18 | class FFN(nn.Module):
19 | def __init__(self, in_proj, activation, out_proj) -> None:
20 | super().__init__()
21 | self.in_proj = in_proj
22 | self.activation = activation
23 | self.out_proj = out_proj
24 |
25 | def forward(self, x: Tensor) -> Tensor:
26 | x = self.in_proj(x)
27 | x = self.activation(x)
28 | x = self.out_proj(x)
29 | return x
30 |
--------------------------------------------------------------------------------
/src/esm/layers/geom_attention.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 |
3 | import torch
4 | from einops import rearrange
5 | from torch import nn
6 | from torch.nn import functional as F
7 |
8 |
9 | class GeometricReasoningOriginalImpl(nn.Module):
10 | def __init__(
11 | self,
12 | c_s: int,
13 | v_heads: int,
14 | num_vector_messages: int = 1,
15 | mask_and_zero_frameless: bool = True,
16 | divide_residual_by_depth: bool = False,
17 | bias: bool = False,
18 | ):
19 | """Approximate implementation:
20 |
21 | ATTN(A, v) := (softmax_j A_ij) v_j
22 | make_rot_vectors(x) := R(i->g) Linear(x).reshape(..., 3)
23 | make_vectors(x) := T(i->g) Linear(x).reshape(..., 3)
24 |
25 | v <- make_rot_vectors(x)
26 | q_dir, k_dir <- make_rot_vectors(x)
27 | q_dist, k_dist <- make_vectors(x)
28 |
29 | A_ij <- dot(q_dir_i, k_dir_j) -||q_dist_i - k_dist_j||^2
30 | x <- x + Linear(T(g->i) ATTN(A, v))
31 | """
32 | super().__init__()
33 | self.c_s = c_s
34 | self.v_heads = v_heads
35 | self.num_vector_messages = num_vector_messages
36 | self.mask_and_zero_frameless = mask_and_zero_frameless
37 |
38 | self.s_norm = nn.LayerNorm(c_s, bias=bias)
39 | dim_proj = (
40 | 4 * self.v_heads * 3 + self.v_heads * 3 * self.num_vector_messages
41 | ) # 2 x (q, k) * number of heads * (x, y, z) + number of heads * number of vector messages * (x, y, z)
42 | self.proj = nn.Linear(c_s, dim_proj, bias=bias)
43 | channels_out = self.v_heads * 3 * self.num_vector_messages
44 | self.out_proj = nn.Linear(channels_out, c_s, bias=bias)
45 |
46 | # The basic idea is for some attention heads to pay more or less attention to rotation versus distance,
47 | # as well as to control the sharpness of the softmax (i.e., should this head only attend to those residues
48 | # very nearby or should there be shallower dropoff in attention weight?)
49 | self.distance_scale_per_head = nn.Parameter(torch.zeros((self.v_heads)))
50 | self.rotation_scale_per_head = nn.Parameter(torch.zeros((self.v_heads)))
51 |
52 | def forward(self, s, affine, affine_mask, sequence_id, chain_id):
53 | attn_bias = sequence_id.unsqueeze(-1) == sequence_id.unsqueeze(-2)
54 | attn_bias = attn_bias.unsqueeze(1).float()
55 | attn_bias = attn_bias.masked_fill(
56 | ~affine_mask[:, None, None, :], torch.finfo(attn_bias.dtype).min
57 | )
58 | chain_id_mask = chain_id.unsqueeze(1) != chain_id.unsqueeze(2)
59 | attn_bias = attn_bias.masked_fill(
60 | chain_id_mask.unsqueeze(1), torch.finfo(s.dtype).min
61 | )
62 |
63 | ns = self.s_norm(s)
64 | vec_rot, vec_dist = self.proj(ns).split(
65 | [
66 | self.v_heads * 2 * 3 + self.v_heads * 3 * self.num_vector_messages,
67 | self.v_heads * 2 * 3,
68 | ],
69 | dim=-1,
70 | )
71 |
72 | # Rotate the queries and keys for the rotation term. We also rotate the values.
73 | # NOTE(zeming, thayes): Values are only rotated, not translated. We may wish to change
74 | # this in the future.
75 | query_rot, key_rot, value = (
76 | affine.rot[..., None]
77 | .apply(rearrange(vec_rot, "... (h c) -> ... h c", c=3))
78 | .split(
79 | [
80 | self.v_heads,
81 | self.v_heads,
82 | self.v_heads * self.num_vector_messages,
83 | ],
84 | dim=-2,
85 | )
86 | )
87 |
88 | # Rotate and translate the queries and keys for the distance term
89 | # NOTE(thayes): a simple speedup would be to apply all rotations together, then
90 | # separately apply the translations.
91 | query_dist, key_dist = (
92 | affine[..., None]
93 | .apply(rearrange(vec_dist, "... (h c) -> ... h c", c=3))
94 | .chunk(2, dim=-2)
95 | )
96 |
97 | query_dist = rearrange(query_dist, "b s h d -> b h s 1 d")
98 | key_dist = rearrange(key_dist, "b s h d -> b h 1 s d")
99 | query_rot = rearrange(query_rot, "b s h d -> b h s d")
100 | key_rot = rearrange(key_rot, "b s h d -> b h d s")
101 | value = rearrange(
102 | value, "b s (h m) d -> b h s (m d)", m=self.num_vector_messages
103 | )
104 |
105 | distance_term = (query_dist - key_dist).norm(dim=-1) / sqrt(3)
106 | rotation_term = query_rot.matmul(key_rot) / sqrt(3)
107 | distance_term_weight = rearrange(
108 | F.softplus(self.distance_scale_per_head), "h -> h 1 1"
109 | )
110 | rotation_term_weight = rearrange(
111 | F.softplus(self.rotation_scale_per_head), "h -> h 1 1"
112 | )
113 |
114 | attn_weight = (
115 | rotation_term * rotation_term_weight - distance_term * distance_term_weight
116 | )
117 |
118 | if attn_bias is not None:
119 | # we can re-use the attention bias from the transformer layers
120 | # NOTE(thayes): This attention bias is expected to handle two things:
121 | # 1. Masking attention on padding tokens
122 | # 2. Masking cross sequence attention in the case of bin packing
123 | s_q = attn_weight.size(2)
124 | s_k = attn_weight.size(3)
125 | _s_q = max(0, attn_bias.size(2) - s_q)
126 | _s_k = max(0, attn_bias.size(3) - s_k)
127 | attn_bias = attn_bias[:, :, _s_q:, _s_k:]
128 | attn_weight = attn_weight + attn_bias
129 |
130 | attn_weight = torch.softmax(attn_weight, dim=-1)
131 |
132 | attn_out = attn_weight.matmul(value)
133 |
134 | attn_out = (
135 | affine.rot[..., None]
136 | .invert()
137 | .apply(
138 | rearrange(
139 | attn_out, "b h s (m d) -> b s (h m) d", m=self.num_vector_messages
140 | )
141 | )
142 | )
143 |
144 | attn_out = rearrange(
145 | attn_out, "b s (h m) d -> b s (h m d)", m=self.num_vector_messages
146 | )
147 | if self.mask_and_zero_frameless:
148 | attn_out = attn_out.masked_fill(~affine_mask[..., None], 0.0)
149 | s = self.out_proj(attn_out)
150 |
151 | return s
152 |
--------------------------------------------------------------------------------
/src/esm/layers/regression_head.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def RegressionHead(
5 | d_model: int,
6 | output_dim: int,
7 | hidden_dim: int | None = None,
8 | ) -> nn.Module:
9 | """Single-hidden layer MLP for supervised output.
10 |
11 | Args:
12 | d_model: input dimension
13 | output_dim: dimensionality of the output.
14 | hidden_dim: optional dimension of hidden layer, defaults to d_model.
15 | Returns:
16 | output MLP module.
17 | """
18 | hidden_dim = hidden_dim if hidden_dim is not None else d_model
19 | return nn.Sequential(
20 | nn.Linear(d_model, hidden_dim),
21 | nn.GELU(),
22 | nn.LayerNorm(hidden_dim),
23 | nn.Linear(hidden_dim, output_dim),
24 | )
25 |
--------------------------------------------------------------------------------
/src/esm/layers/rotary.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
2 | #
3 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
4 | # and OPT implementations in this library. It has been modified from its
5 | # original forms to accommodate minor architectural differences compared
6 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 | # NOTE: this implementation is from LLaMA 2:
20 | # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/08639a72e17836184096ae6a7e2766f2a34c3e36/modeling_flash_llama.py#L114
21 | # Flash attention rotary implementation can be installed like so: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary`
22 |
23 | from typing import Tuple
24 |
25 | import torch
26 | from einops import rearrange, repeat
27 |
28 |
29 | def rotate_half(x, interleaved=False):
30 | if not interleaved:
31 | x1, x2 = x.chunk(2, dim=-1)
32 | return torch.cat((-x2, x1), dim=-1)
33 | else:
34 | x1, x2 = x[..., ::2], x[..., 1::2]
35 | return rearrange(
36 | torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
37 | )
38 |
39 |
40 | def apply_rotary_emb_torch(x, cos, sin, interleaved=False, _inplace=False):
41 | """
42 | x: (batch_size, seqlen, nheads, headdim)
43 | cos, sin: (seqlen, rotary_dim / 2)
44 | """
45 | ro_dim = cos.shape[-1] * 2
46 | assert ro_dim <= x.shape[-1]
47 | seqlen = x.size(1)
48 | cos = cos[:seqlen]
49 | sin = sin[:seqlen]
50 | cos = repeat(cos, "s d -> s 1 (2 d)")
51 | sin = repeat(sin, "s d -> s 1 (2 d)")
52 | return torch.cat(
53 | [
54 | x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
55 | x[..., ro_dim:],
56 | ],
57 | dim=-1,
58 | )
59 |
60 |
61 | class RotaryEmbedding(torch.nn.Module):
62 | """
63 | The rotary position embeddings from RoFormer_ (Su et. al).
64 | A crucial insight from the method is that the query and keys are
65 | transformed by rotation matrices which depend on the relative positions.
66 | Other implementations are available in the Rotary Transformer repo_ and in
67 | GPT-NeoX_, GPT-NeoX was an inspiration
68 | .. _RoFormer: https://arxiv.org/abs/2104.09864
69 | .. _repo: https://github.com/ZhuiyiTechnology/roformer
70 | .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
71 | If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
72 | A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
73 | Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
74 | """
75 |
76 | def __init__(
77 | self,
78 | dim: int,
79 | base=10000.0,
80 | interleaved=False,
81 | scale_base=None,
82 | scaling_factor=1.0,
83 | pos_idx_in_fp32=True,
84 | device=None,
85 | ):
86 | """
87 | interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
88 | of 1st half and 2nd half (GPT-NeoX style).
89 | pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
90 | otherwise they might be in lower precision.
91 | This option was added because previously (before 2023-07-02), when we construct
92 | the position indices, we use the dtype of self.inv_freq. In most cases this would
93 | be fp32, but if the model is trained in pure bf16 (not mixed precision), then
94 | self.inv_freq would be bf16, and the position indices are also in bf16.
95 | Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
96 | embeddings for some positions will coincide.
97 | To maintain compatibility with models previously trained in pure bf16,
98 | we add this option.
99 | scaling_factor: RotaryEmbedding extended with linear scaling.
100 | """
101 | super().__init__()
102 | self.dim = dim
103 | self.base = float(base)
104 | self.pos_idx_in_fp32 = pos_idx_in_fp32
105 | # Generate and save the inverse frequency buffer (non trainable)
106 | self.interleaved = interleaved
107 | self.scale_base = scale_base
108 | self.scaling_factor = scaling_factor
109 | self.device = device
110 |
111 | self._seq_len_cached = 0
112 | self._cos_cached = None
113 | self._sin_cached = None
114 | self._cos_k_cached = None
115 | self._sin_k_cached = None
116 | self.reset_parameters()
117 |
118 | def reset_parameters(self):
119 | inv_freq = self._compute_inv_freq(self.device)
120 | self.register_buffer("inv_freq", inv_freq, persistent=False)
121 | arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
122 | scale = (
123 | (arange + 0.4 * self.dim) / (1.4 * self.dim)
124 | if self.scale_base is not None
125 | else None
126 | )
127 | self.register_buffer("scale", scale)
128 |
129 | def _compute_inv_freq(self, device=None):
130 | return 1 / (
131 | self.base
132 | ** (
133 | torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
134 | / self.dim
135 | )
136 | )
137 |
138 | def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
139 | # Reset the tables if the sequence length has changed,
140 | # if we're on a new device (possibly due to tracing for instance),
141 | # or if we're switching from inference mode to training
142 | if (
143 | seqlen > self._seq_len_cached
144 | or self._cos_cached is None
145 | or self._cos_cached.device != device
146 | or self._cos_cached.dtype != dtype
147 | or (self.training and self._cos_cached.is_inference())
148 | ):
149 | self._seq_len_cached = seqlen
150 | # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
151 | # And the output of arange can be quite large, so bf16 would lose a lot of precision.
152 | # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
153 | if self.pos_idx_in_fp32:
154 | t = torch.arange(seqlen, device=device, dtype=torch.float32)
155 | t /= self.scaling_factor
156 | # We want fp32 here as well since inv_freq will be multiplied with t, and the output
157 | # will be large. Having it in bf16 will lose a lot of precision and cause the
158 | # cos & sin output to change significantly.
159 | # We want to recompute self.inv_freq if it was not loaded in fp32
160 | if self.inv_freq.dtype != torch.float32:
161 | inv_freq = self.inv_freq.to(torch.float32)
162 | else:
163 | inv_freq = self.inv_freq
164 | else:
165 | t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
166 | t /= self.scaling_factor
167 | inv_freq = self.inv_freq
168 | # Don't do einsum, it converts fp32 to fp16 under AMP
169 | # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
170 | freqs = torch.outer(t, inv_freq)
171 |
172 | if self.scale is None:
173 | self._cos_cached = torch.cos(freqs).to(dtype)
174 | self._sin_cached = torch.sin(freqs).to(dtype)
175 | else:
176 | power = (
177 | torch.arange(
178 | seqlen, dtype=self.scale.dtype, device=self.scale.device
179 | )
180 | - seqlen // 2
181 | ) / self.scale_base
182 | scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
183 | # We want the multiplication by scale to happen in fp32
184 | self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
185 | self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
186 | self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
187 | self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
188 |
189 | def forward(
190 | self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0
191 | ) -> Tuple[torch.Tensor, torch.Tensor]:
192 | """
193 | q: (batch, seqlen, nheads, headdim)
194 | k: (batch, seqlen, nheads, headdim)
195 | seqlen_offset: can be used in generation where the qkv being passed in is only the last
196 | token in the batch.
197 | """
198 | self._update_cos_sin_cache(
199 | q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype
200 | )
201 | assert self._cos_cached is not None
202 | assert self._sin_cached is not None
203 | if self.scale is None:
204 | return (
205 | apply_rotary_emb_torch(
206 | q,
207 | self._cos_cached[seqlen_offset:],
208 | self._sin_cached[seqlen_offset:],
209 | self.interleaved,
210 | True, # inplace=True
211 | ),
212 | apply_rotary_emb_torch(
213 | k,
214 | self._cos_cached[seqlen_offset:],
215 | self._sin_cached[seqlen_offset:],
216 | self.interleaved,
217 | True, # inplace=True
218 | ),
219 | ) # type: ignore
220 | else:
221 | assert False
222 |
--------------------------------------------------------------------------------
/src/esm/layers/structure_proj.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from src.esm.utils.constants.physics import (
5 | BB_COORDINATES,
6 | )
7 | from src.esm.utils.structure.affine3d import (
8 | Affine3D,
9 | RotationMatrix,
10 | )
11 |
12 |
13 | class Dim6RotStructureHead(nn.Module):
14 | # Normally, AF2 uses quaternions to specify rotations. There's some evidence that
15 | # other representations are more well behaved - the best one according to
16 | # https://openaccess.thecvf.com/content_CVPR_2019/papers/Zhou_On_the_Continuity_of_Rotation_Representations_in_Neural_Networks_CVPR_2019_paper.pdf
17 | # is using graham schmidt on 2 vectors, which is implemented here.
18 | def __init__(
19 | self,
20 | input_dim: int,
21 | trans_scale_factor: float = 10,
22 | norm_type: str = "layernorm",
23 | activation_fn: str = "esm_gelu",
24 | predict_torsion_angles: bool = True,
25 | ):
26 | super().__init__()
27 | self.ffn1 = nn.Linear(input_dim, input_dim)
28 | self.activation_fn = nn.GELU()
29 | self.norm = nn.LayerNorm(input_dim)
30 | self.proj = nn.Linear(input_dim, 9 + 7 * 2)
31 | self.trans_scale_factor = trans_scale_factor
32 | self.predict_torsion_angles = predict_torsion_angles
33 | self.bb_local_coords = torch.tensor(BB_COORDINATES).float()
34 |
35 | def forward(self, x, affine, affine_mask, **kwargs):
36 | if affine is None:
37 | rigids = Affine3D.identity(
38 | x.shape[:-1],
39 | dtype=x.dtype,
40 | device=x.device,
41 | requires_grad=self.training,
42 | rotation_type=RotationMatrix,
43 | )
44 | else:
45 | rigids = affine
46 |
47 | # [*, N]
48 | x = self.ffn1(x)
49 | x = self.activation_fn(x)
50 | x = self.norm(x)
51 | trans, x, y, angles = self.proj(x).split([3, 3, 3, 7 * 2], dim=-1)
52 | trans = trans * self.trans_scale_factor
53 | x = x / (x.norm(dim=-1, keepdim=True) + 1e-5)
54 | y = y / (y.norm(dim=-1, keepdim=True) + 1e-5)
55 | update = Affine3D.from_graham_schmidt(x + trans, trans, y + trans)
56 | rigids = rigids.compose(update.mask(affine_mask))
57 | affine = rigids.tensor
58 |
59 | # We approximate the positions of the backbone atoms in the global frame by applying the rigid
60 | # transformation to the mean of the backbone atoms in the local frame.
61 | all_bb_coords_local = (
62 | self.bb_local_coords[None, None, :, :]
63 | .expand(*x.shape[:-1], 3, 3)
64 | .to(x.device)
65 | )
66 | pred_xyz = rigids[..., None].apply(all_bb_coords_local)
67 |
68 | return affine, pred_xyz
69 |
--------------------------------------------------------------------------------
/src/esm/layers/transformer_stack.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from src.esm.layers.blocks import UnifiedTransformerBlock
7 | from src.esm.utils.structure.affine3d import Affine3D
8 |
9 |
10 | class TransformerStack(nn.Module):
11 | """
12 | A stack of transformer blocks used in the ESM-3 model. Each block is a UnifiedTransformerBlock,
13 | which can either be geometric attention or standard multi-head attention.
14 |
15 | Args:
16 | d_model (int): The dimensionality of the input and output feature vectors.
17 | n_heads (int): The number of attention heads.
18 | v_heads (int): The number of voting heads.
19 | n_layers (int): The number of transformer blocks in the stack.
20 | n_layers_geom (int, optional): The number of transformer blocks that use geometric attention.
21 | scale_residue (bool, optional): Whether to scale the residue connections in each transformer block.
22 | mask_and_zero_frameless (bool, optional): Whether to mask and zero frameless positions in the input.
23 | Only applies in the geometric attention blocks, which is conditioned on the structure
24 | """
25 |
26 | def __init__(
27 | self,
28 | d_model: int,
29 | n_heads: int,
30 | v_heads: int | None,
31 | n_layers: int,
32 | n_layers_geom: int = 1,
33 | scale_residue: bool = True,
34 | mask_and_zero_frameless: bool = False,
35 | bias: bool = False,
36 | qk_layernorm: bool = True,
37 | ffn_type: str = "swiglu", # swiglu | gelu
38 | expansion_ratio: float = 8 / 3,
39 | ):
40 | super().__init__()
41 | self.blocks = nn.ModuleList(
42 | [
43 | UnifiedTransformerBlock(
44 | d_model,
45 | n_heads,
46 | v_heads=v_heads,
47 | use_geom_attn=i < n_layers_geom,
48 | residue_scaling_factor=(
49 | math.sqrt(n_layers / 36) if scale_residue else 1.0
50 | ),
51 | expansion_ratio=expansion_ratio,
52 | mask_and_zero_frameless=mask_and_zero_frameless,
53 | bias=bias,
54 | qk_layernorm=qk_layernorm,
55 | ffn_type=ffn_type,
56 | )
57 | for i in range(n_layers)
58 | ]
59 | )
60 | self.norm = nn.LayerNorm(d_model, bias=False)
61 |
62 | def forward(
63 | self,
64 | x: torch.Tensor,
65 | sequence_id: torch.Tensor | None = None,
66 | affine: Affine3D | None = None,
67 | affine_mask: torch.Tensor | None = None,
68 | chain_id: torch.Tensor | None = None,
69 | ) -> tuple[torch.Tensor, torch.Tensor]:
70 | """
71 | Forward pass of the TransformerStack.
72 |
73 | Args:
74 | x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, d_model).
75 | sequence_id (torch.Tensor): The sequence ID tensor of shape (batch_size, sequence_length).
76 | affine (Affine3D | None): The affine transformation tensor or None.
77 | affine_mask (torch.Tensor | None): The affine mask tensor or None.
78 | chain_id (torch.Tensor): The protein chain tensor of shape (batch_size, sequence_length).
79 | Only used in geometric attention.
80 |
81 | Returns:
82 | post_norm: The output tensor of shape (batch_size, sequence_length, d_model).
83 | pre_norm: The embedding of shape (batch_size, sequence_length, d_model).
84 | """
85 | *batch_dims, _ = x.shape
86 | if sequence_id is None:
87 | sequence_id = torch.ones(
88 | size=batch_dims, dtype=torch.int64, device=x.device
89 | )
90 | if chain_id is None:
91 | chain_id = torch.ones(size=batch_dims, dtype=torch.int64, device=x.device)
92 | for block in self.blocks:
93 | x = block(x, sequence_id, affine, affine_mask, chain_id)
94 | return self.norm(x), x
95 |
--------------------------------------------------------------------------------
/src/esm/pretrained.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from esm.models.esm3 import ESM3
7 | from esm.models.function_decoder import FunctionTokenDecoder
8 | from esm.models.vqvae import (
9 | StructureTokenDecoder,
10 | StructureTokenEncoder,
11 | )
12 | from esm.utils.constants.esm3 import data_root
13 | from esm.utils.constants.models import (
14 | ESM3_FUNCTION_DECODER_V0,
15 | ESM3_OPEN_SMALL,
16 | ESM3_STRUCTURE_DECODER_V0,
17 | ESM3_STRUCTURE_ENCODER_V0,
18 | )
19 |
20 | ModelBuilder = Callable[[torch.device | str], nn.Module]
21 |
22 |
23 | def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
24 | model = (
25 | ESM3(
26 | d_model=1536,
27 | n_heads=24,
28 | v_heads=256,
29 | n_layers=48,
30 | structure_encoder_name=ESM3_STRUCTURE_ENCODER_V0,
31 | structure_decoder_name=ESM3_STRUCTURE_DECODER_V0,
32 | function_decoder_name=ESM3_FUNCTION_DECODER_V0,
33 | )
34 | .to(device)
35 | .eval()
36 | )
37 | state_dict = torch.load(
38 | data_root() / "data/weights/esm3_sm_open_v1.pth", map_location=device
39 | )
40 | model.load_state_dict(state_dict)
41 | return model
42 |
43 |
44 | def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"):
45 | model = (
46 | StructureTokenEncoder(
47 | d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096
48 | )
49 | .to(device)
50 | .eval()
51 | )
52 | state_dict = torch.load(
53 | data_root() / "data/weights/esm3_structure_encoder_v0.pth", map_location=device
54 | )
55 | model.load_state_dict(state_dict)
56 | return model
57 |
58 |
59 | def ESM3_structure_decoder_v0(device: torch.device | str = "cpu"):
60 | model = (
61 | StructureTokenDecoder(d_model=1280, n_heads=20, n_layers=30).to(device).eval()
62 | )
63 | state_dict = torch.load(
64 | data_root() / "data/weights/esm3_structure_decoder_v0.pth", map_location=device
65 | )
66 | model.load_state_dict(state_dict)
67 | return model
68 |
69 |
70 | def ESM3_function_decoder_v0(device: torch.device | str = "cpu"):
71 | model = FunctionTokenDecoder().to(device).eval()
72 | state_dict = torch.load(
73 | data_root() / "data/weights/esm3_function_decoder_v0.pth", map_location=device
74 | )
75 | model.load_state_dict(state_dict)
76 | return model
77 |
78 |
79 | LOCAL_MODEL_REGISTRY: dict[str, ModelBuilder] = {
80 | ESM3_OPEN_SMALL: ESM3_sm_open_v0,
81 | ESM3_STRUCTURE_ENCODER_V0: ESM3_structure_encoder_v0,
82 | ESM3_STRUCTURE_DECODER_V0: ESM3_structure_decoder_v0,
83 | ESM3_FUNCTION_DECODER_V0: ESM3_function_decoder_v0,
84 | }
85 |
86 |
87 | def load_local_model(model_name: str, device: torch.device | str = "cpu") -> nn.Module:
88 | if model_name not in LOCAL_MODEL_REGISTRY:
89 | raise ValueError(f"Model {model_name} not found in local model registry.")
90 | return LOCAL_MODEL_REGISTRY[model_name](device)
91 |
92 |
93 | # Register custom versions of ESM3 for use with the local inference API
94 | def register_local_model(model_name: str, model_builder: ModelBuilder) -> None:
95 | LOCAL_MODEL_REGISTRY[model_name] = model_builder
96 |
--------------------------------------------------------------------------------
/src/esm/tokenization/__init__.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Protocol
3 |
4 | from src.esm.utils.constants.esm3 import VQVAE_SPECIAL_TOKENS
5 | from src.esm.utils.constants.models import ESM3_OPEN_SMALL
6 |
7 | from .function_tokenizer import InterProQuantizedTokenizer
8 | from .residue_tokenizer import ResidueAnnotationsTokenizer
9 | from .sasa_tokenizer import SASADiscretizingTokenizer
10 | from .sequence_tokenizer import EsmSequenceTokenizer
11 | from .ss_tokenizer import SecondaryStructureTokenizer
12 | from .structure_tokenizer import StructureTokenizer
13 | from .tokenizer_base import EsmTokenizerBase
14 |
15 |
16 | class TokenizerCollectionProtocol(Protocol):
17 | sequence: EsmSequenceTokenizer
18 | structure: StructureTokenizer
19 | secondary_structure: SecondaryStructureTokenizer
20 | sasa: SASADiscretizingTokenizer
21 | function: InterProQuantizedTokenizer
22 | residue_annotations: ResidueAnnotationsTokenizer
23 |
24 |
25 | @dataclass
26 | class TokenizerCollection:
27 | sequence: EsmSequenceTokenizer
28 | structure: StructureTokenizer
29 | secondary_structure: SecondaryStructureTokenizer
30 | sasa: SASADiscretizingTokenizer
31 | function: InterProQuantizedTokenizer
32 | residue_annotations: ResidueAnnotationsTokenizer
33 |
34 |
35 | def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection:
36 | if model == ESM3_OPEN_SMALL:
37 | return TokenizerCollection(
38 | sequence=EsmSequenceTokenizer(),
39 | structure=StructureTokenizer(vq_vae_special_tokens=VQVAE_SPECIAL_TOKENS),
40 | secondary_structure=SecondaryStructureTokenizer(kind="ss8"),
41 | sasa=SASADiscretizingTokenizer(),
42 | function=InterProQuantizedTokenizer(),
43 | residue_annotations=ResidueAnnotationsTokenizer(),
44 | )
45 | else:
46 | raise ValueError(f"Unknown model: {model}")
47 |
48 |
49 | def get_invalid_tokenizer_ids(tokenizer: EsmTokenizerBase) -> list[int]:
50 | if isinstance(tokenizer, EsmSequenceTokenizer):
51 | return [
52 | tokenizer.mask_token_id, # type: ignore
53 | tokenizer.pad_token_id, # type: ignore
54 | tokenizer.cls_token_id, # type: ignore
55 | tokenizer.eos_token_id, # type: ignore
56 | ]
57 | else:
58 | return [
59 | tokenizer.mask_token_id,
60 | tokenizer.pad_token_id,
61 | tokenizer.bos_token_id,
62 | tokenizer.eos_token_id,
63 | ]
64 |
--------------------------------------------------------------------------------
/src/esm/tokenization/residue_tokenizer.py:
--------------------------------------------------------------------------------
1 | from functools import cached_property
2 | from pathlib import Path
3 | from typing import Any
4 |
5 | import pandas as pd
6 | import torch
7 | import torch.nn.functional as F
8 |
9 | from src.esm.tokenization.tokenizer_base import EsmTokenizerBase
10 | from src.esm.utils.constants import esm3 as C
11 |
12 | Sample = dict[str, Any]
13 |
14 |
15 | class ResidueAnnotationsTokenizer(EsmTokenizerBase):
16 | def __init__(
17 | self,
18 | csv_path: str | None = None,
19 | max_annotations: int = 16,
20 | ):
21 | if csv_path is None:
22 | csv_path = str(C.data_root() / C.RESID_CSV)
23 | self.csv_path = csv_path
24 | self.max_annotations = max_annotations
25 |
26 | @cached_property
27 | def _description2label(self) -> dict[str, str]:
28 | with Path(self.csv_path).open() as f: # type: ignore
29 | df = pd.read_csv(f)
30 | return dict(zip(df.label, df.label_clean))
31 |
32 | @cached_property
33 | def _labels(self) -> list[str]:
34 | with Path(self.csv_path).open() as f: # type: ignore
35 | df = pd.read_csv(f)
36 | labels = (
37 | df.groupby("label_clean")["count"]
38 | .sum()
39 | .sort_values(ascending=False, kind="stable") # type: ignore
40 | .index.tolist()
41 | )
42 | assert isinstance(labels, list)
43 | return labels # type: ignore
44 |
45 | def _description2id(self, description: str) -> int | None:
46 | label = self._description2label.get(description)
47 | return self._label2id.get(label) # type: ignore
48 |
49 | @cached_property
50 | def _label2id(self) -> dict[str, int]:
51 | offset = len(self.special_tokens) + 1 # +1 for ""
52 | return {label: offset + i for i, label in enumerate(self._labels)}
53 |
54 | @cached_property
55 | def special_tokens(self) -> list[str]:
56 | """List of special tokens which come before cluster toknes in vocab."""
57 | return ["", "", ""]
58 |
59 | @cached_property
60 | def vocab(self):
61 | annotation_tokens = [f"" for _, id in self._label2id.items()]
62 | return self.special_tokens + [""] + annotation_tokens
63 |
64 | @cached_property
65 | def vocab_to_index(self) -> dict[str, int]:
66 | return {token: token_id for token_id, token in enumerate(self.vocab)}
67 |
68 | @cached_property
69 | def vocabulary(self) -> list[str]:
70 | """Full vocabulary."""
71 | return [*self.special_tokens, "", *self._labels]
72 |
73 | def get_special_tokens_mask(self, encoded: torch.Tensor) -> torch.Tensor:
74 | """Determines where in the sequence are special tokens."""
75 | return encoded[:, 0] < len(self.special_tokens)
76 |
77 | def tokenize(
78 | self, sample: Sample | None, sequence: str, fail_on_mismatch: bool = False
79 | ) -> list[str]:
80 | """
81 | # interpro_site_starts
82 | # interpro_site_ends # should always == interpro_site_starts. but I haven't checked overall.
83 | # interpro_site_residues # the residue identity of the specfic residue that is annotated. good for a sanity check that parsing occurred correctly.
84 | # interpro_site_descriptions
85 | # ASSERT (i.e. drop if bad)
86 | # interpro_site_residues matches the residue at that position
87 | # all these lists ^ above are the same length
88 | """
89 | seqlen = len(sequence)
90 | assert seqlen >= 0
91 | # None mean sequence is *not annotated* - so use full
92 | if sample is None:
93 | return [""] * seqlen
94 |
95 | if any(
96 | sample.get(field) is None
97 | for field in [
98 | "interpro_site_descriptions",
99 | "interpro_site_starts",
100 | "interpro_site_ends",
101 | "interpro_site_residues",
102 | ]
103 | ):
104 | return [""] * seqlen
105 |
106 | num_annotations = len(sample["interpro_site_descriptions"])
107 | if any(
108 | len(sample[field]) != num_annotations
109 | for field in [
110 | "interpro_site_starts",
111 | "interpro_site_ends",
112 | "interpro_site_residues",
113 | ]
114 | ):
115 | # mismatched length.
116 | return [""] * seqlen
117 |
118 | positional_ids = [set() for _ in range(seqlen)]
119 | for description, start, end, residues in zip(
120 | sample["interpro_site_descriptions"],
121 | sample["interpro_site_starts"],
122 | sample["interpro_site_ends"],
123 | sample["interpro_site_residues"],
124 | ):
125 | try:
126 | start = int(start)
127 | end = int(end)
128 | except (TypeError, ValueError):
129 | continue
130 |
131 | # Start / End are 1-indexed [inclusive, inclusive].
132 | if start <= 0 or end > seqlen or start > end:
133 | print(f"invalid start/end: ({start}, {end}), len: {seqlen}")
134 | continue
135 |
136 | if len(residues) != (end - start) + 1:
137 | print(f"bad reference residue: {residues}")
138 | continue
139 |
140 | token_id = self._description2id(description)
141 | if token_id is None:
142 | token_id = self.vocab_to_index[""]
143 |
144 | for i, residue in zip(range(start - 1, end), residues):
145 | # If there are any mismatching residues, skip the entire sample.
146 | if sequence[i] != residue:
147 | if fail_on_mismatch:
148 | raise ValueError(
149 | f"Residue mismatch at position {i} (1-indexed): {sequence[i]} != {residue}"
150 | )
151 | return [""] * seqlen
152 |
153 | positional_ids[i].add(token_id)
154 |
155 | tokens = []
156 | for token_ids in positional_ids:
157 | if token_ids:
158 | token = ""
159 | else:
160 | token = ""
161 | tokens.append(token)
162 | return tokens
163 |
164 | def _token2ids(self, token: str) -> list[int]:
165 | if token.startswith(""):
166 | return [int(token_id) for token_id in token[4:-1].split(",")]
167 | else:
168 | token_id = self.vocab_to_index[token]
169 | return [token_id]
170 |
171 | def encode(
172 | self, tokens: list[str], add_special_tokens: bool = True
173 | ) -> torch.Tensor:
174 | token_ids = torch.full(
175 | size=(len(tokens), self.max_annotations),
176 | dtype=torch.int64,
177 | fill_value=self.vocab_to_index[""],
178 | )
179 | for i, token in enumerate(tokens):
180 | ids = self._token2ids(token)[: self.max_annotations]
181 | token_ids[i, : len(ids)] = torch.tensor(ids)
182 |
183 | if add_special_tokens:
184 | token_ids = F.pad(
185 | token_ids, (0, 0, 1, 1), value=self.vocab_to_index[""]
186 | )
187 | return token_ids
188 |
189 | def decode(self, encoded: torch.Tensor) -> list[str]:
190 | raise NotImplementedError(
191 | "Residue annotation decoding should be handled with util.decoding.decode_residue_annotations"
192 | )
193 |
194 | @property
195 | def mask_token(self) -> str:
196 | return ""
197 |
198 | @property
199 | def mask_token_id(self) -> int:
200 | return self.vocab_to_index[self.mask_token]
201 |
202 | @property
203 | def bos_token(self) -> str:
204 | return ""
205 |
206 | @property
207 | def bos_token_id(self) -> int:
208 | return self.vocab_to_index[self.bos_token]
209 |
210 | @property
211 | def eos_token(self) -> str:
212 | return ""
213 |
214 | @property
215 | def eos_token_id(self) -> int:
216 | return self.vocab_to_index[self.eos_token]
217 |
218 | @property
219 | def pad_token(self) -> str:
220 | return ""
221 |
222 | @property
223 | def pad_token_id(self) -> int:
224 | return self.vocab_to_index[self.pad_token]
225 |
--------------------------------------------------------------------------------
/src/esm/tokenization/sasa_tokenizer.py:
--------------------------------------------------------------------------------
1 | from functools import cached_property
2 |
3 | import torch
4 |
5 | from src.esm.tokenization.tokenizer_base import EsmTokenizerBase
6 | from src.esm.utils.constants import esm3 as C
7 |
8 |
9 | class SASADiscretizingTokenizer(EsmTokenizerBase):
10 | """Tokenizer for Solvent Accessible Surface Area (SASA)."""
11 |
12 | def __init__(self, boundaries: list[float] = C.SASA_DISCRETIZATION_BOUNDARIES):
13 | self._boundaries = sorted(boundaries)
14 |
15 | @cached_property
16 | def special_tokens(self) -> list[str]:
17 | return ["", "", ""]
18 |
19 | @cached_property
20 | def vocab(self) -> list[str]:
21 | """Discrete token vocabulary.
22 |
23 | Returns:
24 | token vocabulary with ranges represented as "".
25 | """
26 | boundary_strs = ["0"] + [str(b) for b in self._boundaries] + ["inf"]
27 | range_tokens = [
28 | f"<{low}-{high}>"
29 | for low, high in zip(boundary_strs[:-1], boundary_strs[1:])
30 | ]
31 | return self.special_tokens + range_tokens
32 |
33 | @cached_property
34 | def midpoints(self) -> list[float]:
35 | """Midpoints of the SASA token ranges."""
36 | boundaries = [0] + self._boundaries + [self._boundaries[-1] * 2]
37 | midpoint_tokens = [
38 | (float(high) + float(low)) / 2
39 | for low, high in zip(boundaries[:-1], boundaries[1:])
40 | ]
41 | midpoint_tokens = [float("nan"), float("nan"), float("nan")] + midpoint_tokens
42 | return midpoint_tokens
43 |
44 | @cached_property
45 | def vocab_to_index(self) -> dict[str, int]:
46 | """Constructs token -> token id mapping."""
47 | return {word: i for i, word in enumerate(self.vocab)}
48 |
49 | def get_special_tokens_mask(self, tokens: torch.Tensor) -> torch.Tensor:
50 | """Determines which positions are special tokens.
51 |
52 | Args:
53 | tokens: [length]
54 | Returns:
55 | [length] tensor, true where special tokens are located in the input.
56 | """
57 | return tokens < len(self.special_tokens)
58 |
59 | def encode(
60 | self, values: list[float | str], add_special_tokens: bool = True
61 | ) -> torch.Tensor:
62 | """Encodes SASA values as discrete tokens.
63 |
64 | Args:
65 | values: list of either SASA values or individual tokens. For example
66 | [1.2, "", 10.3, , 0.]
67 | Returns:
68 | Token ids as tensor. Adds BOS and EOS special tokens.
69 | """
70 | ids = []
71 | if add_special_tokens:
72 | ids.append(self.vocab_to_index[""]) # BOS
73 | for value in values:
74 | if isinstance(value, (float, int)):
75 | bucket = torch.bucketize(value, torch.tensor(self._boundaries))
76 | token_id = len(self.special_tokens) + bucket
77 | elif isinstance(value, str):
78 | token_id = self.vocab_to_index[value]
79 | else:
80 | raise TypeError(value)
81 | ids.append(token_id)
82 | if add_special_tokens:
83 | ids.append(self.vocab_to_index[""]) # EOS
84 |
85 | return torch.tensor(ids, dtype=torch.int64)
86 |
87 | def decode_float(self, encoded: torch.Tensor) -> list[float]:
88 | """Decodes SASA token ids into float values."""
89 | return [self.midpoints[token_id] for token_id in encoded]
90 |
91 | def decode(self, encoded: torch.Tensor) -> str:
92 | """Decodes SASA token ids."""
93 | return ",".join(self.vocab[i] for i in encoded)
94 |
95 | def decode_list(self, encoded: torch.Tensor) -> list[str]:
96 | """Decodes SASA token ids."""
97 | return [self.vocab[i] for i in encoded]
98 |
99 | @property
100 | def mask_token(self) -> str:
101 | return ""
102 |
103 | @property
104 | def mask_token_id(self) -> int:
105 | return self.vocab_to_index[self.mask_token]
106 |
107 | @property
108 | def bos_token(self) -> str:
109 | return ""
110 |
111 | @property
112 | def bos_token_id(self) -> int:
113 | return self.vocab_to_index[self.bos_token]
114 |
115 | @property
116 | def eos_token(self) -> str:
117 | return ""
118 |
119 | @property
120 | def eos_token_id(self) -> int:
121 | return self.vocab_to_index[self.eos_token]
122 |
123 | @property
124 | def pad_token(self) -> str:
125 | return ""
126 |
127 | @property
128 | def pad_token_id(self) -> int:
129 | return self.vocab_to_index[self.pad_token]
130 |
--------------------------------------------------------------------------------
/src/esm/tokenization/sequence_tokenizer.py:
--------------------------------------------------------------------------------
1 | from tokenizers import Tokenizer
2 | from tokenizers.models import BPE
3 | from tokenizers.processors import TemplateProcessing
4 | from transformers import PreTrainedTokenizerFast
5 |
6 | from src.esm.tokenization.tokenizer_base import EsmTokenizerBase
7 | from src.esm.utils.constants import esm3 as C
8 |
9 |
10 | class EsmSequenceTokenizer(PreTrainedTokenizerFast, EsmTokenizerBase):
11 | """
12 | Constructs an ESM tokenizer.
13 | """
14 |
15 | model_input_names = ["sequence_tokens", "attention_mask"]
16 |
17 | def __init__(
18 | self,
19 | unk_token="",
20 | cls_token="",
21 | pad_token="",
22 | mask_token="",
23 | eos_token="",
24 | chainbreak_token="|",
25 | **kwargs,
26 | ):
27 | all_tokens = C.SEQUENCE_VOCAB
28 | token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
29 |
30 | # a character-level tokenizer is the same as BPE with no token merges
31 | bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
32 | tokenizer = Tokenizer(bpe)
33 | special_tokens = [cls_token, pad_token, mask_token, eos_token, chainbreak_token]
34 | additional_special_tokens = [chainbreak_token]
35 |
36 | tokenizer.add_special_tokens(
37 | special_tokens,
38 | )
39 |
40 | # This is where we configure the automatic addition of special tokens when we call
41 | # tokenizer(text, add_special_tokens=True). Note that you can also configure how two
42 | # sequences are merged if you want.
43 | tokenizer.post_processor = TemplateProcessing( # type: ignore
44 | single=" $A ",
45 | special_tokens=[
46 | ("", tokenizer.token_to_id("")),
47 | ("", tokenizer.token_to_id("")),
48 | ],
49 | )
50 | super().__init__(
51 | tokenizer_object=tokenizer,
52 | unk_token=unk_token,
53 | cls_token=cls_token,
54 | pad_token=pad_token,
55 | mask_token=mask_token,
56 | eos_token=eos_token,
57 | additional_special_tokens=additional_special_tokens,
58 | **kwargs,
59 | )
60 |
61 | # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
62 | @property
63 | def bos_token(self):
64 | return self.cls_token
65 |
66 | @property
67 | def bos_token_id(self):
68 | return self.cls_token_id
69 |
--------------------------------------------------------------------------------
/src/esm/tokenization/ss_tokenizer.py:
--------------------------------------------------------------------------------
1 | from functools import cached_property
2 | from typing import Sequence
3 |
4 | import torch
5 |
6 | from src.esm.tokenization.tokenizer_base import EsmTokenizerBase
7 | from src.esm.utils.constants import esm3 as C
8 |
9 |
10 | class SecondaryStructureTokenizer(EsmTokenizerBase):
11 | """Tokenizer for secondary structure strings."""
12 |
13 | def __init__(self, kind: str = "ss8"):
14 | assert kind in ("ss8", "ss3")
15 | self.kind = kind
16 |
17 | @property
18 | def special_tokens(self) -> list[str]:
19 | return ["", "", ""]
20 |
21 | @cached_property
22 | def vocab(self):
23 | """Tokenzier vocabulary list."""
24 | match self.kind:
25 | case "ss8":
26 | nonspecial_tokens = list(C.SSE_8CLASS_VOCAB) # "GHITEBSC"
27 | case "ss3":
28 | nonspecial_tokens = list(C.SSE_3CLASS_VOCAB) # HEC
29 | case _:
30 | raise ValueError(self.kind)
31 |
32 | # The non-special tokens ids match amino acid tokens ids when possible.
33 | return [*self.special_tokens, *nonspecial_tokens]
34 |
35 | @cached_property
36 | def vocab_to_index(self) -> dict[str, int]:
37 | """Constructs token -> token id mapping."""
38 | return {word: i for i, word in enumerate(self.vocab)}
39 |
40 | def get_special_tokens_mask(self, tokens: torch.Tensor) -> torch.Tensor:
41 | """Determines which positions are special tokens.
42 |
43 | Args:
44 | tokens: [length]
45 | Returns:
46 | [length] tensor, true where special tokens are located in the input.
47 | """
48 | return tokens < len(self.special_tokens)
49 |
50 | def encode(
51 | self, sequence: str | Sequence[str], add_special_tokens: bool = True
52 | ) -> torch.Tensor:
53 | """Encode secondary structure string
54 |
55 | Args:
56 | string: secondary structure string e.g. "GHHIT", or as token listk.
57 | Returns:
58 | [sequence_length] token ids representing. Will add /.
59 | """
60 | ids = []
61 | if add_special_tokens:
62 | ids.append(self.vocab_to_index[""]) # cls
63 | for char in sequence:
64 | ids.append(self.vocab_to_index[char])
65 | if add_special_tokens:
66 | ids.append(self.vocab_to_index[""]) # eos
67 | return torch.tensor(ids, dtype=torch.int64)
68 |
69 | def decode(self, encoded: torch.Tensor) -> str:
70 | """Decodes token ids into secondary structure string.
71 |
72 | Args:
73 | encoded: [length] token id array.
74 | Returns
75 | Decoded secondary structure string.
76 | """
77 | return "".join(self.vocab[i] for i in encoded)
78 |
79 | @property
80 | def mask_token(self) -> str:
81 | return ""
82 |
83 | @property
84 | def mask_token_id(self) -> int:
85 | return self.vocab_to_index[self.mask_token]
86 |
87 | @property
88 | def bos_token(self) -> str:
89 | return ""
90 |
91 | @property
92 | def bos_token_id(self) -> int:
93 | return self.vocab_to_index[self.bos_token]
94 |
95 | @property
96 | def eos_token(self) -> str:
97 | return ""
98 |
99 | @property
100 | def eos_token_id(self) -> int:
101 | return self.vocab_to_index[self.eos_token]
102 |
103 | @property
104 | def pad_token(self) -> str:
105 | return ""
106 |
107 | @property
108 | def pad_token_id(self) -> int:
109 | return self.vocab_to_index[self.pad_token]
110 |
--------------------------------------------------------------------------------
/src/esm/tokenization/structure_tokenizer.py:
--------------------------------------------------------------------------------
1 | from src.esm.tokenization.tokenizer_base import EsmTokenizerBase
2 |
3 |
4 | class StructureTokenizer(EsmTokenizerBase):
5 | """A convenince class for accessing special token ids of
6 | the StructureTokenEncoder and StructureTokenDecoder."""
7 |
8 | def __init__(self, vq_vae_special_tokens: dict[str, int]):
9 | self.vq_vae_special_tokens = vq_vae_special_tokens
10 |
11 | def mask_token(self) -> str:
12 | raise NotImplementedError(
13 | "Structure tokens are defined on 3D coordinates, not strings."
14 | )
15 |
16 | @property
17 | def mask_token_id(self) -> int:
18 | return self.vq_vae_special_tokens["MASK"]
19 |
20 | def bos_token(self) -> str:
21 | raise NotImplementedError(
22 | "Structure tokens are defined on 3D coordinates, not strings."
23 | )
24 |
25 | @property
26 | def bos_token_id(self) -> int:
27 | return self.vq_vae_special_tokens["BOS"]
28 |
29 | def eos_token(self) -> str:
30 | raise NotImplementedError(
31 | "Structure tokens are defined on 3D coordinates, not strings."
32 | )
33 |
34 | @property
35 | def eos_token_id(self) -> int:
36 | return self.vq_vae_special_tokens["EOS"]
37 |
38 | def pad_token(self) -> str:
39 | raise NotImplementedError(
40 | "Structure tokens are defined on 3D coordinates, not strings."
41 | )
42 |
43 | @property
44 | def pad_token_id(self) -> int:
45 | return self.vq_vae_special_tokens["PAD"]
46 |
47 | @property
48 | def chainbreak_token_id(self) -> int:
49 | return self.vq_vae_special_tokens["CHAINBREAK"]
50 |
51 | def encode(self, *args, **kwargs):
52 | raise NotImplementedError(
53 | "The StructureTokenizer class is provided as a convenience for "
54 | "accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.\n"
55 | "Please use them instead."
56 | )
57 |
58 | def decode(self, *args, **kwargs):
59 | raise NotImplementedError(
60 | "The StructureTokenizer class is provided as a convenience for "
61 | "accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.\n"
62 | "Please use them instead."
63 | )
64 |
--------------------------------------------------------------------------------
/src/esm/tokenization/tokenizer_base.py:
--------------------------------------------------------------------------------
1 | from typing import Protocol, runtime_checkable
2 |
3 |
4 | @runtime_checkable
5 | class EsmTokenizerBase(Protocol):
6 | def encode(self, *args, **kwargs):
7 | ...
8 |
9 | def decode(self, *args, **kwargs):
10 | ...
11 |
12 | @property
13 | def mask_token(self) -> str:
14 | ...
15 |
16 | @property
17 | def mask_token_id(self) -> int:
18 | ...
19 |
20 | @property
21 | def bos_token(self) -> str:
22 | ...
23 |
24 | @property
25 | def bos_token_id(self) -> int:
26 | ...
27 |
28 | @property
29 | def eos_token(self) -> str:
30 | ...
31 |
32 | @property
33 | def eos_token_id(self) -> int:
34 | ...
35 |
36 | @property
37 | def pad_token(self) -> str:
38 | ...
39 |
40 | @property
41 | def pad_token_id(self) -> int:
42 | ...
43 |
--------------------------------------------------------------------------------
/src/esm/utils/constants/esm3.py:
--------------------------------------------------------------------------------
1 | from functools import cache
2 | from pathlib import Path
3 |
4 | from huggingface_hub import snapshot_download
5 |
6 | SEQUENCE_BOS_TOKEN = 0
7 | SEQUENCE_PAD_TOKEN = 1
8 | SEQUENCE_EOS_TOKEN = 2
9 | SEQUENCE_CHAINBREAK_TOKEN = 31
10 | SEQUENCE_MASK_TOKEN = 32
11 |
12 | VQVAE_CODEBOOK_SIZE = 4096
13 | VQVAE_SPECIAL_TOKENS = {
14 | "MASK": VQVAE_CODEBOOK_SIZE,
15 | "EOS": VQVAE_CODEBOOK_SIZE + 1,
16 | "BOS": VQVAE_CODEBOOK_SIZE + 2,
17 | "PAD": VQVAE_CODEBOOK_SIZE + 3,
18 | "CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4,
19 | }
20 | VQVAE_DIRECTION_LOSS_BINS = 16
21 | VQVAE_PAE_BINS = 64
22 | VQVAE_MAX_PAE_BIN = 31.0
23 | VQVAE_PLDDT_BINS = 50
24 |
25 | STRUCTURE_MASK_TOKEN = VQVAE_SPECIAL_TOKENS["MASK"]
26 | STRUCTURE_BOS_TOKEN = VQVAE_SPECIAL_TOKENS["BOS"]
27 | STRUCTURE_EOS_TOKEN = VQVAE_SPECIAL_TOKENS["EOS"]
28 | STRUCTURE_PAD_TOKEN = VQVAE_SPECIAL_TOKENS["PAD"]
29 | STRUCTURE_CHAINBREAK_TOKEN = VQVAE_SPECIAL_TOKENS["CHAINBREAK"]
30 | STRUCTURE_UNDEFINED_TOKEN = 955
31 |
32 | SASA_UNK_TOKEN = 2
33 | SASA_PAD_TOKEN = 0
34 |
35 | SS8_UNK_TOKEN = 2
36 | SS8_PAD_TOKEN = 0
37 |
38 | INTERPRO_PAD_TOKEN = 0
39 |
40 | RESIDUE_PAD_TOKEN = 0
41 |
42 | CHAIN_BREAK_STR = "|"
43 |
44 | SEQUENCE_BOS_STR = ""
45 | SEQUENCE_EOS_STR = ""
46 |
47 | MASK_STR_SHORT = "_"
48 | SEQUENCE_MASK_STR = ""
49 | SASA_MASK_STR = ""
50 | SS8_MASK_STR = ""
51 |
52 | # fmt: off
53 | SEQUENCE_VOCAB = [
54 | "", "", "", "",
55 | "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
56 | "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
57 | "O", ".", "-", "|",
58 | "",
59 | ]
60 | # fmt: on
61 |
62 | SSE_8CLASS_VOCAB = "GHITEBSC"
63 | SSE_3CLASS_VOCAB = "HEC"
64 | SSE_8CLASS_TO_3CLASS_MAP = {
65 | "G": "H",
66 | "H": "H",
67 | "I": "H",
68 | "T": "C",
69 | "E": "E",
70 | "B": "E",
71 | "S": "C",
72 | "C": "C",
73 | }
74 |
75 | SASA_DISCRETIZATION_BOUNDARIES = [
76 | 0.8,
77 | 4.0,
78 | 9.6,
79 | 16.4,
80 | 24.5,
81 | 32.9,
82 | 42.0,
83 | 51.5,
84 | 61.2,
85 | 70.9,
86 | 81.6,
87 | 93.3,
88 | 107.2,
89 | 125.4,
90 | 151.4,
91 | ]
92 |
93 | MAX_RESIDUE_ANNOTATIONS = 16
94 |
95 |
96 | TFIDF_VECTOR_SIZE = 58641
97 |
98 |
99 | @staticmethod
100 | @cache
101 | def data_root():
102 | # Try a few default directories
103 | for path in [
104 | "esm/data",
105 | "esm/data",
106 | ]:
107 | if (p := Path(path)).exists():
108 | return p.parent
109 | # Try to download from hugginface if it doesn't exist
110 | path = Path(snapshot_download(repo_id="EvolutionaryScale/esm3-sm-open-v1"))
111 | return path
112 |
113 |
114 | INTERPRO_ENTRY = "data/entry_list_safety_29026.list"
115 | INTERPRO_HIERARCHY = "data/ParentChildTreeFile.txt"
116 | INTERPRO2GO = "data/ParentChildTreeFile.txt"
117 | INTERPRO_2ID = "data/tag_dict_4_safety_filtered.json"
118 |
119 | LSH_TABLE_PATHS = {
120 | "8bit": "data/hyperplanes_8bit_58641.npz",
121 | }
122 |
123 | KEYWORDS_VOCABULARY = "data/keyword_vocabulary_safety_filtered_58641.txt"
124 | KEYWORDS_IDF = "data/keyword_idf_safety_filtered_58641.npy"
125 |
126 | RESID_CSV = "data/uniref90_and_mgnify90_residue_annotations_gt_1k_proteins.csv"
127 | INTERPRO2KEYWORDS = "data/interpro_29026_to_keywords_58641.csv"
128 |
--------------------------------------------------------------------------------
/src/esm/utils/constants/models.py:
--------------------------------------------------------------------------------
1 | # Model names
2 | ESM3_OPEN_SMALL = "esm3_sm_open_v1"
3 | ESM3_STRUCTURE_ENCODER_V0 = "esm3_structure_encoder_v0"
4 | ESM3_STRUCTURE_DECODER_V0 = "esm3_structure_decoder_v0"
5 | ESM3_FUNCTION_DECODER_V0 = "esm3_function_decoder_v0"
6 |
--------------------------------------------------------------------------------
/src/esm/utils/constants/physics.py:
--------------------------------------------------------------------------------
1 | BB_COORDINATES = [
2 | [0.5256, 1.3612, 0.0000],
3 | [0.0000, 0.0000, 0.0000],
4 | [-1.5251, 0.0000, 0.0000],
5 | ]
6 |
--------------------------------------------------------------------------------
/src/esm/utils/decoding.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import attr
4 | import torch
5 |
6 | from src.esm.models.function_decoder import FunctionTokenDecoder
7 | from src.esm.models.vqvae import StructureTokenDecoder
8 | from src.esm.sdk.api import ESMProtein, ESMProteinTensor
9 | from src.esm.tokenization import TokenizerCollectionProtocol
10 | from src.esm.tokenization.function_tokenizer import (
11 | InterProQuantizedTokenizer,
12 | )
13 | from src.esm.tokenization.residue_tokenizer import (
14 | ResidueAnnotationsTokenizer,
15 | )
16 | from src.esm.tokenization.sasa_tokenizer import (
17 | SASADiscretizingTokenizer,
18 | )
19 | from src.esm.tokenization.sequence_tokenizer import (
20 | EsmSequenceTokenizer,
21 | )
22 | from src.esm.tokenization.ss_tokenizer import (
23 | SecondaryStructureTokenizer,
24 | )
25 | from src.esm.tokenization.structure_tokenizer import (
26 | StructureTokenizer,
27 | )
28 | from src.esm.tokenization.tokenizer_base import EsmTokenizerBase
29 | from src.esm.utils.constants import esm3 as C
30 | from src.esm.utils.function.encode_decode import (
31 | decode_function_tokens,
32 | decode_residue_annotation_tokens,
33 | )
34 | from src.esm.utils.structure.protein_chain import ProteinChain
35 | from src.esm.utils.types import FunctionAnnotation
36 |
37 |
38 | def decode_protein_tensor(
39 | input: ESMProteinTensor,
40 | tokenizers: TokenizerCollectionProtocol,
41 | structure_token_decoder: StructureTokenDecoder,
42 | function_token_decoder: FunctionTokenDecoder,
43 | ) -> ESMProtein:
44 | input = attr.evolve(input) # Make a copy
45 |
46 | sequence = None
47 | secondary_structure = None
48 | sasa = None
49 | function_annotations = []
50 |
51 | coordinates = None
52 |
53 | # If all pad tokens, set to None
54 | for track in attr.fields(ESMProteinTensor):
55 | tokens: torch.Tensor | None = getattr(input, track.name)
56 | if track.name == "coordinates":
57 | continue
58 | if tokens is not None:
59 | tokens = tokens[1:-1] # Remove BOS and EOS tokens
60 | tokens = tokens.flatten() # For multi-track tensors
61 | track_tokenizer = getattr(tokenizers, track.name)
62 | if torch.all(tokens == track_tokenizer.pad_token_id):
63 | setattr(input, track.name, None)
64 |
65 | if input.sequence is not None:
66 | sequence = decode_sequence(input.sequence, tokenizers.sequence)
67 |
68 | plddt, ptm = None, None
69 | if input.structure is not None:
70 | # Note: We give priority to the structure tokens over the coordinates when decoding
71 | coordinates, plddt, ptm = decode_structure(
72 | structure_tokens=input.structure,
73 | structure_decoder=structure_token_decoder,
74 | structure_tokenizer=tokenizers.structure,
75 | sequence=sequence,
76 | )
77 | elif input.coordinates is not None:
78 | coordinates = input.coordinates[1:-1, ...]
79 |
80 | if input.secondary_structure is not None:
81 | secondary_structure = decode_secondary_structure(
82 | input.secondary_structure, tokenizers.secondary_structure
83 | )
84 | if input.sasa is not None:
85 | sasa = decode_sasa(input.sasa, tokenizers.sasa)
86 | if input.function is not None:
87 | function_track_annotations = decode_function_annotations(
88 | input.function,
89 | function_token_decoder=function_token_decoder,
90 | function_tokenizer=tokenizers.function,
91 | )
92 | function_annotations.extend(function_track_annotations)
93 | if input.residue_annotations is not None:
94 | residue_annotations = decode_residue_annotations(
95 | input.residue_annotations, tokenizers.residue_annotations
96 | )
97 | function_annotations.extend(residue_annotations)
98 |
99 | return ESMProtein(
100 | sequence=sequence,
101 | secondary_structure=secondary_structure,
102 | sasa=sasa, # type: ignore
103 | function_annotations=function_annotations if function_annotations else None,
104 | coordinates=coordinates,
105 | plddt=plddt,
106 | ptm=ptm,
107 | )
108 |
109 |
110 | def _bos_eos_warn(msg: str, tensor: torch.Tensor, tok: EsmTokenizerBase):
111 | if tensor[0] != tok.bos_token_id:
112 | warnings.warn(
113 | f"{msg} does not start with BOS token, token is ignored. BOS={tok.bos_token_id} vs {tensor}"
114 | )
115 | if tensor[-1] != tok.eos_token_id:
116 | warnings.warn(
117 | f"{msg} does not end with EOS token, token is ignored. EOS='{tok.eos_token_id}': {tensor}"
118 | )
119 |
120 |
121 | def decode_sequence(
122 | sequence_tokens: torch.Tensor,
123 | sequence_tokenizer: EsmSequenceTokenizer,
124 | **kwargs,
125 | ) -> str:
126 | _bos_eos_warn("Sequence", sequence_tokens, sequence_tokenizer)
127 | sequence = sequence_tokenizer.decode(
128 | sequence_tokens,
129 | **kwargs,
130 | )
131 | sequence = sequence.replace(" ", "")
132 | sequence = sequence.replace(sequence_tokenizer.mask_token, C.MASK_STR_SHORT)
133 | sequence = sequence.replace(sequence_tokenizer.cls_token, "")
134 | sequence = sequence.replace(sequence_tokenizer.eos_token, "")
135 |
136 | return sequence
137 |
138 |
139 | def decode_structure(
140 | structure_tokens: torch.Tensor,
141 | structure_decoder: StructureTokenDecoder,
142 | structure_tokenizer: StructureTokenizer,
143 | sequence: str | None = None,
144 | ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
145 | is_singleton = len(structure_tokens.size()) == 1
146 | if is_singleton:
147 | structure_tokens = structure_tokens.unsqueeze(0)
148 | else:
149 | raise ValueError(
150 | f"Only one structure can be decoded at a time, got structure tokens of shape {structure_tokens.size()}"
151 | )
152 | _bos_eos_warn("Structure", structure_tokens[0], structure_tokenizer)
153 |
154 | decoder_output = structure_decoder.decode(structure_tokens)
155 | bb_coords: torch.Tensor = decoder_output["bb_pred"][
156 | 0, 1:-1, ...
157 | ] # Remove BOS and EOS tokens
158 | bb_coords = bb_coords.detach().cpu()
159 |
160 | if "plddt" in decoder_output:
161 | plddt = decoder_output["plddt"][0, 1:-1]
162 | plddt = plddt.detach().cpu()
163 | else:
164 | plddt = None
165 |
166 | if "ptm" in decoder_output:
167 | ptm = decoder_output["ptm"]
168 | else:
169 | ptm = None
170 |
171 | chain = ProteinChain.from_backbone_atom_coordinates(bb_coords, sequence=sequence)
172 | chain = chain.infer_oxygen()
173 | return torch.tensor(chain.atom37_positions), plddt, ptm
174 |
175 |
176 | def decode_secondary_structure(
177 | secondary_structure_tokens: torch.Tensor,
178 | ss_tokenizer: SecondaryStructureTokenizer,
179 | ) -> str:
180 | _bos_eos_warn("Secondary structure", secondary_structure_tokens, ss_tokenizer)
181 | secondary_structure_tokens = secondary_structure_tokens[1:-1]
182 | secondary_structure = ss_tokenizer.decode(
183 | secondary_structure_tokens,
184 | )
185 | return secondary_structure
186 |
187 |
188 | def decode_sasa(
189 | sasa_tokens: torch.Tensor,
190 | sasa_tokenizer: SASADiscretizingTokenizer,
191 | ) -> list[float]:
192 | _bos_eos_warn("SASA", sasa_tokens, sasa_tokenizer)
193 | sasa_tokens = sasa_tokens[1:-1]
194 |
195 | return sasa_tokenizer.decode_float(sasa_tokens)
196 |
197 |
198 | def decode_function_annotations(
199 | function_annotation_tokens: torch.Tensor,
200 | function_token_decoder: FunctionTokenDecoder,
201 | function_tokenizer: InterProQuantizedTokenizer,
202 | **kwargs,
203 | ) -> list[FunctionAnnotation]:
204 | # No need to check for BOS/EOS as function annotations are not affected
205 |
206 | function_annotations = decode_function_tokens(
207 | function_annotation_tokens,
208 | function_token_decoder=function_token_decoder,
209 | function_tokens_tokenizer=function_tokenizer,
210 | **kwargs,
211 | )
212 | return function_annotations
213 |
214 |
215 | def decode_residue_annotations(
216 | residue_annotation_tokens: torch.Tensor,
217 | residue_annotation_decoder: ResidueAnnotationsTokenizer,
218 | ) -> list[FunctionAnnotation]:
219 | # No need to check for BOS/EOS as function annotations are not affected
220 |
221 | residue_annotations = decode_residue_annotation_tokens(
222 | residue_annotations_token_ids=residue_annotation_tokens,
223 | residue_annotations_tokenizer=residue_annotation_decoder,
224 | )
225 | return residue_annotations
226 |
--------------------------------------------------------------------------------
/src/esm/utils/encoding.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from src.esm.models.vqvae import StructureTokenEncoder
7 | from src.esm.tokenization.function_tokenizer import (
8 | InterProQuantizedTokenizer as EsmFunctionTokenizer,
9 | )
10 | from src.esm.tokenization.residue_tokenizer import (
11 | ResidueAnnotationsTokenizer,
12 | )
13 | from src.esm.tokenization.sasa_tokenizer import (
14 | SASADiscretizingTokenizer,
15 | )
16 | from src.esm.tokenization.sequence_tokenizer import (
17 | EsmSequenceTokenizer,
18 | )
19 | from src.esm.tokenization.ss_tokenizer import (
20 | SecondaryStructureTokenizer,
21 | )
22 | from src.esm.tokenization.structure_tokenizer import (
23 | StructureTokenizer,
24 | )
25 | from src.esm.utils.constants import esm3 as C
26 | from src.esm.utils.function.encode_decode import (
27 | encode_function_annotations,
28 | )
29 | from src.esm.utils.structure.protein_chain import ProteinChain
30 | from src.esm.utils.types import FunctionAnnotation
31 |
32 |
33 | # Raw Defaults
34 | def get_default_sequence(sequence_length: int) -> str:
35 | return C.MASK_STR_SHORT * sequence_length
36 |
37 |
38 | def get_default_secondary_structure(sequence_length: int) -> str:
39 | return C.MASK_STR_SHORT * sequence_length
40 |
41 |
42 | def get_default_sasa(sequence_length: int) -> Sequence[float | str | None]:
43 | return [None] * sequence_length
44 |
45 |
46 | # Tokenization
47 | def tokenize_sequence(
48 | sequence: str,
49 | sequence_tokenizer: EsmSequenceTokenizer,
50 | add_special_tokens: bool = True,
51 | ) -> torch.Tensor:
52 | sequence = sequence.replace(C.MASK_STR_SHORT, sequence_tokenizer.mask_token)
53 | sequence_tokens = sequence_tokenizer.encode(
54 | sequence, add_special_tokens=add_special_tokens
55 | )
56 | sequence_tokens = torch.tensor(sequence_tokens, dtype=torch.int64)
57 | return sequence_tokens
58 |
59 |
60 | def tokenize_structure(
61 | coordinates: torch.Tensor,
62 | structure_encoder: StructureTokenEncoder,
63 | structure_tokenizer: StructureTokenizer,
64 | reference_sequence: str = "",
65 | add_special_tokens: bool = True,
66 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
67 | device = next(structure_encoder.parameters()).device
68 | chain = ProteinChain.from_atom37(
69 | coordinates, sequence=reference_sequence if reference_sequence else None
70 | )
71 |
72 | # Setup padding
73 | if reference_sequence and len(reference_sequence) != coordinates.size(0):
74 | raise ValueError(
75 | f"Reference sequence length ({len(reference_sequence)}) does not match the number of residues in the coordinates ({coordinates.size(0)})"
76 | )
77 |
78 | left_pad = 0
79 | right_pad = 0
80 |
81 | if add_special_tokens:
82 | left_pad += 1 # Add space for BOS token
83 | right_pad += 1 # Add space for EOS token
84 |
85 | coordinates, plddt, residue_index = chain.to_structure_encoder_inputs()
86 | coordinates = coordinates.to(device) # (1, L, 37, 3)
87 | plddt = plddt.to(device) # (1, L)
88 | residue_index = residue_index.to(device) # (1, L)
89 | _, structure_tokens = structure_encoder.encode(
90 | coordinates, residue_index=residue_index
91 | )
92 | coordinates = torch.squeeze(coordinates, dim=0) # (L, 37, 3) # type: ignore
93 | plddt = torch.squeeze(plddt, dim=0) # (L,) # type: ignore
94 | structure_tokens = torch.squeeze(structure_tokens, dim=0) # (L,) # type: ignore
95 |
96 | # Add space for BOS and EOS tokens
97 | if add_special_tokens:
98 | coordinates = F.pad(
99 | coordinates,
100 | (0, 0, 0, 0, left_pad, right_pad),
101 | value=torch.inf,
102 | )
103 | plddt = F.pad(plddt, (left_pad, right_pad), value=0)
104 | structure_tokens = F.pad(
105 | structure_tokens,
106 | (left_pad, right_pad),
107 | value=structure_tokenizer.pad_token_id,
108 | )
109 | structure_tokens[0] = structure_tokenizer.bos_token_id
110 | structure_tokens[-1] = structure_tokenizer.eos_token_id
111 | return coordinates, plddt, structure_tokens
112 |
113 |
114 | def tokenize_secondary_structure(
115 | secondary_structure: str | Sequence[str],
116 | secondary_structure_tokenizer: SecondaryStructureTokenizer,
117 | add_special_tokens: bool = True,
118 | ) -> torch.Tensor:
119 | if isinstance(secondary_structure, str):
120 | # Ensure only one char per token
121 | secondary_structure = secondary_structure.replace(
122 | secondary_structure_tokenizer.mask_token, C.MASK_STR_SHORT
123 | )
124 |
125 | # Input as list of chars
126 | secondary_structure = [char for char in secondary_structure]
127 |
128 | # Use tokenizer's mask token
129 | secondary_structure = [
130 | secondary_structure_tokenizer.mask_token if char == C.MASK_STR_SHORT else char
131 | for char in secondary_structure
132 | ]
133 |
134 | secondary_structure_tokens = secondary_structure_tokenizer.encode(
135 | secondary_structure, add_special_tokens=add_special_tokens
136 | )
137 | return secondary_structure_tokens
138 |
139 |
140 | def tokenize_sasa(
141 | sasa: Sequence[float | str | None],
142 | sasa_tokenizer: SASADiscretizingTokenizer,
143 | add_special_tokens: bool = True,
144 | ):
145 | sasa_tokens = sasa_tokenizer.encode(
146 | [sasa_tokenizer.mask_token if value is None else value for value in sasa],
147 | add_special_tokens=add_special_tokens,
148 | )
149 | return sasa_tokens
150 |
151 |
152 | def tokenize_function_annotations(
153 | function_annotations: Sequence[FunctionAnnotation],
154 | reference_sequence: str,
155 | function_tokenizer: EsmFunctionTokenizer,
156 | residue_annotation_tokenizer: ResidueAnnotationsTokenizer,
157 | add_special_tokens: bool = True,
158 | ) -> tuple[torch.Tensor, torch.Tensor]:
159 | function_tokens, residue_annotation_tokens = encode_function_annotations(
160 | sequence=reference_sequence,
161 | function_annotations=function_annotations,
162 | function_tokens_tokenizer=function_tokenizer,
163 | residue_annotations_tokenizer=residue_annotation_tokenizer,
164 | add_special_tokens=add_special_tokens,
165 | )
166 | return function_tokens, residue_annotation_tokens
167 |
168 |
169 | # Tokenized Defaults
170 | def get_default_sequence_tokens(
171 | sequence_length: int,
172 | sequence_tokenizer: EsmSequenceTokenizer,
173 | ) -> torch.Tensor:
174 | return tokenize_sequence(
175 | get_default_sequence(sequence_length),
176 | sequence_tokenizer,
177 | add_special_tokens=True,
178 | )
179 |
180 |
181 | def get_default_structure_tokens(
182 | sequence_length: int, structure_tokenizer: StructureTokenizer
183 | ) -> torch.Tensor:
184 | structure_tokens = (
185 | torch.ones(
186 | (sequence_length + 2,),
187 | dtype=torch.int64,
188 | )
189 | * structure_tokenizer.pad_token_id
190 | )
191 | # Always include BOS and EOS tokens
192 | structure_tokens[0] = structure_tokenizer.bos_token_id
193 | structure_tokens[-1] = structure_tokenizer.eos_token_id
194 | return structure_tokens
195 |
196 |
197 | def get_default_secondary_structure_tokens(
198 | sequence_length: int, secondary_structure_tokenizer: SecondaryStructureTokenizer
199 | ) -> torch.Tensor:
200 | return tokenize_secondary_structure(
201 | get_default_secondary_structure(sequence_length),
202 | secondary_structure_tokenizer,
203 | add_special_tokens=True,
204 | )
205 |
206 |
207 | def get_default_sasa_tokens(
208 | sequence_length: int, sasa_tokenizer: SASADiscretizingTokenizer
209 | ) -> torch.Tensor:
210 | return tokenize_sasa(
211 | get_default_sasa(sequence_length), sasa_tokenizer, add_special_tokens=True
212 | )
213 |
214 |
215 | def get_default_function_tokens(
216 | sequence_length: int, function_tokenizer: EsmFunctionTokenizer
217 | ) -> torch.Tensor:
218 | function_tokens = (
219 | torch.ones((sequence_length + 2, function_tokenizer.depth), dtype=torch.int64)
220 | * function_tokenizer.pad_token_id
221 | )
222 | # Always include BOS and EOS tokens
223 | function_tokens[0] = function_tokenizer.bos_token_id
224 | function_tokens[-1] = function_tokenizer.eos_token_id
225 | return function_tokens
226 |
227 |
228 | def get_default_residue_annotation_tokens(
229 | sequence_length: int, residue_annotation_tokenizer: ResidueAnnotationsTokenizer
230 | ) -> torch.Tensor:
231 | residue_annotation_tokens = (
232 | torch.ones(
233 | (sequence_length + 2, C.MAX_RESIDUE_ANNOTATIONS),
234 | dtype=torch.int64,
235 | )
236 | * residue_annotation_tokenizer.pad_token_id
237 | )
238 | # Always include BOS and EOS tokens
239 | residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id
240 | residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id
241 | return residue_annotation_tokens
242 |
--------------------------------------------------------------------------------
/src/esm/utils/function/encode_decode.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Sequence
3 |
4 | import torch
5 |
6 | from src.esm.models.function_decoder import (
7 | FunctionTokenDecoder,
8 | _merge_annotations,
9 | )
10 | from src.esm.tokenization.function_tokenizer import (
11 | InterProQuantizedTokenizer,
12 | )
13 | from src.esm.tokenization.residue_tokenizer import (
14 | ResidueAnnotationsTokenizer,
15 | )
16 | from src.esm.utils.constants import esm3 as C
17 | from src.esm.utils.types import FunctionAnnotation
18 |
19 |
20 | def encode_function_annotations(
21 | sequence: str,
22 | function_annotations: Sequence[FunctionAnnotation],
23 | function_tokens_tokenizer: InterProQuantizedTokenizer,
24 | residue_annotations_tokenizer: ResidueAnnotationsTokenizer,
25 | add_special_tokens: bool = True,
26 | ) -> tuple[torch.Tensor, torch.Tensor]:
27 | assert isinstance(
28 | residue_annotations_tokenizer, ResidueAnnotationsTokenizer
29 | ), "residue_annotations_tokenizer must be of type ResidueAnnotationsTokenizer"
30 |
31 | # Split the user's annotations by type
32 | ft_annotations: list[FunctionAnnotation] = []
33 | ra_annotations: list[FunctionAnnotation] = []
34 | for fa in function_annotations:
35 | assert (
36 | 1 <= fa.start <= fa.end <= len(sequence)
37 | ), f"Invalid (start, end) in function annotation {fa}. Indices 1-indexed and [inclusive, inclusive]"
38 |
39 | supported_label = False
40 |
41 | # Is it an InterPro label?
42 | if match := re.match(r"IPR\d+", fa.label):
43 | if match.group() in function_tokens_tokenizer.interpro_to_index:
44 | ft_annotations.append(fa)
45 | supported_label = True
46 |
47 | # Is it a function keyword?
48 | if fa.label in function_tokens_tokenizer._tfidf.vocab_to_index:
49 | ft_annotations.append(fa)
50 | supported_label = True
51 |
52 | # Is it a residue annotation?
53 | if fa.label in residue_annotations_tokenizer._labels:
54 | ra_annotations.append(fa)
55 | supported_label = True
56 |
57 | if not supported_label:
58 | raise ValueError(f"Unknown label in FunctionAnnotation: {fa.label}")
59 |
60 | # Convert function token FunctionAnnotations -> Tensor
61 | function_tokens = function_tokens_tokenizer.tokenize(
62 | annotations=ft_annotations,
63 | seqlen=len(sequence),
64 | )
65 | function_token_ids = function_tokens_tokenizer.encode(
66 | function_tokens, add_special_tokens=add_special_tokens
67 | )
68 |
69 | # Convert residue annotation FunctionAnnotations -> Tensor
70 | if ra_annotations:
71 | descriptions, starts, ends = zip(
72 | *[(anot.label, anot.start, anot.end) for anot in ra_annotations]
73 | )
74 | else:
75 | descriptions = starts = ends = None
76 | ra_tokens = residue_annotations_tokenizer.tokenize(
77 | {
78 | "interpro_site_descriptions": descriptions,
79 | "interpro_site_starts": starts,
80 | "interpro_site_ends": ends,
81 | },
82 | sequence=sequence,
83 | fail_on_mismatch=True,
84 | )
85 | residue_annotation_ids = residue_annotations_tokenizer.encode(
86 | ra_tokens, add_special_tokens=add_special_tokens
87 | )
88 |
89 | return function_token_ids, residue_annotation_ids
90 |
91 |
92 | def decode_function_tokens(
93 | function_token_ids: torch.Tensor,
94 | function_token_decoder: FunctionTokenDecoder,
95 | function_tokens_tokenizer: InterProQuantizedTokenizer,
96 | decoder_annotation_threshold: float = 0.1,
97 | annotation_min_length: int | None = 5,
98 | annotation_gap_merge_max: int | None = 3,
99 | ) -> list[FunctionAnnotation]:
100 | """Decodes model prediction logits into function predictions.
101 |
102 | Merges function token and residue annotation predictions into a single
103 | set of FunctionAnnotation predictions.
104 |
105 | Args:
106 | function_token_ids: Tensor [length, depth] of
107 | function token ids.
108 | residue_annotation_logits: Tensor [length, RA-vocab] of residue
109 | annotation binary classification logits.
110 | function_tokens_tokenizer: InterPro annotation tokenizer.
111 | residue_annotation_threshold: tokenizer of residue annotations.
112 | residue_annotation_threshold: predicted probability threshold for emitting
113 | a predicted residue annotation.
114 | Returns:
115 | Predicted function annotations merged from both predictions.
116 | """
117 | assert (
118 | function_token_ids.ndim == 2
119 | ), "function_token_ids must be of shape (length, depth)"
120 |
121 | annotations: list[FunctionAnnotation] = []
122 |
123 | # Function Annotations from predicted function tokens.
124 | decoded = function_token_decoder.decode(
125 | function_token_ids,
126 | tokenizer=function_tokens_tokenizer,
127 | annotation_threshold=decoder_annotation_threshold,
128 | annotation_min_length=annotation_min_length,
129 | annotation_gap_merge_max=annotation_gap_merge_max,
130 | )
131 |
132 | # Convert predicted InterPro annotation to FunctionAnnotation.
133 | annotations.extend(decoded["function_keywords"])
134 | for annotation in decoded["interpro_annotations"]:
135 | annotation: FunctionAnnotation
136 | label = function_tokens_tokenizer.format_annotation(annotation)
137 | annotations.append(
138 | FunctionAnnotation(label=label, start=annotation.start, end=annotation.end)
139 | )
140 |
141 | return annotations
142 |
143 |
144 | def decode_residue_annotation_tokens(
145 | residue_annotations_token_ids: torch.Tensor,
146 | residue_annotations_tokenizer: ResidueAnnotationsTokenizer,
147 | annotation_min_length: int | None = 5,
148 | annotation_gap_merge_max: int | None = 3,
149 | ) -> list[FunctionAnnotation]:
150 | """Decodes residue annotation tokens into FunctionAnnotations.
151 |
152 | Args:
153 | tokens: Tensor [length, MAX_RESIDUE_ANNOTATIONS] of residue annotation tokens.
154 | residue_annotations_tokenizer: Tokenizer of residue annotations.
155 | threshold: predicted probability threshold for emitting a predicted residue
156 | annotation.
157 | Returns:
158 | Predicted residue annotations.
159 | """
160 | assert (
161 | residue_annotations_token_ids.ndim == 2
162 | ), "logits must be of shape (length, MAX_RESIDUE_ANNOTATIONS)"
163 |
164 | annotations: list[FunctionAnnotation] = []
165 |
166 | for depth in range(0, C.MAX_RESIDUE_ANNOTATIONS):
167 | token_ids = residue_annotations_token_ids[:, depth]
168 | for loc, vocab_index in torch.nonzero(token_ids).cpu().numpy():
169 | label = residue_annotations_tokenizer.vocabulary[vocab_index]
170 | if label not in [*residue_annotations_tokenizer.special_tokens, ""]:
171 | annotation = FunctionAnnotation(label=label, start=loc, end=loc)
172 | annotations.append(annotation)
173 |
174 | annotations = _merge_annotations(
175 | annotations,
176 | merge_gap_max=annotation_gap_merge_max,
177 | )
178 |
179 | # Drop very small annotations.
180 | if annotation_min_length is not None:
181 | annotations = [
182 | annotation
183 | for annotation in annotations
184 | if annotation.end - annotation.start + 1 >= annotation_min_length
185 | ]
186 |
187 | return annotations
188 |
--------------------------------------------------------------------------------
/src/esm/utils/function/interpro.py:
--------------------------------------------------------------------------------
1 | """Utilities for interacting with InterPro."""
2 |
3 | import itertools
4 | import re
5 | from dataclasses import dataclass
6 | from enum import IntEnum, auto
7 | from functools import cached_property
8 | from pathlib import Path
9 |
10 | import networkx as nx
11 | import numpy as np
12 | import pandas as pd
13 |
14 | from src.esm.utils.constants import esm3 as C
15 |
16 |
17 | def parse_go_terms(text: str) -> list[str]:
18 | """Parses GO terms from a string.
19 |
20 | Args:
21 | text: String containing GO terms. Example: "GO:0008309, GO:1902267" Note that GO
22 | terms have exactly 7 digits.
23 | Returns:
24 | All GO terms found in the string. Example: ['GO:0008309', 'GO:1902267']
25 | """
26 | return re.findall(r"GO:(?:\d{7,})", text)
27 |
28 |
29 | def _parse_interpro2go(path: str) -> dict[str, list[str]]:
30 | """Parses InterPro2GO file into map.
31 |
32 | NOTE: this file has a very strange, non-standard format.
33 |
34 | Args:
35 | path: path to InterPro2GO file from: https://www.ebi.ac.uk/GOA/InterPro2GO
36 | Returns:
37 | Mapping from InterPro to list of associated GO terms.
38 | """
39 | with Path(path).open("r") as f:
40 | text = f.read()
41 | df = pd.Series(text.split("\n"), name="line").to_frame()
42 | df = df[~df.line.str.startswith("!")]
43 | df["interpro_id"] = df.line.apply(lambda line: re.findall(r"IPR\d+", line))
44 | df["go_ids"] = df.line.apply(parse_go_terms)
45 | df = df[df.go_ids.apply(len).gt(0) & df.interpro_id.apply(len).eq(1)]
46 | df["interpro_id"] = df["interpro_id"].apply(lambda xs: xs[0]) # type: ignore
47 |
48 | # Group all mappints together into a single map.
49 | df = (
50 | df.groupby("interpro_id")["go_ids"] # type: ignore
51 | .apply(lambda group: list(itertools.chain.from_iterable(group)))
52 | .reset_index()
53 | )
54 | return dict(zip(df.interpro_id, df.go_ids)) # type: ignore
55 |
56 |
57 | class InterProEntryType(IntEnum):
58 | """InterPro types and representation counts:
59 |
60 | Family 21,942
61 | Domain 14,053
62 | Homologous_superfamily 3,446
63 | Conserved_site 728
64 | Repeat 374
65 | Active_site 133
66 | Binding_site 75
67 | PTM 17
68 | """
69 |
70 | ACTIVE_SITE = 0
71 | BINDING_SITE = auto()
72 | CONSERVED_SITE = auto()
73 | DOMAIN = auto()
74 | FAMILY = auto()
75 | HOMOLOGOUS_SUPERFAMILY = auto()
76 | PTM = auto()
77 | REPEAT = auto()
78 | UNKNOWN = auto()
79 |
80 |
81 | @dataclass
82 | class InterProEntry:
83 | """Represents an InterPro entry."""
84 |
85 | id: str # Example: IPR000006
86 | type: InterProEntryType
87 | name: str # Example: "Metallothionein, vertebrate"
88 | description: str | None = None
89 |
90 |
91 | @dataclass(frozen=True)
92 | class InterProRangeAnnotation:
93 | """Represents a InterPro annotation along a range of residues in a protein."""
94 |
95 | interpro_accession: str
96 | start_idx: int
97 | end_idx: int
98 |
99 |
100 | class InterPro:
101 | """Convenience class interacting with InterPro ontology/data."""
102 |
103 | def __init__(
104 | self,
105 | entries_path: str | None = None,
106 | hierarchy_path: str | None = None,
107 | interpro2go_path: str | None = None,
108 | ):
109 | """Constructs interface to query InterPro entries."""
110 | default = lambda x, d: x if x is not None else d
111 | self.entries_path = default(entries_path, str(C.data_root() / C.INTERPRO_ENTRY))
112 | self.hierarchy_graph_path = default(
113 | hierarchy_path, str(C.data_root() / C.INTERPRO_HIERARCHY)
114 | )
115 | self.interpro2go_path = default(
116 | interpro2go_path, str(C.data_root() / C.INTERPRO2GO)
117 | )
118 |
119 | @cached_property
120 | def interpro2go(self) -> dict[str, list[str]]:
121 | """Reads the InterPro to GO term mapping."""
122 | assert self.interpro2go_path is not None
123 | return _parse_interpro2go(self.interpro2go_path)
124 |
125 | @cached_property
126 | def entries_frame(self) -> pd.DataFrame:
127 | """Loads full InterPro entry set as a DataFrame.
128 |
129 | Colums are
130 | - "id": str interpro accession /id as
131 | - "type": InterProEntryType representing the type of annotation.
132 | - "name": Short name of the entry.
133 | """
134 | with Path(self.entries_path).open("r") as f:
135 | df = pd.read_csv(f, sep="\t")
136 | assert all(
137 | col in df.columns for col in ["ENTRY_AC", "ENTRY_TYPE", "ENTRY_NAME"]
138 | )
139 | df.rename(
140 | columns={
141 | "ENTRY_AC": "id",
142 | "ENTRY_TYPE": "type",
143 | "ENTRY_NAME": "name",
144 | },
145 | inplace=True,
146 | )
147 | df["type"] = df.type.str.upper().apply(
148 | lambda type_name: InterProEntryType[type_name]
149 | )
150 | return df
151 |
152 | @cached_property
153 | def entries(self) -> dict[str, InterProEntry]:
154 | """Returns all InterPro entries."""
155 | return {
156 | row.id: InterProEntry( # type: ignore
157 | id=row.id, # type: ignore
158 | type=row.type, # type: ignore
159 | name=row.name, # type: ignore
160 | )
161 | for row in self.entries_frame.itertuples()
162 | }
163 |
164 | def lookup_name(self, interpro_id: str) -> str | None:
165 | """Short name / title for an interpro id."""
166 | if interpro_id not in self.entries:
167 | return None
168 | return self.entries[interpro_id].name
169 |
170 | def lookup_entry_type(self, interpro_id: str) -> InterProEntryType:
171 | """Looks up entry-type for an interpro id."""
172 | if interpro_id in self.entries:
173 | return self.entries[interpro_id].type
174 | else:
175 | return InterProEntryType.UNKNOWN
176 |
177 | @cached_property
178 | def graph(self) -> nx.DiGraph:
179 | """Reads the InterPro hierarchy of InterPro."""
180 | graph = nx.DiGraph()
181 | with Path(self.hierarchy_graph_path).open("r") as f:
182 | parents = []
183 | for line in f:
184 | ipr = line.split("::", maxsplit=1)[0]
185 | ipr_strip = ipr.lstrip("-")
186 | level = (len(ipr) - len(ipr_strip)) // 2
187 | parents = parents[:level]
188 | graph.add_node(ipr_strip)
189 | if parents:
190 | graph.add_edge(ipr_strip, parents[-1])
191 | parents.append(ipr_strip)
192 | return graph
193 |
194 |
195 | def parse_interpro_features(
196 | interpro_accessions: list[str],
197 | interpro_starts: list[int],
198 | interpro_ends: list[int],
199 | ) -> list[InterProRangeAnnotation]:
200 | """Parses raw InterPro ranges.
201 |
202 | Args:
203 | interpro_accessions: list of InterPro accessions
204 | interpro_starts: list of one-indexed inclusive residue locations where the
205 | annotation from `interpro_accesisons` begin.
206 | interpro_ends: list of one-indexed *inclusive* residue locations where the
207 | annotation from `interpro_accesisons` end.
208 | Returns:
209 | Collated InterProRangeAnnotations. NOTE that index conversion will convert range
210 | bounds to zero-indexed [inclusive, exclusive) start/end indices.
211 | """
212 | assert len(interpro_accessions) == len(interpro_starts) == len(interpro_ends)
213 |
214 | # Residue locations from Uniprot/InterPro are [inclusive, inclusive] and 1-index.
215 | start_idcs = np.array(interpro_starts).astype(int)
216 | end_idcs = np.array(interpro_ends).astype(int)
217 |
218 | # We want to use Python's convention of [inclusive, exclusive) and 0-indexing.
219 | # Interpro residue indices are [inclusive, inclusive] and 1-indexing.
220 | # The conversion ends up being:
221 | # ```python
222 | # end_idcs += 1 # [inclusive, inclusive] -> [inclusive, exclusive)
223 | # start_idcs -= 1 # 1 -> 0 indexing
224 | # end_idcs -= 1 # 1 -> 0 indexing
225 | # ```
226 | # Which simply results in:
227 | start_idcs -= 1
228 |
229 | ranges = []
230 | for interpro_accession, start_idx, end_idx in zip(
231 | interpro_accessions, start_idcs, end_idcs
232 | ):
233 | # NOTE: Skip unintegrated Interpro labels, for now.
234 | if interpro_accession == "-":
235 | continue
236 |
237 | ranges.append(
238 | InterProRangeAnnotation(
239 | interpro_accession=interpro_accession,
240 | start_idx=start_idx,
241 | end_idx=end_idx,
242 | )
243 | )
244 |
245 | return ranges
246 |
--------------------------------------------------------------------------------
/src/esm/utils/function/lsh.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 |
5 | from src.esm.utils.types import PathLike
6 |
7 |
8 | class LSHTable:
9 | def __init__(self, n_bits: int, dim: int, hyperplanes: np.ndarray | None = None):
10 | if hyperplanes is None:
11 | hyperplanes = np.random.randn(n_bits, dim)
12 | hyperplanes = hyperplanes / np.linalg.norm(
13 | hyperplanes, axis=-1, keepdims=True
14 | )
15 | else:
16 | assert hyperplanes.shape == (n_bits, dim), (
17 | hyperplanes.shape,
18 | (n_bits, dim),
19 | )
20 | assert hyperplanes is not None
21 | self.hyperplanes: np.ndarray = hyperplanes
22 | self.values = 1 << np.arange(n_bits)
23 |
24 | def __call__(self, array, tokenize: bool = True):
25 | similarity = self.hyperplanes @ array.T
26 | bits = np.where(similarity >= 0, 1, 0)
27 | if tokenize:
28 | tokens = bits.T @ self.values
29 | return tokens
30 | else:
31 | return bits.T
32 |
33 |
34 | class LSHTokenized:
35 | def __init__(
36 | self,
37 | n_bits: int,
38 | dim: int,
39 | num_tables: int = 1,
40 | filepath: PathLike | None = None,
41 | allow_create_hyperplanes: bool = False, # set this if you want the lsh to allow creation of hyperplanes
42 | ):
43 | table_hyperplanes = None
44 | if filepath is not None:
45 | filepath = Path(filepath)
46 | if not filepath.exists():
47 | raise FileNotFoundError(filepath)
48 | table_hyperplanes = np.load(filepath) # type: ignore
49 | for i in range(num_tables):
50 | assert str(i) in table_hyperplanes, f"Missing hyperplane for table {i}"
51 | elif not allow_create_hyperplanes:
52 | raise RuntimeError(
53 | "Not allowed to create hyperplanes but no filepath provided"
54 | )
55 |
56 | self.tables = [
57 | LSHTable(
58 | n_bits,
59 | dim,
60 | table_hyperplanes[str(i)] if table_hyperplanes is not None else None,
61 | )
62 | for i in range(num_tables)
63 | ]
64 |
65 | def write_hyperplanes(self, filepath: PathLike):
66 | hyperplanes: dict[str, np.ndarray] = { # type: ignore
67 | str(i): table.hyperplanes for i, table in enumerate(self.tables)
68 | }
69 | np.savez(filepath, **hyperplanes)
70 |
71 | def __call__(self, array):
72 | tokens = np.stack([table(array) for table in self.tables], 1)
73 | return tokens
74 |
75 |
76 | class LSHBitstream:
77 | def __init__(
78 | self,
79 | n_bits: int,
80 | dim: int,
81 | filepath: PathLike | None = None,
82 | allow_create_hyperplanes: bool = False, # set this if you want the lsh to allow creation of hyperplanes
83 | ):
84 | table_hyperplanes = None
85 | if filepath is not None:
86 | filepath = Path(filepath)
87 | if not filepath.exists():
88 | raise FileNotFoundError(filepath)
89 | table_hyperplanes = np.load(filepath)
90 | elif not allow_create_hyperplanes:
91 | raise RuntimeError(
92 | "Not allowed to create hyperplanes but no filepath provided"
93 | )
94 |
95 | self.table = LSHTable(
96 | n_bits, dim, table_hyperplanes if table_hyperplanes is not None else None
97 | )
98 |
99 | def write_hyperplanes(self, filepath: PathLike):
100 | np.save(filepath, self.table.hyperplanes)
101 |
102 | def __call__(self, array):
103 | return self.table(array, tokenize=False)
104 |
--------------------------------------------------------------------------------
/src/esm/utils/function/tfidf.py:
--------------------------------------------------------------------------------
1 | """Term-Frequency / Inverse Document Frequency (TF-IDF) model."""
2 |
3 | from collections import Counter
4 | from functools import cached_property
5 |
6 | import numpy as np
7 | from scipy import sparse
8 |
9 |
10 | class TFIDFModel:
11 | """Term-Frequency / Inverse Document Frequency (TF-IDF) model.
12 | Mimics sklearn.feature_extraction.text.TfidfVectorizer with sublinear_tf=True
13 | """
14 |
15 | def __init__(self, vocabulary_path: str, idf_path: str):
16 | with open(vocabulary_path, "r") as f:
17 | self.vocabulary = f.read().strip().split("\n")
18 |
19 | with open(idf_path, "rb") as f:
20 | self.idf_ = np.load(f)
21 |
22 | assert self.idf_.ndim == 1
23 | assert (
24 | len(self.idf_) == len(self.vocabulary)
25 | ), f"IDF size must match vocabulary size, got {len(self.idf_)} and {len(self.vocabulary)}"
26 |
27 | @cached_property
28 | def vocab_to_index(self) -> dict[str, int]:
29 | return {term: index for index, term in enumerate(self.vocabulary)}
30 |
31 | def encode(self, terms: list[str]) -> sparse.csr_matrix:
32 | """Encodes terms as TF-IDF vectors.
33 |
34 | Args:
35 | terms: list of terms to encode.
36 |
37 | Returns:
38 | TF-IDF vector encoded as sparse matrix of shape (1, num_terms)
39 | """
40 | counter = Counter(filter(self.vocabulary.__contains__, terms))
41 | indices = [self.vocab_to_index[term] for term in counter]
42 |
43 | tf = np.array([count for term, count in counter.items()])
44 | idf = np.take(self.idf_, indices)
45 |
46 | values = (1 + np.log(tf)) * idf
47 | values /= np.linalg.norm(values)
48 |
49 | return sparse.csr_matrix(
50 | (values, (np.zeros_like(indices), indices)),
51 | shape=(1, len(self.vocabulary)),
52 | )
53 |
54 | def decode(self, vec: sparse.csr_matrix) -> list[str]:
55 | """Extract terms from TF-IDF."""
56 | return [self.vocabulary[i] for i in vec.indices]
57 |
--------------------------------------------------------------------------------
/src/esm/utils/generation.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import attr
4 | import torch
5 | from tqdm import tqdm
6 |
7 | from src.esm.sdk.api import (
8 | ESM3InferenceClient,
9 | ESMProtein,
10 | ESMProteinTensor,
11 | GenerationConfig,
12 | SamplingConfig,
13 | SamplingTrackConfig,
14 | )
15 | from src.esm.tokenization import (
16 | EsmTokenizerBase,
17 | TokenizerCollectionProtocol,
18 | )
19 | from src.esm.utils.constants import esm3 as C
20 | from src.esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY
21 |
22 |
23 | def iterative_sampling_raw(
24 | client: ESM3InferenceClient,
25 | input: ESMProtein,
26 | config: GenerationConfig,
27 | ):
28 | # Keep structure tokens
29 | input_tokens = client.encode(input)
30 |
31 | output_tokens = client.generate(input_tokens, config)
32 |
33 | raw_protein = client.decode(output_tokens)
34 |
35 | track_to_sample = config.track
36 |
37 | if track_to_sample not in ["function", "residue_annotations"]:
38 | # Function and residue annotation encoding/decoding is lossy
39 | # There is no guarantee that decoding encoded tokens will yield the same input
40 | raw_protein.function_annotations = input.function_annotations
41 |
42 | return raw_protein
43 |
44 |
45 | def iterative_sampling_tokens(
46 | client: ESM3InferenceClient,
47 | input_tokens: ESMProteinTensor,
48 | config: GenerationConfig,
49 | tokenizers: TokenizerCollectionProtocol,
50 | ) -> ESMProteinTensor:
51 | track_to_sample = config.track
52 |
53 | # Get all tracks that require sampling
54 | all_tracks = [
55 | f.name for f in attr.fields(SamplingConfig) if "embedding" not in f.name
56 | ]
57 |
58 | sequence_length = len(input_tokens)
59 | device = input_tokens.device
60 |
61 | # Initialize schedule and masks
62 | decoding_schedule = NOISE_SCHEDULE_REGISTRY[config.schedule]
63 | sampled_tokens = attr.evolve(input_tokens) # Make a copy
64 |
65 | if config.condition_on_coordinates_only and input_tokens.coordinates is not None:
66 | sampled_tokens.structure = None
67 |
68 | sampling_mask = torch.ones(
69 | sequence_length,
70 | dtype=torch.bool,
71 | device=device,
72 | )
73 | sampling_mask[0] = False
74 | sampling_mask[-1] = False
75 |
76 | get_tokenizer: Callable[[str], EsmTokenizerBase] = lambda s: getattr(tokenizers, s)
77 | if getattr(sampled_tokens, track_to_sample) is None:
78 | if track_to_sample == "function":
79 | dims = (sequence_length, tokenizers.function.depth)
80 | elif track_to_sample == "residue_annotations":
81 | dims = (sequence_length, C.MAX_RESIDUE_ANNOTATIONS)
82 | else:
83 | dims = (sequence_length,)
84 | masked_tokens = torch.full(
85 | dims,
86 | get_tokenizer(track_to_sample).mask_token_id,
87 | dtype=torch.long,
88 | device=device,
89 | )
90 | if track_to_sample == "sequence":
91 | masked_tokens[0] = tokenizers.sequence.cls_token_id # type: ignore
92 | masked_tokens[-1] = tokenizers.sequence.eos_token_id # type: ignore
93 | else:
94 | masked_tokens[0] = get_tokenizer(track_to_sample).bos_token_id
95 | masked_tokens[-1] = get_tokenizer(track_to_sample).eos_token_id
96 |
97 | setattr(
98 | sampled_tokens,
99 | track_to_sample,
100 | masked_tokens,
101 | )
102 | else:
103 | is_mask: torch.Tensor = (
104 | getattr(input_tokens, track_to_sample)
105 | == get_tokenizer(track_to_sample).mask_token_id
106 | )
107 | if not is_mask.any().item():
108 | raise ValueError(f"Cannot sample {config.track} when input has no masks.")
109 | sampling_mask = sampling_mask & is_mask
110 |
111 | # Decode
112 |
113 | def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None:
114 | return x.clone() if x is not None else None
115 |
116 | L = sequence_length - 2
117 | positions_sampled = 0
118 | for t in tqdm(range(config.num_steps)):
119 | # Single step sampling at all positions
120 | track_sample_config = SamplingTrackConfig()
121 | track_sample_config.invalid_ids = config.invalid_ids
122 | track_sample_config.temperature = config.temperature
123 | track_sample_config.top_p = config.top_p
124 | sampling_config = SamplingConfig(**{track_to_sample: track_sample_config}) # type: ignore
125 |
126 | forward_and_sample_output = client.forward_and_sample(
127 | sampled_tokens, sampling_config
128 | )
129 | new_samples = forward_and_sample_output.protein_tensor
130 |
131 | # Calculate number of tokens to sample
132 | perc_masked = decoding_schedule(torch.tensor((t + 1) / config.num_steps))
133 | num_to_sample = int((1 - perc_masked) * L) - positions_sampled
134 | positions_sampled += num_to_sample
135 |
136 | # Select tokens based on lowest entropy
137 | if track_to_sample in ["function", "residue_annotations"]:
138 | # TODO: Implement iterative decoding for function and residue_annotations
139 | # TODO: Fix encode/decode of interpro tokens (not yet supported)
140 | sampled_tokens.function = maybe_clone(input_tokens.function)
141 | sampled_tokens.residue_annotations = maybe_clone(
142 | input_tokens.residue_annotations
143 | )
144 | if track_to_sample in track_to_sample:
145 | raise NotImplementedError(
146 | f"Iterative decoding for {track_to_sample} is not supported yet."
147 | )
148 | continue
149 |
150 | sampling_mask = sampling_mask & (
151 | getattr(sampled_tokens, track_to_sample)
152 | == get_tokenizer(track_to_sample).mask_token_id
153 | )
154 |
155 | track_entropy: torch.Tensor = getattr(
156 | forward_and_sample_output.entropy, track_to_sample
157 | )
158 | track_entropy = track_entropy.masked_fill(
159 | ~sampling_mask, torch.finfo(track_entropy.dtype).max
160 | )
161 | _, indices = track_entropy.topk(num_to_sample, dim=-1, largest=False)
162 | is_top_k = ~(
163 | torch.arange(sequence_length, device=device)[:, None] != indices[None, :]
164 | ).all(-1)
165 | tokens_to_sample = sampling_mask & is_top_k
166 |
167 | old_track_samples = getattr(sampled_tokens, track_to_sample)
168 | new_track_samples = getattr(new_samples, track_to_sample)
169 |
170 | new_track_samples = torch.where(
171 | tokens_to_sample, new_track_samples, old_track_samples
172 | )
173 |
174 | setattr(sampled_tokens, track_to_sample, new_track_samples)
175 |
176 | # Do not update tracks that were not sampled (e.g. keep None instead of masks)
177 | for track in all_tracks:
178 | if track != track_to_sample:
179 | setattr(
180 | sampled_tokens,
181 | track,
182 | maybe_clone(getattr(input_tokens, track)),
183 | )
184 |
185 | return sampled_tokens
186 |
--------------------------------------------------------------------------------
/src/esm/utils/misc.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import ContextManager, Sequence, TypeVar
3 |
4 | import numpy as np
5 | import torch
6 |
7 | MAX_SUPPORTED_DISTANCE = 1e6
8 |
9 |
10 | TSequence = TypeVar("TSequence", bound=Sequence)
11 |
12 |
13 | def slice_python_object_as_numpy(
14 | obj: TSequence, idx: int | list[int] | slice | np.ndarray
15 | ) -> TSequence:
16 | """
17 | Slice a python object (like a list, string, or tuple) as if it was a numpy object.
18 |
19 | Example:
20 | >>> obj = "ABCDE"
21 | >>> slice_python_object_as_numpy(obj, [1, 3, 4])
22 | "BDE"
23 |
24 | >>> obj = [1, 2, 3, 4, 5]
25 | >>> slice_python_object_as_numpy(obj, np.arange(5) < 3)
26 | [1, 2, 3]
27 | """
28 | if isinstance(idx, int):
29 | idx = [idx]
30 |
31 | if isinstance(idx, np.ndarray) and idx.dtype == bool:
32 | sliced_obj = [obj[i] for i in np.where(idx)[0]]
33 | elif isinstance(idx, slice):
34 | sliced_obj = obj[idx]
35 | else:
36 | sliced_obj = [obj[i] for i in idx]
37 |
38 | match obj, sliced_obj:
39 | case str(), list():
40 | sliced_obj = "".join(sliced_obj)
41 | case _:
42 | sliced_obj = obj.__class__(sliced_obj) # type: ignore
43 |
44 | return sliced_obj # type: ignore
45 |
46 |
47 | def rbf(values, v_min, v_max, n_bins=16):
48 | """
49 | Returns RBF encodings in a new dimension at the end.
50 | """
51 | rbf_centers = torch.linspace(
52 | v_min, v_max, n_bins, device=values.device, dtype=values.dtype
53 | )
54 | rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1])
55 | rbf_std = (v_max - v_min) / n_bins
56 | z = (values.unsqueeze(-1) - rbf_centers) / rbf_std
57 | return torch.exp(-(z**2))
58 |
59 |
60 | def batched_gather(data, inds, dim=0, no_batch_dims=0):
61 | ranges = []
62 | for i, s in enumerate(data.shape[:no_batch_dims]):
63 | r = torch.arange(s)
64 | r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
65 | ranges.append(r)
66 |
67 | remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
68 | remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
69 | ranges.extend(remaining_dims)
70 | return data[ranges]
71 |
72 |
73 | def node_gather(s: torch.Tensor, edges: torch.Tensor) -> torch.Tensor:
74 | return batched_gather(s.unsqueeze(-3), edges, -2, no_batch_dims=len(s.shape) - 1)
75 |
76 |
77 | def knn_graph(
78 | coords: torch.Tensor,
79 | coord_mask: torch.Tensor,
80 | padding_mask: torch.Tensor,
81 | sequence_id: torch.Tensor,
82 | *,
83 | no_knn: int,
84 | ):
85 | L = coords.shape[-2]
86 | num_by_dist = min(no_knn, L)
87 | device = coords.device
88 |
89 | coords = coords.nan_to_num()
90 | coord_mask = ~(coord_mask[..., None, :] & coord_mask[..., :, None])
91 | padding_pairwise_mask = padding_mask[..., None, :] | padding_mask[..., :, None]
92 | if sequence_id is not None:
93 | padding_pairwise_mask |= torch.unsqueeze(sequence_id, 1) != torch.unsqueeze(
94 | sequence_id, 2
95 | )
96 | dists = (coords.unsqueeze(-2) - coords.unsqueeze(-3)).norm(dim=-1)
97 | arange = torch.arange(L, device=device)
98 | seq_dists = (arange.unsqueeze(-1) - arange.unsqueeze(-2)).abs()
99 | # We only support up to a certain distance, above that, we use sequence distance
100 | # instead. This is so that when a large portion of the structure is masked out,
101 | # the edges are built according to sequence distance.
102 | max_dist = MAX_SUPPORTED_DISTANCE
103 | torch._assert_async((dists[~coord_mask] < max_dist).all())
104 | struct_then_seq_dist = (
105 | seq_dists.to(dists.dtype)
106 | .mul(1e2)
107 | .add(max_dist)
108 | .where(coord_mask, dists)
109 | .masked_fill(padding_pairwise_mask, torch.inf)
110 | )
111 | dists, edges = struct_then_seq_dist.sort(dim=-1, descending=False)
112 | # This is a L x L tensor, where we index by rows first,
113 | # and columns are the edges we should pick.
114 | chosen_edges = edges[..., :num_by_dist]
115 | chosen_mask = dists[..., :num_by_dist].isfinite()
116 | return chosen_edges, chosen_mask
117 |
118 |
119 | def stack_variable_length_tensors(
120 | sequences: Sequence[torch.Tensor],
121 | constant_value: int | float = 0,
122 | dtype: torch.dtype | None = None,
123 | ) -> torch.Tensor:
124 | """Automatically stack tensors together, padding variable lengths with the
125 | value in constant_value. Handles an arbitrary number of dimensions.
126 |
127 | Examples:
128 | >>> tensor1, tensor2 = torch.ones([2]), torch.ones([5])
129 | >>> stack_variable_length_tensors(tensor1, tensor2)
130 | tensor of shape [2, 5]. First row is [1, 1, 0, 0, 0]. Second row is all ones.
131 |
132 | >>> tensor1, tensor2 = torch.ones([2, 4]), torch.ones([5, 3])
133 | >>> stack_variable_length_tensors(tensor1, tensor2)
134 | tensor of shape [2, 5, 4]
135 | """
136 | batch_size = len(sequences)
137 | shape = [batch_size] + np.max([seq.shape for seq in sequences], 0).tolist()
138 |
139 | if dtype is None:
140 | dtype = sequences[0].dtype
141 | device = sequences[0].device
142 |
143 | array = torch.full(shape, constant_value, dtype=dtype, device=device)
144 | for arr, seq in zip(array, sequences):
145 | arrslice = tuple(slice(dim) for dim in seq.shape)
146 | arr[arrslice] = seq
147 |
148 | return array
149 |
150 |
151 | def unbinpack(
152 | tensor: torch.Tensor, sequence_id: torch.Tensor | None, pad_value: int | float
153 | ):
154 | """
155 | Args:
156 | tensor (Tensor): [B, L, ...]
157 |
158 | Returns:
159 | Tensor: [B_unbinpacked, L_unbinpack, ...]
160 | """
161 | if sequence_id is None:
162 | return tensor
163 |
164 | unpacked_tensors = []
165 | num_sequences = sequence_id.max(dim=-1).values + 1
166 | for batch_idx, (batch_seqid, batch_num_sequences) in enumerate(
167 | zip(sequence_id, num_sequences)
168 | ):
169 | for seqid in range(batch_num_sequences):
170 | mask = batch_seqid == seqid
171 | unpacked = tensor[batch_idx, mask]
172 | unpacked_tensors.append(unpacked)
173 | return stack_variable_length_tensors(unpacked_tensors, pad_value)
174 |
175 |
176 | def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast]:
177 | """
178 | Returns an autocast context manager that disables downcasting by AMP.
179 |
180 | Args:
181 | device_type: The device type ('cpu' or 'cuda')
182 |
183 | Returns:
184 | An autocast context manager with the specified behavior.
185 | """
186 | if device_type == "cpu":
187 | return torch.amp.autocast(device_type, enabled=False)
188 | elif device_type == "cuda":
189 | return torch.amp.autocast(device_type, dtype=torch.float32)
190 | else:
191 | raise ValueError(f"Unsupported device type: {device_type}")
192 |
193 |
194 | def merge_ranges(ranges: list[range], merge_gap_max: int | None = None) -> list[range]:
195 | """Merge overlapping ranges into sorted, non-overlapping segments.
196 |
197 | Args:
198 | ranges: collection of ranges to merge.
199 | merge_gap_max: optionally merge neighboring ranges that are separated by a gap
200 | no larger than this size.
201 | Returns:
202 | non-overlapping ranges merged from the inputs, sorted by position.
203 | """
204 | ranges = sorted(ranges, key=lambda r: r.start)
205 | merge_gap_max = merge_gap_max if merge_gap_max is not None else 0
206 | assert merge_gap_max >= 0, f"Invalid merge_gap_max: {merge_gap_max}"
207 |
208 | merged = []
209 | for r in ranges:
210 | if not merged:
211 | merged.append(r)
212 | else:
213 | last = merged[-1]
214 | if last.stop + merge_gap_max >= r.start:
215 | merged[-1] = range(last.start, max(last.stop, r.stop))
216 | else:
217 | merged.append(r)
218 | return merged
219 |
220 |
221 | def list_nan_to_none(l: list) -> list:
222 | if l is None:
223 | return None # type: ignore
224 | elif isinstance(l, float):
225 | return None if math.isnan(l) else l # type: ignore
226 | elif isinstance(l, list):
227 | return [list_nan_to_none(x) for x in l]
228 | else:
229 | # Don't go into other structures.
230 | return l
231 |
232 |
233 | def list_none_to_nan(l: list) -> list:
234 | if l is None:
235 | return math.nan # type: ignore
236 | elif isinstance(l, list):
237 | return [list_none_to_nan(x) for x in l]
238 | else:
239 | return l
240 |
241 |
242 | def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None:
243 | if x is None:
244 | return None
245 | if convert_none_to_nan:
246 | x = list_none_to_nan(x)
247 | return torch.tensor(x)
248 |
249 |
250 | def maybe_list(x, convert_nan_to_none: bool = False) -> list | None:
251 | if x is None:
252 | return None
253 | x = x.tolist()
254 | if convert_nan_to_none:
255 | x = list_nan_to_none(x)
256 | return x
257 |
--------------------------------------------------------------------------------
/src/esm/utils/noise_schedules.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 |
5 |
6 | def cosine_schedule(t: torch.Tensor):
7 | # t is a tensor of size (batch_size,) with values between 0 and 1. This is the
8 | # schedule used in the MaskGIT paper
9 | return torch.cos(t * math.pi * 0.5)
10 |
11 |
12 | def cubic_schedule(t):
13 | return 1 - t**3
14 |
15 |
16 | def linear_schedule(t):
17 | return 1 - t
18 |
19 |
20 | def square_root_schedule(t):
21 | return 1 - torch.sqrt(t)
22 |
23 |
24 | def square_schedule(t):
25 | return 1 - t**2
26 |
27 |
28 | NOISE_SCHEDULE_REGISTRY = {
29 | "cosine": cosine_schedule,
30 | "linear": linear_schedule,
31 | "square_root_schedule": square_root_schedule,
32 | "cubic": cubic_schedule,
33 | "square": square_schedule,
34 | }
35 |
--------------------------------------------------------------------------------
/src/esm/utils/residue_constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 AlQuraishi Laboratory
2 | # Copyright 2021 DeepMind Technologies Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # This mapping is used when we need to store atom data in a format that requires
17 | # fixed atom data size for every residue (e.g. a numpy array).
18 | atom_types = [
19 | "N",
20 | "CA",
21 | "C",
22 | "CB",
23 | "O",
24 | "CG",
25 | "CG1",
26 | "CG2",
27 | "OG",
28 | "OG1",
29 | "SG",
30 | "CD",
31 | "CD1",
32 | "CD2",
33 | "ND1",
34 | "ND2",
35 | "OD1",
36 | "OD2",
37 | "SD",
38 | "CE",
39 | "CE1",
40 | "CE2",
41 | "CE3",
42 | "NE",
43 | "NE1",
44 | "NE2",
45 | "OE1",
46 | "OE2",
47 | "CH2",
48 | "NH1",
49 | "NH2",
50 | "OH",
51 | "CZ",
52 | "CZ2",
53 | "CZ3",
54 | "NZ",
55 | "OXT",
56 | ]
57 | atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
58 | atom_type_num = len(atom_types) # := 37.
59 |
60 | restype_1to3 = {
61 | "A": "ALA",
62 | "R": "ARG",
63 | "N": "ASN",
64 | "D": "ASP",
65 | "C": "CYS",
66 | "Q": "GLN",
67 | "E": "GLU",
68 | "G": "GLY",
69 | "H": "HIS",
70 | "I": "ILE",
71 | "L": "LEU",
72 | "K": "LYS",
73 | "M": "MET",
74 | "F": "PHE",
75 | "P": "PRO",
76 | "S": "SER",
77 | "T": "THR",
78 | "W": "TRP",
79 | "Y": "TYR",
80 | "V": "VAL",
81 | }
82 |
--------------------------------------------------------------------------------
/src/esm/utils/sampling.py:
--------------------------------------------------------------------------------
1 | import attr
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from src.esm.sdk.api import (
6 | SamplingConfig,
7 | SamplingTrackConfig,
8 | )
9 | from src.esm.tokenization import (
10 | TokenizerCollection,
11 | get_invalid_tokenizer_ids,
12 | )
13 | from src.esm.tokenization.function_tokenizer import (
14 | InterProQuantizedTokenizer,
15 | )
16 | from src.esm.utils.constants.esm3 import MAX_RESIDUE_ANNOTATIONS
17 |
18 |
19 | def get_default_sampling_config(tokenizers: TokenizerCollection) -> SamplingConfig:
20 | tracks = [f.name for f in attr.fields(SamplingConfig)]
21 | sampling_config = SamplingConfig()
22 | for current_track in tracks:
23 | setattr(
24 | sampling_config,
25 | current_track,
26 | SamplingTrackConfig(
27 | invalid_ids=get_invalid_tokenizer_ids(
28 | getattr(tokenizers, current_track)
29 | ),
30 | temperature=1.0,
31 | top_p=1.0,
32 | # TODO: Add different mask and padding tokens for all tracks
33 | # Some tracks have the same pad and mask, which causes ambiguity when sampling
34 | only_sample_masked_tokens=current_track
35 | not in ["secondary_structure", "sasa", "function"],
36 | ),
37 | )
38 | return sampling_config
39 |
40 |
41 | def sample_logits(
42 | logits: torch.Tensor,
43 | temperature: float | torch.Tensor,
44 | top_p: float | torch.Tensor = 1.0,
45 | ):
46 | """Default sampling from logits.
47 |
48 | Args:
49 | logits is shape (..., vocab_size)
50 | temperature is broadcastable to (...)
51 | """
52 |
53 | if top_p < 1.0:
54 | logits = top_p_logits(logits, top_p=top_p)
55 |
56 | temperature = _tensorize_like(temperature, logits)
57 |
58 | if torch.all(temperature == 0):
59 | ids = logits.argmax(-1)
60 | return ids
61 |
62 | assert not torch.any(temperature == 0), "Partial temperature 0 not supported."
63 |
64 | batch_dims = logits.size()[:-1]
65 | logits = logits.reshape(-1, logits.shape[-1])
66 |
67 | # Sample from all logits
68 | probs = F.softmax(logits / temperature[..., None], dim=-1)
69 | ids = torch.multinomial(probs, 1).squeeze(1)
70 |
71 | ids = ids.reshape(*batch_dims)
72 | return ids
73 |
74 |
75 | def sample_function_logits(
76 | logits: torch.Tensor,
77 | tokenizer: InterProQuantizedTokenizer,
78 | top_p: float | torch.Tensor = 1.0,
79 | temperature: float | torch.Tensor = 1.0,
80 | p_none_threshold: float = 0.05,
81 | ) -> tuple[torch.Tensor, torch.Tensor]:
82 | [L, D, V] = logits.shape
83 | assert D == tokenizer.depth
84 |
85 | if top_p < 1.0:
86 | logits = top_p_logits(logits, top_p=top_p)
87 |
88 | temperature = torch.ones_like(logits[..., 0]) * temperature
89 |
90 | log_p = F.log_softmax(logits / temperature[..., None], dim=-1) # (L, D, V)
91 |
92 | # Choose which positions have no predicted function.
93 | log_p_nones = log_p[..., tokenizer.vocab_to_index[""]] # (L, D)
94 | p_none = torch.exp(log_p_nones).mean(dim=-1) # "Ensemble of predictions"
95 | where_none = p_none > p_none_threshold # (L, )
96 |
97 | # Set probability of to 0 for all not-none positions
98 | none_index = tokenizer.vocab_to_index[""]
99 | log_p[~where_none, :, none_index] = -torch.inf
100 |
101 | ids = torch.argmax(log_p, dim=-1) # (L, D)
102 | ids[where_none, :] = tokenizer.vocab_to_index[""]
103 |
104 | return ids, log_p
105 |
106 |
107 | def sample_residue_annotation_logits(
108 | logits: torch.Tensor, annotation_threshold: float = 0.5
109 | ) -> tuple[torch.Tensor, torch.Tensor]:
110 | # Take top residue annotations
111 | top_residue_annotations_idx = logits.argsort(dim=-1, descending=True)[
112 | ..., :MAX_RESIDUE_ANNOTATIONS
113 | ] # (L, MAX_R)
114 | top_residue_annotations_logprobs = torch.gather(
115 | F.logsigmoid(logits), -1, top_residue_annotations_idx
116 | ) # (L, MAX_R)
117 | top_residue_annotations_probs = top_residue_annotations_logprobs.exp()
118 | # Keep only positive predictions
119 | is_negative = top_residue_annotations_probs < annotation_threshold
120 | top_residue_annotations_idx[is_negative] = 0
121 |
122 | top_residue_annotations_logprobs = top_residue_annotations_logprobs
123 |
124 | return top_residue_annotations_idx, top_residue_annotations_logprobs
125 |
126 |
127 | def top_p_logits(
128 | logits: torch.Tensor,
129 | top_p: float | torch.Tensor,
130 | ) -> torch.Tensor:
131 | top_p = _tensorize_like(top_p, logits)
132 |
133 | batch_dims = logits.size()[:-1]
134 | logits = logits.reshape(-1, logits.shape[-1])
135 |
136 | # Sort logits in descending order and extract the mask for the top_p
137 | sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)
138 | cumsum_logits = sorted_logits.softmax(-1).cumsum(-1)
139 | top_p_mask = cumsum_logits <= top_p[:, None]
140 |
141 | # Make sure at least one token is sampled
142 | top_p_mask[:, 0] = True
143 |
144 | # Mask out the logits that are not in the top_p
145 | batch_indices_to_mask, _ = torch.where(~top_p_mask)
146 | vocab_indices_to_mask = sorted_indices[~top_p_mask]
147 | logits[batch_indices_to_mask, vocab_indices_to_mask] = torch.finfo(logits.dtype).min
148 |
149 | return logits.reshape(*batch_dims, -1)
150 |
151 |
152 | def _tensorize_like(value: int | float | torch.Tensor, logits: torch.Tensor):
153 | if isinstance(value, (float, int)):
154 | value = torch.full_like(logits[..., 0], value, dtype=logits.dtype)
155 | return value.to(logits.device).expand_as(logits[..., 0]).reshape(-1)
156 |
--------------------------------------------------------------------------------
/src/esm/utils/structure/aligner.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import replace
4 | from typing import TYPE_CHECKING
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from src.esm.utils.structure.protein_structure import (
10 | compute_affine_and_rmsd,
11 | )
12 |
13 | if TYPE_CHECKING:
14 | from src.esm.utils.structure.protein_chain import ProteinChain
15 |
16 |
17 | class Aligner:
18 | def __init__(
19 | self,
20 | mobile: ProteinChain,
21 | target: ProteinChain,
22 | only_use_backbone: bool = False,
23 | use_reflection: bool = False,
24 | ):
25 | """
26 | Aligns a mobile protein chain against a target protein chain.
27 |
28 | Args:
29 | mobile (ProteinChain): Protein chain to be aligned.
30 | target (ProteinChain): Protein chain target.
31 | only_use_backbone (bool): Whether to only use backbone atoms.
32 | use_reflection (bool): Whether to align to target reflection.
33 | """
34 | # Check proteins must have same number of residues
35 | assert len(mobile) == len(target)
36 |
37 | # Determine overlapping atoms
38 | joint_atom37_mask = mobile.atom37_mask.astype(bool) & target.atom37_mask.astype(
39 | bool
40 | )
41 |
42 | # Backbone atoms are first sites in atom37 representation
43 | if only_use_backbone:
44 | joint_atom37_mask[:, 3:] = False
45 |
46 | # Extract matching atom positions and convert to batched tensors
47 | mobile_atom_tensor = (
48 | torch.from_numpy(mobile.atom37_positions).type(torch.double).unsqueeze(0)
49 | )
50 | target_atom_tensor = (
51 | torch.from_numpy(target.atom37_positions).type(torch.double).unsqueeze(0)
52 | )
53 | joint_atom37_mask = (
54 | torch.from_numpy(joint_atom37_mask).type(torch.bool).unsqueeze(0)
55 | )
56 |
57 | # If using reflection flip target
58 | if use_reflection:
59 | target_atom_tensor = -target_atom_tensor
60 |
61 | # Compute alignment and rmsd
62 | affine3D, rmsd = compute_affine_and_rmsd(
63 | mobile_atom_tensor, target_atom_tensor, atom_exists_mask=joint_atom37_mask
64 | )
65 | self._affine3D = affine3D
66 | self._rmsd = rmsd.item()
67 |
68 | @property
69 | def rmsd(self):
70 | return self._rmsd
71 |
72 | def apply(self, mobile: ProteinChain) -> ProteinChain:
73 | """Apply alignment to a protein chain"""
74 | # Extract atom positions and convert to batched tensors
75 | mobile_atom_tensor = (
76 | torch.from_numpy(mobile.atom37_positions[mobile.atom37_mask])
77 | .type(torch.float32)
78 | .unsqueeze(0)
79 | )
80 |
81 | # Transform atom arrays
82 | aligned_atom_tensor = self._affine3D.apply(mobile_atom_tensor).squeeze(0)
83 |
84 | # Rebuild atom37 positions
85 | aligned_atom37_positions = np.full_like(mobile.atom37_positions, np.nan)
86 | aligned_atom37_positions[mobile.atom37_mask] = aligned_atom_tensor
87 |
88 | return replace(mobile, atom37_positions=aligned_atom37_positions)
89 |
--------------------------------------------------------------------------------
/src/esm/utils/structure/lddt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange
3 |
4 | from src.esm.utils import residue_constants as RC
5 |
6 |
7 | def compute_lddt(
8 | all_atom_pred_pos: torch.Tensor,
9 | all_atom_positions: torch.Tensor,
10 | all_atom_mask: torch.Tensor,
11 | cutoff: float = 15.0,
12 | eps: float = 1e-10,
13 | per_residue: bool = True,
14 | ) -> torch.Tensor:
15 | """
16 | Computes LDDT for a protein. Tensor sizes below include some optional dimensions. Specifically:
17 | Nstates:
18 | all_atom_pred_pos can contain multiple states in the first dimension which corresponds to outputs from different layers of a model (e.g. each IPA block). The return size will be [Nstates x Batch size] if this is included.
19 | Natoms:
20 | LDDT can be computed for all atoms or some atoms. The second to last dimension should contain the *FLATTENED* representation of L x Natoms. If you want to calculate for atom37, e.g., this will be of size (L * 37). If you are only calculating CA LDDT, it will be of size L.
21 |
22 | Args:
23 | all_atom_pred_pos (Tensor[float], [(Nstates x) B x (L * Natoms x) 3]): Tensor of predicted positions
24 | all_atom_positions (Tensor[float], [B x (L * Natoms x) 3]): Tensor of true positions
25 | all_atom_mask (Tensor[float], [B x (L * Natoms)]): Tensor of masks, indicating whether an atom exists.
26 | cutoff (float): Max distance to score lddt over.
27 | per_residue (bool): Whether to return per-residue or full-protein lddt.
28 |
29 | Returns:
30 | LDDT Tensor:
31 | if per_residue:
32 | Tensor[float], [(Nstates x) B x (L * Natoms)]
33 | else:
34 | Tensor[float], [(Nstates x) B]
35 | """
36 | n = all_atom_mask.shape[-2]
37 | dmat_true = torch.sqrt(
38 | eps
39 | + torch.sum(
40 | (all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :])
41 | ** 2,
42 | dim=-1,
43 | )
44 | )
45 |
46 | dmat_pred = torch.sqrt(
47 | eps
48 | + torch.sum(
49 | (all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :]) ** 2,
50 | dim=-1,
51 | )
52 | )
53 | dists_to_score = (
54 | (dmat_true < cutoff)
55 | * all_atom_mask
56 | * rearrange(all_atom_mask, "... a b -> ... b a")
57 | * (1.0 - torch.eye(n, device=all_atom_mask.device))
58 | )
59 |
60 | dist_l1 = torch.abs(dmat_true - dmat_pred)
61 |
62 | score = (
63 | (dist_l1 < 0.5).type(dist_l1.dtype)
64 | + (dist_l1 < 1.0).type(dist_l1.dtype)
65 | + (dist_l1 < 2.0).type(dist_l1.dtype)
66 | + (dist_l1 < 4.0).type(dist_l1.dtype)
67 | )
68 | score = score * 0.25
69 |
70 | dims = (-1,) if per_residue else (-2, -1)
71 | norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
72 | score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
73 |
74 | return score
75 |
76 |
77 | def compute_lddt_ca(
78 | all_atom_pred_pos: torch.Tensor,
79 | all_atom_positions: torch.Tensor,
80 | all_atom_mask: torch.Tensor,
81 | cutoff: float = 15.0,
82 | eps: float = 1e-10,
83 | per_residue: bool = True,
84 | ) -> torch.Tensor:
85 | ca_pos = RC.atom_order["CA"]
86 | if all_atom_pred_pos.dim() != 3:
87 | all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
88 | all_atom_positions = all_atom_positions[..., ca_pos, :]
89 | all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
90 |
91 | return compute_lddt(
92 | all_atom_pred_pos,
93 | all_atom_positions,
94 | all_atom_mask,
95 | cutoff=cutoff,
96 | eps=eps,
97 | per_residue=per_residue,
98 | )
99 |
--------------------------------------------------------------------------------
/src/esm/utils/structure/normalize_coordinates.py:
--------------------------------------------------------------------------------
1 | from typing import TypeVar
2 |
3 | import numpy as np
4 | import torch
5 | from torch import Tensor
6 |
7 | from src.esm.utils import residue_constants as RC
8 | from src.esm.utils.structure.affine3d import Affine3D
9 |
10 | ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor)
11 |
12 |
13 | def atom3_to_backbone_frames(bb_positions: torch.Tensor) -> Affine3D:
14 | N, CA, C = bb_positions.unbind(dim=-2)
15 | return Affine3D.from_graham_schmidt(C, CA, N)
16 |
17 |
18 | def index_by_atom_name(
19 | atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2
20 | ) -> ArrayOrTensor:
21 | squeeze = False
22 | if isinstance(atom_names, str):
23 | atom_names = [atom_names]
24 | squeeze = True
25 | indices = [RC.atom_order[atom_name] for atom_name in atom_names]
26 | dim = dim % atom37.ndim
27 | index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim))
28 | result = atom37[index] # type: ignore
29 | if squeeze:
30 | result = result.squeeze(dim)
31 | return result
32 |
33 |
34 | def get_protein_normalization_frame(coords: Tensor) -> Affine3D:
35 | """Given a set of coordinates for a protein, compute a single frame that can be used to normalize the coordinates.
36 | Specifically, we compute the average position of the N, CA, and C atoms use those 3 points to construct a frame
37 | using the Gram-Schmidt algorithm. The average CA position is used as the origin of the frame.
38 |
39 | Args:
40 | coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates
41 |
42 | Returns:
43 | Affine3D: tensor of Affine3D frame
44 | """
45 | bb_coords = index_by_atom_name(coords, ["N", "CA", "C"], dim=-2)
46 | coord_mask = torch.all(
47 | torch.all(torch.isfinite(bb_coords), dim=-1),
48 | dim=-1,
49 | )
50 |
51 | average_position_per_n_ca_c = bb_coords.masked_fill(
52 | ~coord_mask[..., None, None], 0
53 | ).sum(-3) / (coord_mask.sum(-1)[..., None, None] + 1e-8)
54 | frame = atom3_to_backbone_frames(average_position_per_n_ca_c.float())
55 |
56 | return frame
57 |
58 |
59 | def apply_frame_to_coords(coords: Tensor, frame: Affine3D) -> Tensor:
60 | """Given a set of coordinates and a single frame, apply the frame to the coordinates.
61 |
62 | Args:
63 | coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates
64 | frame (Affine3D): Affine3D frame
65 |
66 | Returns:
67 | torch.FloatTensor: [L, 37, 3] tensor of transformed coordinates
68 | """
69 | coords_trans_rot = frame[..., None, None].invert().apply(coords)
70 |
71 | # only transform coordinates with frame that have a valid rotation
72 | valid_frame = frame.trans.norm(dim=-1) > 0
73 |
74 | is_inf = torch.isinf(coords)
75 | coords = coords_trans_rot.where(valid_frame[..., None, None, None], coords)
76 | coords.masked_fill_(is_inf, torch.inf)
77 |
78 | return coords
79 |
80 |
81 | def normalize_coordinates(coords: Tensor) -> Tensor:
82 | return apply_frame_to_coords(coords, get_protein_normalization_frame(coords))
83 |
--------------------------------------------------------------------------------
/src/esm/utils/structure/predicted_aligned_error.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from src.esm.utils.structure.affine3d import Affine3D
5 |
6 |
7 | def masked_mean(
8 | mask: torch.Tensor,
9 | value: torch.Tensor,
10 | dim: int | None | tuple[int, ...] = None,
11 | eps=1e-10,
12 | ) -> torch.Tensor:
13 | """Compute the mean of `value` where only positions where `mask == true` are
14 | counted.
15 | """
16 | mask = mask.expand(*value.shape)
17 | return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
18 |
19 |
20 | def _pae_bins(
21 | max_bin: float = 31, num_bins: int = 64, device: torch.device = torch.device("cpu")
22 | ):
23 | bins = torch.linspace(0, max_bin, steps=(num_bins - 1), device=device)
24 | step = max_bin / (num_bins - 2)
25 | bin_centers = bins + step / 2
26 | bin_centers = torch.cat(
27 | [bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0
28 | )
29 | return bin_centers
30 |
31 |
32 | def _compute_pae_masks(mask: torch.Tensor):
33 | square_mask = (mask.unsqueeze(-1) * mask.unsqueeze(-2)).bool()
34 | return square_mask
35 |
36 |
37 | def compute_predicted_aligned_error(
38 | logits: torch.Tensor,
39 | aa_mask: torch.Tensor,
40 | sequence_id: torch.Tensor | None = None,
41 | max_bin: float = 31,
42 | ) -> torch.Tensor:
43 | bins = _pae_bins(max_bin, logits.shape[-1], logits.device)
44 | square_mask = _compute_pae_masks(aa_mask)
45 | min_v = torch.finfo(logits.dtype).min
46 | probs = logits.masked_fill(~square_mask.unsqueeze(-1), min_v).softmax(dim=-1)
47 |
48 | return (probs * bins).sum(dim=-1)
49 |
50 |
51 | @torch.no_grad
52 | def compute_tm(
53 | logits: torch.Tensor,
54 | aa_mask: torch.Tensor,
55 | max_bin: float = 31.0,
56 | ):
57 | square_mask = _compute_pae_masks(aa_mask)
58 | seqlens = aa_mask.sum(-1, keepdim=True)
59 | bins = _pae_bins(max_bin, logits.shape[-1], logits.device)
60 | d0 = 1.24 * (seqlens.clamp_min(19) - 15) ** (1 / 3) - 1.8
61 | f_d = 1.0 / (1 + (bins / d0.unsqueeze(-1)) ** 2)
62 |
63 | min_v = torch.finfo(logits.dtype).min
64 | probs = logits.masked_fill(~square_mask.unsqueeze(-1), min_v).softmax(dim=-1)
65 | # This is the sum over bins
66 | ptm = (probs * f_d.unsqueeze(-2)).sum(dim=-1)
67 | # This is the mean over residues j
68 | ptm = masked_mean(square_mask, ptm, dim=-1)
69 | # The we do a max over residues i
70 | return ptm.max(dim=-1).values
71 |
72 |
73 | def tm_loss(
74 | logits: torch.Tensor,
75 | pred_affine: torch.Tensor,
76 | targ_affine: torch.Tensor,
77 | targ_mask: torch.Tensor,
78 | tm_mask: torch.Tensor | None = None,
79 | sequence_id: torch.Tensor | None = None,
80 | max_bin: float = 31,
81 | ):
82 | pred = Affine3D.from_tensor(pred_affine)
83 | targ = Affine3D.from_tensor(targ_affine)
84 |
85 | def transform(affine: Affine3D):
86 | pts = affine.trans[..., None, :, :]
87 | return affine.invert()[..., None].apply(pts)
88 |
89 | with torch.no_grad():
90 | sq_diff = (transform(pred) - transform(targ)).square().sum(dim=-1)
91 |
92 | num_bins = logits.shape[-1]
93 | sq_bins = torch.linspace(
94 | 0, max_bin, num_bins - 1, device=logits.device
95 | ).square()
96 | # Gets the bin id by using a sum.
97 | true_bins = (sq_diff[..., None] > sq_bins).sum(dim=-1).long()
98 |
99 | errors = F.cross_entropy(logits.movedim(3, 1), true_bins, reduction="none")
100 | square_mask = _compute_pae_masks(targ_mask)
101 | loss = masked_mean(square_mask, errors, dim=(-1, -2))
102 |
103 | if tm_mask is not None:
104 | loss = masked_mean(tm_mask, loss, dim=None)
105 | else:
106 | loss = loss.mean()
107 |
108 | return loss
109 |
--------------------------------------------------------------------------------
/src/esm/utils/types.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import io
4 | from dataclasses import dataclass
5 | from pathlib import Path
6 | from typing import Union
7 |
8 | PathLike = Union[str, Path]
9 | PathOrBuffer = Union[PathLike, io.StringIO]
10 |
11 |
12 | @dataclass
13 | class FunctionAnnotation:
14 | """Represents an annotation of a protein's function over a range of residues.
15 |
16 | Fields:
17 | label (str): An entry in either the function_tokens or residue_annotations tokenizer vocabs
18 | start (int): Start index of this annotation. 1-indexed, inclusive.
19 | end (int): End index of this annotation. 1-indexed, inclusive.
20 | """
21 |
22 | label: str
23 | start: int
24 | end: int
25 |
26 | def to_tuple(self) -> tuple[str, int, int]:
27 | return self.label, self.start, self.end
28 |
--------------------------------------------------------------------------------
/src/esmfold.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import esm
3 | import os
4 | import gc
5 | import argparse
6 | import biotite.structure.io as bsio
7 | import pandas as pd
8 | from tqdm import tqdm
9 | from Bio import SeqIO
10 | from transformers import AutoTokenizer, EsmForProteinFolding
11 |
12 | from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
13 | from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
14 |
15 | def read_fasta(file_path, key):
16 | return str(getattr(SeqIO.read(file_path, 'fasta'), key))
17 |
18 | def read_multi_fasta(file_path):
19 | """
20 | params:
21 | file_path: path to a fasta file
22 | return:
23 | a dictionary of sequences
24 | """
25 | sequences = {}
26 | current_sequence = ''
27 | with open(file_path, 'r') as file:
28 | for line in file:
29 | line = line.strip()
30 | if line.startswith('>'):
31 | if current_sequence:
32 | sequences[header] = current_sequence
33 | current_sequence = ''
34 | header = line
35 | else:
36 | current_sequence += line
37 | if current_sequence:
38 | sequences[header] = current_sequence
39 | return sequences
40 |
41 | def convert_outputs_to_pdb(outputs):
42 | final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
43 | outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
44 | final_atom_positions = final_atom_positions.cpu().numpy()
45 | final_atom_mask = outputs["atom37_atom_exists"]
46 | pdbs = []
47 | for i in range(outputs["aatype"].shape[0]):
48 | aa = outputs["aatype"][i]
49 | pred_pos = final_atom_positions[i]
50 | mask = final_atom_mask[i]
51 | resid = outputs["residue_index"][i] + 1
52 | pred = OFProtein(
53 | aatype=aa,
54 | atom_positions=pred_pos,
55 | atom_mask=mask,
56 | residue_index=resid,
57 | b_factors=outputs["plddt"][i],
58 | chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
59 | )
60 | pdbs.append(to_pdb(pred))
61 | return pdbs
62 |
63 | if __name__ == '__main__':
64 | parser = argparse.ArgumentParser()
65 | parser.add_argument("--sequence", type=str, default=None)
66 | parser.add_argument("--fasta_file", type=str, default=None)
67 | parser.add_argument("--fasta_chunk_num", type=int, default=None)
68 | parser.add_argument("--fasta_chunk_id", type=int, default=None)
69 | parser.add_argument("--fasta_dir", type=str, default=None)
70 | parser.add_argument("--out_dir", type=str)
71 | parser.add_argument("--out_file", type=str, default="result.pdb")
72 | parser.add_argument("--out_info_file", type=str, default=None)
73 | parser.add_argument("--fold_chunk_size", type=int)
74 | args = parser.parse_args()
75 |
76 | # model = esm.pretrained.esmfold_v1()
77 | # model = model.eval().cuda()
78 |
79 | tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
80 | model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)
81 |
82 | model = model.cuda()
83 | # model.esm = model.esm.half()
84 | torch.backends.cuda.matmul.allow_tf32 = True
85 | # Optionally, uncomment to set a chunk size for axial attention. This can help reduce memory.
86 | # Lower sizes will have lower memory requirements at the cost of increased speed.
87 | if args.fold_chunk_size is not None:
88 | model.trunk.set_chunk_size(args.fold_chunk_size)
89 |
90 | if args.fasta_file is not None:
91 | seq_dict = read_multi_fasta(args.fasta_file)
92 | os.makedirs(args.out_dir, exist_ok=True)
93 | names, sequences = list(seq_dict.keys()), list(seq_dict.values())
94 | if args.fasta_chunk_num is not None:
95 | chunk_size = len(names) // args.fasta_chunk_num + 1
96 | start = args.fasta_chunk_id * chunk_size
97 | end = min((args.fasta_chunk_id + 1) * chunk_size, len(names))
98 | names, sequences = names[start:end], sequences[start:end]
99 |
100 | out_info_dict = {"name": [], "plddt": []}
101 | bar = tqdm(zip(names, sequences))
102 | for name, sequence in bar:
103 | bar.set_description(name)
104 | name = name[1:].split(" ")[0]
105 | out_file = os.path.join(args.out_dir, f"{name}.ef.pdb")
106 | if os.path.exists(out_file):
107 | out_info_dict["name"].append(name)
108 | struct = bsio.load_structure(out_file, extra_fields=["b_factor"])
109 | out_info_dict["plddt"].append(struct.b_factor.mean())
110 | continue
111 |
112 | # Multimer prediction can be done with chains separated by ':'
113 | try:
114 | tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
115 | with torch.no_grad():
116 | output = model(tokenized_input)
117 | except:
118 | print(f"Failed to predict {name}")
119 | continue
120 | gc.collect()
121 | pdb = convert_outputs_to_pdb(output)
122 | with open(out_file, "w") as f:
123 | f.write("\n".join(pdb))
124 |
125 | out_info_dict["name"].append(name)
126 | struct = bsio.load_structure(out_file, extra_fields=["b_factor"])
127 | out_info_dict["plddt"].append(struct.b_factor.mean())
128 |
129 | if args.out_info_file is not None:
130 | pd.DataFrame(out_info_dict).to_csv(args.out_info_file, index=False)
131 |
132 | if args.fasta_dir is not None:
133 | os.makedirs(args.out_dir, exist_ok=True)
134 | proteins = sorted(os.listdir(args.fasta_dir))
135 | bar = tqdm(proteins)
136 | for p in bar:
137 | name = p[:-6]
138 | bar.set_description(name)
139 | out_file = os.path.join(args.out_dir, f"{name}.ef.pdb")
140 | if os.path.exists(out_file):
141 | continue
142 | bar.set_description(p)
143 | sequence = read_fasta(os.path.join(args.fasta_dir, p), "seq")
144 | tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
145 | # Multimer prediction can be done with chains separated by ':'
146 |
147 | with torch.no_grad():
148 | output = model(tokenized_input)
149 |
150 | pdb = convert_outputs_to_pdb(output)
151 | with open(out_file, "w") as f:
152 | f.write("\n".join(pdb))
153 |
154 | struct = bsio.load_structure(out_file, extra_fields=["b_factor"])
155 | print(p, struct.b_factor.mean())
156 | elif args.sequence is not None:
157 | sequence = args.sequence
158 | # Multimer prediction can be done with chains separated by ':'
159 |
160 | with torch.no_grad():
161 | output = model.infer_pdb(sequence)
162 |
163 | with open(args.out_file, "w") as f:
164 | f.write(output)
165 |
166 | struct = bsio.load_structure(args.out_file, extra_fields=["b_factor"])
167 | print(struct.b_factor.mean())
--------------------------------------------------------------------------------
/src/models/__pycache__/adapter.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ai4protein/VenusVaccine/a05435581246573845d8a32336340e08c821158d/src/models/__pycache__/adapter.cpython-312.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/pooling.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ai4protein/VenusVaccine/a05435581246573845d8a32336340e08c821158d/src/models/__pycache__/pooling.cpython-312.pyc
--------------------------------------------------------------------------------
/src/models/pooling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from transformers.activations import ACT2FN
5 |
6 | class MaskedConv1d(nn.Conv1d):
7 | """A masked 1-dimensional convolution layer.
8 |
9 | Takes the same arguments as torch.nn.Conv1D, except that the padding is set automatically.
10 |
11 | Shape:
12 | Input: (N, L, in_channels)
13 | input_mask: (N, L, 1), optional
14 | Output: (N, L, out_channels)
15 | """
16 |
17 | def __init__(
18 | self,
19 | in_channels: int,
20 | out_channels: int,
21 | kernel_size: int,
22 | stride: int = 1,
23 | dilation: int = 1,
24 | groups: int = 1,
25 | bias: bool = True,
26 | ):
27 | """
28 | :param in_channels: input channels
29 | :param out_channels: output channels
30 | :param kernel_size: the kernel width
31 | :param stride: filter shift
32 | :param dilation: dilation factor
33 | :param groups: perform depth-wise convolutions
34 | :param bias: adds learnable bias to output
35 | """
36 | padding = dilation * (kernel_size - 1) // 2
37 | super().__init__(
38 | in_channels,
39 | out_channels,
40 | kernel_size,
41 | stride=stride,
42 | dilation=dilation,
43 | groups=groups,
44 | bias=bias,
45 | padding=padding,
46 | )
47 |
48 | def forward(self, x, input_mask=None):
49 | if input_mask is not None:
50 | x = x * input_mask
51 | return super().forward(x.transpose(1, 2)).transpose(1, 2)
52 |
53 |
54 | class Attention1dPooling(nn.Module):
55 | def __init__(self, hidden_size):
56 | super().__init__()
57 | self.layer = MaskedConv1d(hidden_size, 1, 1)
58 |
59 | def forward(self, x, input_mask=None):
60 | batch_szie = x.shape[0]
61 | attn = self.layer(x)
62 | attn = attn.view(batch_szie, -1)
63 | if input_mask is not None:
64 | attn = attn.masked_fill_(
65 | ~input_mask.view(batch_szie, -1).bool(), float("-inf")
66 | )
67 | attn = F.softmax(attn, dim=-1).view(batch_szie, -1, 1)
68 | out = (attn * x).sum(dim=1)
69 | return out, attn
70 |
71 | class Attention1dPoolingProjection(nn.Module):
72 | def __init__(self, hidden_size, num_labels, dropout=0.25) -> None:
73 | super(Attention1dPoolingProjection, self).__init__()
74 | self.linear = nn.Linear(hidden_size, hidden_size)
75 | self.dropout = nn.Dropout(dropout)
76 | self.relu = nn.ReLU()
77 | self.final = nn.Linear(hidden_size, num_labels)
78 |
79 | def forward(self, x):
80 | x = self.linear(x)
81 | x = self.dropout(x)
82 | x = self.relu(x)
83 | x = self.final(x)
84 | return x
85 |
86 | class Attention1dPoolingHead(nn.Module):
87 | """Outputs of the model with the attention1d"""
88 |
89 | def __init__(
90 | self, hidden_size: int, num_labels: int, dropout: float = 0.25, return_attentions: bool = False
91 | ): # [batch x sequence(751) x embedding (1280)] --> [batch x embedding] --> [batch x 1]
92 | super(Attention1dPoolingHead, self).__init__()
93 | self.return_attentions = return_attentions
94 | self.attention1d = Attention1dPooling(hidden_size)
95 | self.attention1d_projection = Attention1dPoolingProjection(hidden_size, num_labels, dropout)
96 |
97 | def forward(self, x, input_mask=None):
98 | x, attn_weights = self.attention1d(x, input_mask=input_mask.unsqueeze(-1))
99 | x = self.attention1d_projection(x)
100 | if self.return_attentions:
101 | return x, attn_weights
102 | else:
103 | return x
104 |
105 | class MeanPooling(nn.Module):
106 | """Mean Pooling for sentence-level classification tasks."""
107 |
108 | def __init__(self):
109 | super().__init__()
110 |
111 | def forward(self, features, input_mask=None):
112 | if input_mask is not None:
113 | # Applying input_mask to zero out masked values
114 | masked_features = features * input_mask.unsqueeze(2)
115 | sum_features = torch.sum(masked_features, dim=1)
116 | mean_pooled_features = sum_features / input_mask.sum(dim=1, keepdim=True)
117 | else:
118 | mean_pooled_features = torch.mean(features, dim=1)
119 | return mean_pooled_features
120 |
121 |
122 | class MeanPoolingProjection(nn.Module):
123 | """Mean Pooling with a projection layer for sentence-level classification tasks."""
124 |
125 | def __init__(self, hidden_size, num_labels, dropout=0.25):
126 | super().__init__()
127 | self.dense = nn.Linear(hidden_size, hidden_size)
128 | self.dropout = nn.Dropout(dropout)
129 | self.out_proj = nn.Linear(hidden_size, num_labels)
130 |
131 | def forward(self, mean_pooled_features):
132 | x = self.dropout(mean_pooled_features)
133 | x = self.dense(x)
134 | x = ACT2FN['gelu'](x)
135 | x = self.dropout(x)
136 | x = self.out_proj(x)
137 | return x
138 |
139 |
140 | class MeanPoolingHead(nn.Module):
141 | """Mean Pooling Head for sentence-level classification tasks."""
142 |
143 | def __init__(self, hidden_size, num_labels, dropout=0.25):
144 | super().__init__()
145 | self.mean_pooling = MeanPooling()
146 | self.mean_pooling_projection = MeanPoolingProjection(hidden_size, num_labels, dropout)
147 |
148 | def forward(self, features, input_mask=None):
149 | mean_pooling_features = self.mean_pooling(features, input_mask=input_mask)
150 | x = self.mean_pooling_projection(mean_pooling_features)
151 | return x
152 |
153 |
154 | class LightAttentionPoolingHead(nn.Module):
155 | def __init__(self, hidden_size=1280, num_labels=11, dropout=0.25, kernel_size=9, conv_dropout: float = 0.25):
156 | super(LightAttentionPoolingHead, self).__init__()
157 |
158 | self.feature_convolution = nn.Conv1d(hidden_size, hidden_size, kernel_size, stride=1,
159 | padding=kernel_size // 2)
160 | self.attention_convolution = nn.Conv1d(hidden_size, hidden_size, kernel_size, stride=1,
161 | padding=kernel_size // 2)
162 |
163 | self.softmax = nn.Softmax(dim=-1)
164 |
165 | self.dropout = nn.Dropout(conv_dropout)
166 |
167 | self.linear = nn.Sequential(
168 | nn.Linear(2 * hidden_size, 32),
169 | nn.Dropout(dropout),
170 | nn.ReLU(),
171 | nn.BatchNorm1d(32)
172 | )
173 |
174 | self.output = nn.Linear(32, num_labels)
175 |
176 | def forward(self, x: torch.Tensor, mask, **kwargs) -> torch.Tensor:
177 | """
178 | Args:
179 | x: [batch_size, sequence_length, hidden_size] embedding tensor that should be classified
180 | mask: [batch_size, sequence_length] mask corresponding to the zero padding used for the shorter sequecnes in the batch. All values corresponding to padding are False and the rest is True.
181 |
182 | Returns:
183 | classification: [batch_size,num_labels] tensor with logits
184 | """
185 | x = x.permute(0, 2, 1) # [batch_size, hidden_size, sequence_length]
186 | o = self.feature_convolution(x) # [batch_size, hidden_size, sequence_length]
187 | o = self.dropout(o) # [batch_gsize, hidden_size, sequence_length]
188 | attention = self.attention_convolution(x) # [batch_size, hidden_size, sequence_length]
189 |
190 | # mask out the padding to which we do not want to pay any attention (we have the padding because the sequences have different lenghts).
191 | # This padding is added by the dataloader when using the padded_permuted_collate function in utils/general.py
192 | attention = attention.masked_fill(mask[:, None, :] == False, -1e9)
193 |
194 | # code used for extracting embeddings for UMAP visualizations
195 | # extraction = torch.sum(x * self.softmax(attention), dim=-1)
196 | # extraction = self.id0(extraction)
197 |
198 | o1 = torch.sum(o * self.softmax(attention), dim=-1) # [batchsize, hidden_size]
199 | o2, _ = torch.max(o, dim=-1) # [batchsize, hidden_size]
200 | o = torch.cat([o1, o2], dim=-1) # [batchsize, 2*hidden_size]
201 | o = self.linear(o) # [batchsize, 32]
202 | return self.output(o) # [batchsize, num_labels]
--------------------------------------------------------------------------------
/src/utils/__pycache__/data_utils.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ai4protein/VenusVaccine/a05435581246573845d8a32336340e08c821158d/src/utils/__pycache__/data_utils.cpython-312.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/loss_fn.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ai4protein/VenusVaccine/a05435581246573845d8a32336340e08c821158d/src/utils/__pycache__/loss_fn.cpython-312.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/metrics.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ai4protein/VenusVaccine/a05435581246573845d8a32336340e08c821158d/src/utils/__pycache__/metrics.cpython-312.pyc
--------------------------------------------------------------------------------
/src/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import biotite
3 | import numpy as np
4 | import torch.utils.data as data
5 | from typing import List
6 | from biotite.structure.residues import get_residues
7 | from biotite.sequence import ProteinSequence
8 | from biotite.structure.io import pdbx, pdb
9 | from biotite.structure import filter_backbone
10 | from biotite.structure import get_chains
11 |
12 | def load_structure(fpath, chain=None):
13 | """
14 | Args:
15 | fpath: filepath to either pdb or cif file
16 | chain: the chain id or list of chain ids to load
17 | Returns:
18 | biotite.structure.AtomArray
19 | """
20 | if fpath.endswith('cif'):
21 | with open(fpath) as fin:
22 | pdbxf = pdbx.PDBxFile.read(fin)
23 | structure = pdbx.get_structure(pdbxf, model=1)
24 | elif fpath.endswith('pdb'):
25 | with open(fpath) as fin:
26 | pdbf = pdb.PDBFile.read(fin)
27 | structure = pdb.get_structure(pdbf, model=1)
28 | bbmask = filter_backbone(structure)
29 | structure = structure[bbmask]
30 | all_chains = get_chains(structure)
31 | if len(all_chains) == 0:
32 | raise ValueError('No chains found in the input file.')
33 | if chain is None:
34 | chain_ids = all_chains
35 | elif isinstance(chain, list):
36 | chain_ids = chain
37 | else:
38 | chain_ids = [chain]
39 | for chain in chain_ids:
40 | if chain not in all_chains:
41 | raise ValueError(f'Chain {chain} not found in input file')
42 | chain_filter = [a.chain_id in chain_ids for a in structure]
43 | structure = structure[chain_filter]
44 | return structure
45 |
46 | def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray):
47 | """
48 | Example for atoms argument: ["N", "CA", "C"]
49 | """
50 | def filterfn(s, axis=None):
51 | filters = np.stack([s.atom_name == name for name in atoms], axis=1)
52 | sum = filters.sum(0)
53 | if not np.all(sum <= np.ones(filters.shape[1])):
54 | raise RuntimeError("structure has multiple atoms with same name")
55 | index = filters.argmax(0)
56 | coords = s[index].coord
57 | coords[sum == 0] = float("nan")
58 | return coords
59 |
60 | return biotite.structure.apply_residue_wise(struct, struct, filterfn)
61 |
62 | def extract_coords_from_structure(structure: biotite.structure.AtomArray):
63 | """
64 | Args:
65 | structure: An instance of biotite AtomArray
66 | Returns:
67 | Tuple (coords, seq)
68 | - coords is an L x 3 x 3 array for N, CA, C coordinates
69 | - seq is the extracted sequence
70 | """
71 | coords = get_atom_coords_residuewise(["N", "CA", "C"], structure)
72 | residue_identities = get_residues(structure)[1]
73 | seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities])
74 | return coords
75 |
76 | def extract_seq_from_pdb(pdb_file, chain=None):
77 | """
78 | Args:
79 | structure: An instance of biotite AtomArray
80 | Returns:
81 | - seq is the extracted sequence
82 | """
83 | structure = load_structure(pdb_file, chain)
84 | residue_identities = get_residues(structure)[1]
85 | seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities])
86 | return seq
87 |
88 |
89 | class BatchSampler(data.Sampler):
90 | '''
91 | A `torch.utils.data.Sampler` which samples batches according to a
92 | maximum number of graph nodes.
93 |
94 | :param node_counts: array of node counts in the dataset to sample from
95 | :param max_batch_nodes: the maximum number of nodes in any batch,
96 | including batches of a single element
97 | :param shuffle: if `True`, batches in shuffled order
98 | '''
99 | def __init__(self, node_counts, max_batch_nodes=10000, shuffle=True):
100 |
101 | self.node_counts = node_counts
102 | self.idx = [i for i in range(len(node_counts)) if node_counts[i] <= max_batch_nodes]
103 | self.shuffle = shuffle
104 | self.max_batch_nodes = max_batch_nodes
105 | self._form_batches()
106 |
107 | def _form_batches(self):
108 | self.batches = []
109 | if self.shuffle: random.shuffle(self.idx)
110 | idx = self.idx
111 | while idx:
112 | batch = []
113 | max_n_node = 0
114 | while idx:
115 | if max(self.node_counts[idx[0]], max_n_node) * (len(batch) + 1) > self.max_batch_nodes:
116 | break
117 | next_idx, idx = idx[0], idx[1:]
118 | current_n_node = self.node_counts[next_idx]
119 | if current_n_node > max_n_node:
120 | max_n_node = current_n_node
121 | batch.append(next_idx)
122 | self.batches.append(batch)
123 |
124 | def __len__(self):
125 | if not self.batches: self._form_batches()
126 | return len(self.batches)
127 |
128 | def __iter__(self):
129 | if not self.batches: self._form_batches()
130 | for batch in self.batches: yield batch
131 |
132 |
133 | def top_k_accuracy(labels, probas):
134 | probas, labels = np.array(probas), np.array(labels)
135 | k = int(len(labels) * 0.3)
136 | topk = probas.argsort()[-k:]
137 | correct = labels[topk] == 1
138 | return correct.sum() / k
139 |
140 | def plot_roc_curve(y_true, y_pred, save_fig=None):
141 | import matplotlib.pyplot as plt
142 | from sklearn import metrics
143 | fpr, tpr, _ = metrics.roc_curve(y_true, y_pred)
144 | roc_auc = metrics.roc_auc_score(y_true, y_pred)
145 | plt.plot(fpr, tpr, label=f'Our (AUC = {roc_auc:.2f})')
146 | plt.plot([0, 1], [0, 1], 'k--')
147 | plt.xlabel('False Positive Rate')
148 | plt.ylabel('True Positive Rate')
149 | plt.title(f'ROC Curve')
150 | plt.legend()
151 | if save_fig:
152 | plt.savefig(save_fig, dpi=300, bbox_inches='tight')
153 | plt.close()
154 |
--------------------------------------------------------------------------------
/src/utils/loss_fn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class MultiClassFocalLossWithAlpha(nn.Module):
5 | def __init__(self, num_classes, alpha=None, gamma=1, reduction='mean', device="cuda"):
6 | super(MultiClassFocalLossWithAlpha, self).__init__()
7 | if alpha is None:
8 | self.alpha = torch.ones(num_classes, dtype=torch.float32)
9 | self.alpha = torch.tensor(alpha).to(device)
10 | self.gamma = gamma
11 | self.reduction = reduction
12 |
13 | def forward(self, pred, target):
14 | alpha = self.alpha[target]
15 | log_softmax = torch.log_softmax(pred, dim=1)
16 | logpt = torch.gather(log_softmax, dim=1, index=target.view(-1, 1))
17 | logpt = logpt.view(-1)
18 | ce_loss = -logpt
19 | pt = torch.exp(logpt)
20 | focal_loss = alpha * (1 - pt) ** self.gamma * ce_loss
21 | if self.reduction == "mean":
22 | return torch.mean(focal_loss)
23 | if self.reduction == "sum":
24 | return torch.sum(focal_loss)
25 | return focal_loss
--------------------------------------------------------------------------------
/src/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchmetrics.classification import MultilabelAveragePrecision
3 |
4 |
5 | def count_f1_max(pred, target):
6 | """
7 | F1 score with the optimal threshold, Copied from TorchDrug.
8 |
9 | This function first enumerates all possible thresholds for deciding positive and negative
10 | samples, and then pick the threshold with the maximal F1 score.
11 |
12 | Parameters:
13 | pred (Tensor): predictions of shape :math:`(B, N)`
14 | target (Tensor): binary targets of shape :math:`(B, N)`
15 | """
16 |
17 | order = pred.argsort(descending=True, dim=1)
18 | target = target.gather(1, order)
19 | precision = target.cumsum(1) / torch.ones_like(target).cumsum(1)
20 | recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10)
21 | is_start = torch.zeros_like(target).bool()
22 | is_start[:, 0] = 1
23 | is_start = torch.scatter(is_start, 1, order, is_start)
24 |
25 | all_order = pred.flatten().argsort(descending=True)
26 | order = (
27 | order
28 | + torch.arange(order.shape[0], device=order.device).unsqueeze(1)
29 | * order.shape[1]
30 | )
31 | order = order.flatten()
32 | inv_order = torch.zeros_like(order)
33 | inv_order[order] = torch.arange(order.shape[0], device=order.device)
34 | is_start = is_start.flatten()[all_order]
35 | all_order = inv_order[all_order]
36 | precision = precision.flatten()
37 | recall = recall.flatten()
38 | all_precision = precision[all_order] - torch.where(
39 | is_start, torch.zeros_like(precision), precision[all_order - 1]
40 | )
41 | all_precision = all_precision.cumsum(0) / is_start.cumsum(0)
42 | all_recall = recall[all_order] - torch.where(
43 | is_start, torch.zeros_like(recall), recall[all_order - 1]
44 | )
45 | all_recall = all_recall.cumsum(0) / pred.shape[0]
46 | all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10)
47 | return all_f1.max()
48 |
49 |
50 | class MultilabelF1Max(MultilabelAveragePrecision):
51 |
52 | def compute(self):
53 | return count_f1_max(torch.cat(self.preds), torch.cat(self.target))
--------------------------------------------------------------------------------