├── .gitignore ├── figs └── architecture_v2.PNG ├── __init__.py ├── src ├── __init__.py ├── logging.py ├── lr_scheduler.py ├── models_ab.py ├── configs.py ├── lora.py ├── training.py ├── swe_pooling.py ├── alignment.py ├── data_module.py ├── utils.py └── models.py ├── setup.py ├── LICENSE ├── README.md └── environment.yml /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | -------------------------------------------------------------------------------- /figs/architecture_v2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/ImmuneCLIP/HEAD/figs/architecture_v2.PNG -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # This file can be empty, or you can use it to define what gets imported 2 | # when you do `from project import *` -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # This file can be empty, or you can use it to define what gets imported 2 | # when you do `from project import *` -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='ImmunoAlign', 5 | version='0.0.1', 6 | packages=find_packages(), 7 | install_requires=[ 8 | 'torch', 9 | 'pytorch-lightning', 10 | # add other dependencies here 11 | ], 12 | ) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Kundaje Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/logging.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 4 | from typing import List 5 | from argparse import Namespace 6 | from dataclasses import is_dataclass 7 | 8 | def call_backs(output_dir, save_ckpts=False): 9 | lrmonitor_callback = LearningRateMonitor(logging_interval='step') 10 | 11 | #Example 2: other options 12 | checkpoint_callback = ModelCheckpoint( 13 | monitor='val_loss', 14 | dirpath=os.path.join(output_dir, 'checkpoints'), 15 | filename="model-{epoch:03d}-{val_loss:.4f}", 16 | save_top_k=2, 17 | mode='min', 18 | save_last=True, 19 | ) 20 | 21 | if save_ckpts: 22 | callbacks = [checkpoint_callback, lrmonitor_callback] 23 | else: 24 | callbacks = [lrmonitor_callback] 25 | 26 | return callbacks 27 | 28 | def combine_args_and_configs(args: Namespace, dataclasses: List): 29 | if not isinstance(args, dict): 30 | args = vars(args).items() 31 | else: 32 | args = args.items() 33 | for name, value in args: 34 | if value is not None: 35 | for obj in dataclasses: 36 | if is_dataclass(obj) and hasattr(obj, name): 37 | print("overwriting default", name, value) 38 | setattr(obj, name, value) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ImmuneCLIP 2 | Code for paper "Sequence-based TCR-Peptide Representations Using Cross-Epitope Contrastive Fine-tuning of Protein Language Models." [[paper](https://link.springer.com/chapter/10.1007/978-3-031-90252-9_3)] (_RECOMB 2025_) 3 | 4 | 🚧 This repository is under active construction. 🚧 5 | 6 | ## Model Components: 7 | 8 | 9 | * Epitope Encoder 10 | * PEFT-adapted Protein Language Models (e.g. ESM-2, ESM-3) 11 | * Default: Using LoRA (rank = 8) on last 8 transformer layers 12 | * Projection layer: FC linear layer (dim $d_{e} \rightarrow d_p$) 13 | * $d_{e}$ is the original PLM dimension, $d_p$ is the projection dimension 14 | 15 | * Recepter Encoder 16 | * PEFT-adapted Protein Language Models (e.g. ESM-2, ESM-3) **or** BCR/TCR Language Models (e.g. AbLang, TCR-BERT, etc.) 17 | * Default: Using LoRA (rank = 8) on last 4 transformer layers 18 | * Projection layer: FC linear layer (dim $d_{r} \rightarrow d_p$) 19 | * $d_{r}$ is the original receptor LM dimension, $d_p$ is the projection dimension 20 | 21 | ## Dataset: 22 | * MixTCRPred Dataset ([paper](https://github.com/GfellerLab/MixTCRpred/tree/main)) 23 | * Contains curated mixture of TCR-pMHC sequence data from IEDB, VDJdb, 10x Genomics, and McPAS-TCR 24 | 25 | ## Pre-trained Weights: 26 | * The pre-trained weights for ImmuneCLIP is deposited at [Zenodo](https://zenodo.org/records/14962685) 27 | 28 | ## CLI: 29 | ### Environment Variables 30 | To run this application, set the following environment variables 31 | ``` 32 | WANDB_OUTPUT_DIR= 33 | ``` 34 | 35 | Additionally, if training on top of a custom in-house TCR model, the following path needs to be set 36 | ``` 37 | INHOUSE_MODEL_CKPT_PATH= 38 | ``` 39 | 40 | 41 | ### Training 42 | ``` 43 | # go to root directory of the repo, and then run: 44 | python -m src.training --run-id [RUN_ID] --dataset-path [PATH_TO_DATASET] --stage fit --max-epochs 100 \\ 45 | --receptor-model-name [esm2|tcrlang|tcrbert] --projection-dim 512 --gpus-used [GPU_IDX] --lr 1e-3 \\ 46 | --batch-size 8 --output-dir [CHECKPOINTS_OUTPUT_DIR] [--mask-seqs] 47 | ``` 48 | 49 | ### Evaluation 50 | ``` 51 | # currently, running model on test stage embeds the test set epitope/receptor pairs with the fine-tuned model and saves them. 52 | python -m src.training --run-id [RUN_ID] --dataset-path [PATH_TO_DATASET] --stage test --from-checkpoint [CHECKPOINT_PATH] \\ 53 | --projection-dim 512 --receptor-model-name [esm2|tcrlang|tcrbert] --gpus-used [GPU_IDX] --batch-size 8 \\ 54 | --save-embed-path [PATH_FOR_SAVING_EMBEDS] 55 | ``` 56 | -------------------------------------------------------------------------------- /src/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | class CosineAnnealingWarmUpRestarts(_LRScheduler): 5 | def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1): 6 | if T_0 <= 0 or not isinstance(T_0, int): 7 | raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) 8 | if T_mult < 1 or not isinstance(T_mult, int): 9 | raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) 10 | if T_up < 0 or not isinstance(T_up, int): 11 | raise ValueError("Expected positive integer T_up, but got {}".format(T_up)) 12 | self.T_0 = T_0 13 | self.T_mult = T_mult 14 | self.base_eta_max = eta_max 15 | self.eta_max = eta_max 16 | self.T_up = T_up 17 | self.T_i = T_0 18 | self.gamma = gamma 19 | self.cycle = 0 20 | self.T_cur = last_epoch 21 | super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch) 22 | 23 | def get_lr(self): 24 | if self.T_cur == -1: 25 | return self.base_lrs 26 | elif self.T_cur < self.T_up: 27 | return [(self.eta_max - base_lr)*self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs] 28 | else: 29 | return [base_lr + (self.eta_max - base_lr) * (1 + math.cos(math.pi * (self.T_cur-self.T_up) / (self.T_i - self.T_up))) / 2 30 | for base_lr in self.base_lrs] 31 | 32 | def step(self, epoch=None): 33 | if epoch is None: 34 | epoch = self.last_epoch + 1 35 | self.T_cur = self.T_cur + 1 36 | if self.T_cur >= self.T_i: 37 | self.cycle += 1 38 | self.T_cur = self.T_cur - self.T_i 39 | self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up 40 | else: 41 | if epoch >= self.T_0: 42 | if self.T_mult == 1: 43 | self.T_cur = epoch % self.T_0 44 | self.cycle = epoch // self.T_0 45 | else: 46 | n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) 47 | self.cycle = n 48 | self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1) 49 | self.T_i = self.T_0 * self.T_mult ** (n) 50 | else: 51 | self.T_i = self.T_0 52 | self.T_cur = epoch 53 | 54 | self.eta_max = self.base_eta_max * (self.gamma**self.cycle) 55 | self.last_epoch = math.floor(epoch) 56 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 57 | param_group['lr'] = lr -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: immuneclip 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1 6 | - _openmp_mutex=5.1 7 | - bzip2=1.0.8 8 | - ca-certificates=2024.3.11 9 | - ld_impl_linux-64=2.38 10 | - libffi=3.4.4 11 | - libgcc-ng=11.2.0 12 | - libgomp=11.2.0 13 | - libstdcxx-ng=11.2.0 14 | - libuuid=1.41.5 15 | - ncurses=6.4 16 | - openssl=3.0.14 17 | - pip=24.0 18 | - python=3.10.14 19 | - readline=8.2 20 | - setuptools=69.5.1 21 | - sqlite=3.45.3 22 | - tk=8.6.14 23 | - wheel=0.43.0 24 | - xz=5.4.6 25 | - zlib=1.2.13 26 | - pip: 27 | - git+https://github.com/oxpig/AbLang.git 28 | - git+https://github.com/oxpig/AbLang2.git 29 | - accelerate==0.32.1 30 | - aiohttp==3.9.5 31 | - aiosignal==1.3.1 32 | - anyio==4.4.0 33 | - argon2-cffi==23.1.0 34 | - argon2-cffi-bindings==21.2.0 35 | - arrow==1.3.0 36 | - asttokens==2.4.1 37 | - async-lru==2.0.4 38 | - async-timeout==4.0.3 39 | - attrs==23.2.0 40 | - babel==2.15.0 41 | - beautifulsoup4==4.12.3 42 | - biopython==1.84 43 | - biotite==0.41.2 44 | - bleach==6.1.0 45 | - brotli==1.1.0 46 | - certifi==2024.7.4 47 | - cffi==1.16.0 48 | - charset-normalizer==3.3.2 49 | - click==8.1.7 50 | - comm==0.2.2 51 | - debugpy==1.8.2 52 | - decorator==5.1.1 53 | - defusedxml==0.7.1 54 | - docker-pycreds==0.4.0 55 | - einops==0.8.0 56 | - esm==3.0.0 57 | - exceptiongroup==1.2.1 58 | - executing==2.0.1 59 | - fastjsonschema==2.20.0 60 | - filelock==3.15.4 61 | - fqdn==1.5.1 62 | - frozenlist==1.4.1 63 | - fsspec==2024.6.1 64 | - gitdb==4.0.11 65 | - gitpython==3.1.43 66 | - h11==0.14.0 67 | - httpcore==1.0.5 68 | - httpx==0.27.0 69 | - huggingface-hub==0.23.4 70 | - idna==3.7 71 | - ipykernel==6.29.5 72 | - ipython==8.26.0 73 | - ipywidgets==8.1.3 74 | - isoduration==20.11.0 75 | - jedi==0.19.1 76 | - jinja2==3.1.4 77 | - joblib==1.4.2 78 | - json5==0.9.25 79 | - jsonpointer==3.0.0 80 | - jsonschema==4.23.0 81 | - jsonschema-specifications==2023.12.1 82 | - jupyter-client==8.6.2 83 | - jupyter-core==5.7.2 84 | - jupyter-events==0.10.0 85 | - jupyter-lsp==2.2.5 86 | - jupyter-server==2.14.1 87 | - jupyter-server-terminals==0.5.3 88 | - jupyterlab==4.2.3 89 | - jupyterlab-pygments==0.3.0 90 | - jupyterlab-server==2.27.2 91 | - jupyterlab-widgets==3.0.11 92 | - lightning==2.3.3 93 | - lightning-utilities==0.11.3.post0 94 | - llvmlite==0.43.0 95 | - markupsafe==2.1.5 96 | - matplotlib-inline==0.1.7 97 | - mistune==3.0.2 98 | - mpmath==1.3.0 99 | - msgpack==1.0.8 100 | - msgpack-numpy==0.4.8 101 | - multidict==6.0.5 102 | - nbclient==0.10.0 103 | - nbconvert==7.16.4 104 | - nbformat==5.10.4 105 | - nest-asyncio==1.6.0 106 | - networkx==3.3 107 | - notebook-shim==0.2.4 108 | - numba==0.60.0 109 | - numpy==1.26.4 110 | - nvidia-cublas-cu12==12.1.3.1 111 | - nvidia-cuda-cupti-cu12==12.1.105 112 | - nvidia-cuda-nvrtc-cu12==12.1.105 113 | - nvidia-cuda-runtime-cu12==12.1.105 114 | - nvidia-cudnn-cu12==8.9.2.26 115 | - nvidia-cufft-cu12==11.0.2.54 116 | - nvidia-curand-cu12==10.3.2.106 117 | - nvidia-cusolver-cu12==11.4.5.107 118 | - nvidia-cusparse-cu12==12.1.0.106 119 | - nvidia-nccl-cu12==2.20.5 120 | - nvidia-nvjitlink-cu12==12.5.82 121 | - nvidia-nvtx-cu12==12.1.105 122 | - overrides==7.7.0 123 | - packaging==24.1 124 | - pandas==2.2.2 125 | - pandocfilters==1.5.1 126 | - parso==0.8.4 127 | - peft==0.11.1 128 | - pexpect==4.9.0 129 | - pillow==10.4.0 130 | - platformdirs==4.2.2 131 | - prometheus-client==0.20.0 132 | - prompt-toolkit==3.0.47 133 | - protobuf==5.27.2 134 | - psutil==6.0.0 135 | - ptyprocess==0.7.0 136 | - pure-eval==0.2.2 137 | - pycparser==2.22 138 | - pygments==2.18.0 139 | - python-dateutil==2.9.0.post0 140 | - python-json-logger==2.0.7 141 | - pytorch-lightning==2.3.3 142 | - pytz==2024.1 143 | - pyyaml==6.0.1 144 | - pyzmq==26.0.3 145 | - referencing==0.35.1 146 | - regex==2024.5.15 147 | - requests==2.32.3 148 | - rfc3339-validator==0.1.4 149 | - rfc3986-validator==0.1.1 150 | - rotary-embedding-torch==0.6.4 151 | - rpds-py==0.19.0 152 | - safetensors==0.4.3 153 | - scikit-learn==1.5.1 154 | - scipy==1.14.0 155 | - send2trash==1.8.3 156 | - sentry-sdk==2.8.0 157 | - setproctitle==1.3.3 158 | - six==1.16.0 159 | - smmap==5.0.1 160 | - sniffio==1.3.1 161 | - soupsieve==2.5 162 | - stack-data==0.6.3 163 | - sympy==1.13.0 164 | - terminado==0.18.1 165 | - threadpoolctl==3.5.0 166 | - tinycss2==1.3.0 167 | - tokenizers==0.19.1 168 | - tomli==2.0.1 169 | - torch==2.3.1 170 | - torchmetrics==1.4.0.post0 171 | - torchtext==0.18.0 172 | - torchvision==0.18.1 173 | - tornado==6.4.1 174 | - tqdm==4.66.4 175 | - traitlets==5.14.3 176 | - transformers==4.42.3 177 | - triton==2.3.1 178 | - types-python-dateutil==2.9.0.20240316 179 | - typing-extensions==4.12.2 180 | - tzdata==2024.1 181 | - uri-template==1.3.0 182 | - urllib3==2.2.2 183 | - wandb==0.17.4 184 | - wcwidth==0.2.13 185 | - webcolors==24.6.0 186 | - webencodings==0.5.1 187 | - websocket-client==1.8.0 188 | - widgetsnbextension==4.0.11 189 | - yarl==1.9.4 190 | -------------------------------------------------------------------------------- /src/models_ab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import re 5 | 6 | from .utils import get_sequence_embeddings, insert_spaces, get_attention_mask, apply_masking_seq 7 | from .swe_pooling import SWE_Pooling 8 | 9 | class AntibodyEncoderAbLang(nn.Module): 10 | def __init__(self, input_dim, projection_dim, ln_cfg, device='cpu'): 11 | super().__init__() 12 | from .lora import setup_peft_ablang 13 | from .configs import peft_config_ablang 14 | 15 | # load the LoRA adapted AbLang HL Models here: 16 | self.ablang_H_lora, self.ablang_H_tokenizer = setup_peft_ablang(peft_config_ablang, chain='H') 17 | self.ablang_L_lora, self.ablang_L_tokenizer = setup_peft_ablang(peft_config_ablang, chain='L') 18 | 19 | self.proj_head = nn.Sequential( 20 | nn.Linear(input_dim, projection_dim), 21 | nn.LayerNorm(projection_dim), 22 | ) 23 | 24 | self.device = device 25 | 26 | def forward(self, x): 27 | H_seqs, L_seqs = x 28 | H_seqs_tokens = self.process_seqs(H_seqs, chain='H') 29 | L_seqs_tokens = self.process_seqs(L_seqs, chain='L') 30 | 31 | try: 32 | H_outputs = self.ablang_H_lora(**H_seqs_tokens) 33 | except: 34 | print("Error in feeding H sequences") 35 | 36 | print("H seq:", H_seqs) 37 | print("H seq tokens max:", torch.max(H_seqs_tokens['input_ids'])) 38 | print("self.ablang_H_lora: ", self.ablang_H_lora) 39 | 40 | raise ValueError 41 | 42 | try: 43 | L_outputs = self.ablang_L_lora(**L_seqs_tokens) 44 | except: 45 | print("Error in feeding L sequences") 46 | 47 | print("L seq:", L_seqs) 48 | print("L seq tokens max:", torch.max(L_seqs_tokens['input_ids'])) 49 | print("self.ablang_L_lora: ", self.ablang_L_lora) 50 | 51 | 52 | H_outputs = get_sequence_embeddings(H_seqs_tokens, H_outputs) 53 | L_outputs = get_sequence_embeddings(L_seqs_tokens, L_outputs) 54 | 55 | Ab_seq_embeds = torch.cat((H_outputs, L_outputs), dim=-1) 56 | 57 | return self.proj_head(Ab_seq_embeds) 58 | 59 | def process_seqs(self, seqs, chain): 60 | ''' 61 | seqs: tuple of sequences 62 | ''' 63 | 64 | # format the seq strings accordingly to AbLang: 65 | seqs = [insert_spaces(seq) for seq in seqs] 66 | # seqs = [' '.join(seq) for seq in seqs] 67 | 68 | if chain == 'H': 69 | seqs_tokens = self.ablang_H_tokenizer(seqs, return_tensors="pt", padding=True) 70 | else: 71 | seqs_tokens = self.ablang_L_tokenizer(seqs, return_tensors="pt", padding=True) 72 | 73 | return seqs_tokens.to(self.device) 74 | 75 | 76 | class AntibodyEncoderAbLang2(nn.Module): 77 | def __init__(self, input_dim, projection_dim, ln_cfg, device='cpu'): 78 | super().__init__() 79 | from .lora import setup_peft_ablang2 80 | from .configs import peft_config_ablang2 81 | 82 | self.ablang2_lora, self.ablang2_tokenizer = setup_peft_ablang2(peft_config_ablang2, receptor_type='BCR', device=device, no_lora=ln_cfg.no_lora) 83 | self.padding_idx = 21 84 | 85 | self.proj_head = nn.Sequential( 86 | nn.Linear(input_dim, projection_dim), 87 | nn.LayerNorm(projection_dim), 88 | ) 89 | 90 | self.device = device 91 | 92 | def forward(self, x): 93 | seq_tokens = self.process_seqs(x) 94 | 95 | # print("seq tokens:", seq_tokens) 96 | 97 | # feed to AbLang2 98 | rescoding = self.ablang2_lora(seq_tokens) 99 | 100 | # process AbLang2 outputs 101 | seq_inputs = {'attention_mask': ~(seq_tokens == self.padding_idx)} 102 | model_output = {'last_hidden_state': rescoding.last_hidden_states} 103 | 104 | seq_outputs = get_sequence_embeddings(seq_inputs, model_output, is_sep=False, is_cls=False) 105 | 106 | return self.proj_head(seq_outputs) 107 | 108 | def process_seqs(self, seqs): 109 | H_seqs, L_seqs = seqs 110 | 111 | # format the seq strings accordingly to AbLang2: 112 | ab_seqs = [f"{H_seqs[i]}|{L_seqs[i]}" for i in range(len(H_seqs))] 113 | 114 | seqs_tokens = self.ablang2_tokenizer(ab_seqs, pad=True, w_extra_tkns=False, device=self.device) 115 | 116 | return seqs_tokens 117 | 118 | 119 | class AntibodyEncoderAntiberta2(nn.Module): 120 | def __init__(self, input_dim, projection_dim, ln_cfg, device='cpu'): 121 | super().__init__() 122 | from .lora import setup_peft_aberta2 123 | from .configs import peft_config_aberta2 124 | 125 | self.aberta2_lora, self.aberta2_tokenizer = setup_peft_aberta2(peft_config_aberta2) 126 | 127 | self.proj_head = nn.Sequential( 128 | nn.Linear(input_dim, projection_dim), 129 | nn.LayerNorm(projection_dim), 130 | ) 131 | 132 | self.device = device 133 | 134 | def forward(self, x): 135 | seq_tokens = self.process_seqs(x) 136 | 137 | try: 138 | # feed to AntiBERTa 139 | rescoding = self.aberta2_lora(**seq_tokens) 140 | except: 141 | print("seqs:", x) 142 | print("seq tokens:", seq_tokens) 143 | raise ValueError 144 | 145 | seq_embeds = get_sequence_embeddings(seq_tokens, rescoding) 146 | 147 | return self.proj_head(seq_embeds) 148 | 149 | def process_seqs(self, seqs): 150 | H_seqs, L_seqs = seqs 151 | 152 | # format the seq strings accordingly to Antiberta2: 153 | ab_seqs = [f"{insert_spaces(H_seqs[i])} [SEP] {insert_spaces(L_seqs[i])}" for i in range(len(H_seqs))] 154 | 155 | seqs_tokens = self.aberta2_tokenizer(ab_seqs, return_tensors="pt", padding=True) 156 | 157 | return seqs_tokens.to(self.device) 158 | -------------------------------------------------------------------------------- /src/configs.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | 4 | from peft import LoraConfig, TaskType 5 | 6 | output_dir_path = os.getenv('WANDB_OUTPUT_DIR') 7 | 8 | @dataclass 9 | class LightningConfig: 10 | max_epochs: int = 10 11 | lr: float = 1e-4 12 | weight_decay: float = 0.01 13 | batch_size: int = 4 14 | num_workers_train: int = 8 15 | torch_device: str = 'gpu' 16 | dataset_path: str = None 17 | output_dir: str = output_dir_path 18 | include_mhc: bool = False 19 | mhc_groove_only: bool = False 20 | unique_epitopes: bool = False 21 | mask_seqs: bool = False 22 | mask_prob: float = 0.15 23 | swe_pooling: bool = False 24 | save_embed_path: str = None 25 | no_lora: bool = False 26 | mse_weight: float = 0. 27 | weigh_epitope_count: bool = False 28 | oversample: bool = False 29 | regular_ft: bool = False 30 | fewshot_ratio: float = None 31 | lr_scheduler: str = 'cos_anneal' 32 | 33 | 34 | @dataclass 35 | class EncoderProjectionConfigAbLang: 36 | epitope_input_dim: int = 1280 37 | receptor_input_dim: int = 1536 38 | projection_dim: int = 512 39 | temperature: float = 0.07 40 | receptor_model_name: str = 'ablang' 41 | 42 | @dataclass 43 | class EncoderProjectionConfigAbLang2: 44 | epitope_input_dim: int = 1280 45 | receptor_input_dim: int = 480 46 | projection_dim: int = 512 47 | temperature: float = 0.07 48 | receptor_model_name: str = 'ablang2' 49 | 50 | @dataclass 51 | class EncoderProjectionConfigAntiberta2: 52 | epitope_input_dim: int = 1280 53 | receptor_input_dim: int = 1024 54 | projection_dim: int = 512 55 | temperature: float = 0.07 56 | receptor_model_name: str = 'antiberta2' 57 | 58 | @dataclass 59 | class EncoderProjectionConfigTCRBert: 60 | epitope_input_dim: int = 1280 61 | receptor_input_dim: int = 1536 62 | hidden_dim: int = None 63 | projection_dim: int = 512 64 | temperature: float = 0.07 65 | receptor_model_name: str = 'tcrbert' 66 | 67 | @dataclass 68 | class EncoderProjectionConfigTCRLang: 69 | epitope_input_dim: int = 1280 70 | receptor_input_dim: int = 480 71 | hidden_dim: int = None 72 | projection_dim: int = 512 73 | temperature: float = 0.07 74 | receptor_model_name: str = 'tcrlang' 75 | 76 | @dataclass 77 | class EncoderProjectionConfigESM2: 78 | epitope_input_dim: int = 1280 79 | receptor_input_dim: int = 1280 80 | hidden_dim: int = None 81 | projection_dim: int = None 82 | temperature: float = 0.07 83 | receptor_model_name: str = 'esm2' 84 | 85 | @dataclass 86 | class EncoderProjectionConfigESM3: 87 | epitope_input_dim: int = 1536 88 | receptor_input_dim: int = 1536 89 | hidden_dim: int = None 90 | projection_dim: int = None 91 | temperature: float = 0.07 92 | receptor_model_name: str = 'esm3' 93 | 94 | @dataclass 95 | class EncoderProjectionConfigInHouse: 96 | epitope_input_dim: int = 1280 97 | receptor_input_dim: int = 768 98 | hidden_dim: int = None 99 | projection_dim: int = None 100 | temperature: float = 0.07 101 | receptor_model_name: str = 'inhouse' 102 | 103 | @dataclass 104 | class EncoderProjectionConfigOneHot: 105 | epitope_input_dim: int = 21 106 | receptor_input_dim: int = 21 107 | hidden_dim: int = None 108 | projection_dim: int = None 109 | temperature: float = 0.07 110 | receptor_model_name: str = 'inhouse' 111 | 112 | # -------------------------------------------------- 113 | # PEFT configs: 114 | peft_config_esm2 = LoraConfig( 115 | r=8, 116 | lora_alpha=32, 117 | lora_dropout=0.1, 118 | bias='none', 119 | layers_to_transform=[32, 31, 30, 29, 28, 27, 26, 25], 120 | task_type=TaskType.FEATURE_EXTRACTION, 121 | target_modules=['attention.self.key', 'attention.self.value'] 122 | ) 123 | 124 | peft_config_esm3 = LoraConfig( 125 | r=8, 126 | lora_alpha=32, 127 | lora_dropout=0.1, 128 | bias='none', 129 | layers_to_transform=[47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36], 130 | task_type=TaskType.FEATURE_EXTRACTION, 131 | target_modules=['attn.layernorm_qkv.1'] 132 | ) 133 | 134 | peft_config_ablang = LoraConfig( 135 | r=8, 136 | lora_alpha=32, 137 | lora_dropout=0.1, 138 | bias='none', 139 | layers_to_transform=[11, 10, 9, 8], 140 | task_type=TaskType.FEATURE_EXTRACTION, 141 | target_modules=['attention.self.query', 'attention.self.value'] 142 | ) 143 | 144 | peft_config_ablang2 = LoraConfig( 145 | r=8, 146 | lora_alpha=32, 147 | lora_dropout=0.1, 148 | bias='none', 149 | # layers_to_transform=[11, 10, 9, 8], 150 | task_type=TaskType.FEATURE_EXTRACTION, 151 | target_modules=".*(8|9|10|11).*[kv]_proj$" 152 | ) 153 | 154 | peft_config_aberta2 = LoraConfig( 155 | r=8, 156 | lora_alpha=32, 157 | lora_dropout=0.1, 158 | bias='none', 159 | layers_to_transform=[15, 14, 13, 12], 160 | task_type=TaskType.FEATURE_EXTRACTION, 161 | target_modules=["attention.self.query", "attention.self.value"] 162 | ) 163 | 164 | peft_config_tcrbert = LoraConfig( 165 | r=8, 166 | lora_alpha=32, 167 | lora_dropout=0.1, 168 | bias='none', 169 | layers_to_transform=[11, 10, 9, 8], 170 | task_type=TaskType.FEATURE_EXTRACTION, 171 | target_modules=["attention.self.key", "attention.self.value"] 172 | ) 173 | 174 | peft_config_inhouse = LoraConfig( 175 | r=8, 176 | lora_alpha=32, 177 | lora_dropout=0.1, 178 | bias='none', 179 | layers_to_transform=[11, 10, 9, 8], 180 | task_type=TaskType.FEATURE_EXTRACTION, 181 | target_modules=["attention.self.key", "attention.self.value"] 182 | ) 183 | # -------------------------------------------------- 184 | 185 | 186 | def get_lightning_config(name='default'): 187 | if name == 'default': 188 | return LightningConfig() 189 | 190 | def get_projection_config(name='ablang'): 191 | if name == 'ablang': 192 | return EncoderProjectionConfig() 193 | elif name == 'ablang2': 194 | return EncoderProjectionConfigAbLang2() 195 | elif name == 'antiberta2': 196 | return EncoderProjectionConfigAntiberta2() 197 | elif name == 'tcrbert': 198 | return EncoderProjectionConfigTCRBert() 199 | elif name == 'tcrlang': 200 | return EncoderProjectionConfigTCRLang() 201 | elif name == 'esm2': 202 | return EncoderProjectionConfigESM2() 203 | elif name == 'esm3': 204 | return EncoderProjectionConfigESM3() 205 | elif name == 'inhouse': 206 | return EncoderProjectionConfigInHouse() 207 | elif name == 'onehot': 208 | return EncoderProjectionConfigOneHot() 209 | else: 210 | raise ValueError(f"Invalid model name: {name}") 211 | 212 | 213 | def build_lora_config(rank=4, alpha=32, dropout=0.1, bias='none', layers_to_transform=None, 214 | task_type=TaskType.FEATURE_EXTRACTION, target_modules=None): 215 | 216 | return LoraConfig( 217 | r=rank, 218 | lora_alpha=alpha, 219 | lora_dropout=dropout, 220 | bias=bias, 221 | layers_to_transform=layers_to_transform, 222 | task_type=task_type, 223 | target_modules=target_modules 224 | ) 225 | -------------------------------------------------------------------------------- /src/lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from peft import get_peft_model 5 | 6 | 7 | def setup_peft_esm2(peft_config, no_lora = False, regular_ft=False): 8 | 9 | from transformers import EsmModel, EsmTokenizer 10 | 11 | # Load the pretrained ESM-2 model 12 | esm_model = EsmModel.from_pretrained('facebook/esm2_t33_650M_UR50D') 13 | esm_tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t33_650M_UR50D') 14 | 15 | if regular_ft: 16 | return esm_model, esm_tokenizer 17 | 18 | # Apply LoRA to the model 19 | peft_lm = get_peft_model(esm_model, peft_config) 20 | 21 | #NOT APPLYING LoRA to the model: 22 | if no_lora: 23 | for name, param in esm_model.named_parameters(): 24 | param.requires_grad = False 25 | return esm_model, esm_tokenizer 26 | 27 | # freeze all the layers except the LoRA adapter matrices 28 | for name, param in peft_lm.named_parameters(): 29 | if "lora" in name: 30 | param.requires_grad = True 31 | else: 32 | param.required_grad = False 33 | 34 | return peft_lm, esm_tokenizer 35 | 36 | def setup_peft_esm3(peft_config, no_lora = False): 37 | 38 | from esm.models.esm3 import ESM3 39 | from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig 40 | from esm.tokenization import EsmSequenceTokenizer 41 | 42 | # Load the pretrained ESM-3 model 43 | esm3_model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1") 44 | esm3_tokenizer = EsmSequenceTokenizer() 45 | 46 | #NOT APPLYING LoRA to the model: 47 | if no_lora: 48 | for name, param in esm3_model.named_parameters(): 49 | param.requires_grad = False 50 | return esm3_model, esm3_tokenizer 51 | 52 | # Apply LoRA to the model 53 | peft_lm = get_peft_model(esm3_model, peft_config) 54 | 55 | # freeze all the layers except the LoRA adapter matrices 56 | for name, param in esm3_model.named_parameters(): 57 | if "lora" in name: 58 | param.requires_grad = True 59 | else: 60 | param.required_grad = False 61 | 62 | return esm3_model, esm3_tokenizer 63 | 64 | def setup_peft_ablang(peft_config, chain="H"): 65 | 66 | from transformers import AutoTokenizer, AutoModelForMaskedLM 67 | 68 | if chain == "H": 69 | # Load the pretrained AbLang H model 70 | ablang_tokenizer = AutoTokenizer.from_pretrained("qilowoq/AbLang_heavy", trust_remote_code=True) 71 | ablang_model = AutoModelForMaskedLM.from_pretrained("qilowoq/AbLang_heavy", trust_remote_code=True) 72 | 73 | if chain == "L": 74 | # Load the pretrained AbLang L model 75 | ablang_tokenizer = AutoTokenizer.from_pretrained("qilowoq/AbLang_light", trust_remote_code=True) 76 | ablang_model = AutoModelForMaskedLM.from_pretrained("qilowoq/AbLang_light", trust_remote_code=True) 77 | 78 | # take out the decoder layer, which we don't need 79 | ablang_model = ablang_model.roberta 80 | 81 | # Apply LoRA to the model 82 | peft_lm = get_peft_model(ablang_model, peft_config) 83 | 84 | # freeze all the layers except the LoRA adapter matrices 85 | for name, param in peft_lm.named_parameters(): 86 | if "lora" in name: 87 | param.requires_grad = True 88 | else: 89 | param.required_grad = False 90 | 91 | return peft_lm, ablang_tokenizer 92 | 93 | def setup_peft_ablang2(peft_config, receptor_type='BCR', device='cpu', no_lora=False): 94 | import ablang2 95 | 96 | # Load the pretrained AbLang2 model 97 | if receptor_type == 'TCR': 98 | ablang2_module = ablang2.pretrained(model_to_use='tcrlang-paired', random_init=False, device=device) 99 | elif receptor_type == 'BCR': 100 | ablang2_module = ablang2.pretrained(model_to_use='ablang2-paired', random_init=False, device=device) 101 | else: 102 | raise ValueError(f"Receptor type {receptor_type} not supported") 103 | ablang2_model = ablang2_module.AbRep 104 | 105 | # NOT APPLYING LoRA to the model: 106 | if no_lora: 107 | for name, param in ablang2_model.named_parameters(): 108 | param.requires_grad = False 109 | return ablang2_model, ablang2_module.tokenizer 110 | 111 | # Apply LoRA to the model 112 | peft_lm = get_peft_model(ablang2_model, peft_config) 113 | 114 | # freeze all the layers except the LoRA adapter matrices 115 | lora_count = 0 116 | for name, param in ablang2_model.named_parameters(): 117 | if "lora" in name: 118 | lora_count += 1 119 | param.requires_grad = True 120 | else: 121 | param.required_grad = False 122 | assert lora_count >= 4 # make sure we have LoRA adapter matrices 123 | 124 | return ablang2_model, ablang2_module.tokenizer 125 | 126 | def setup_peft_aberta2(peft_config): 127 | from transformers import ( 128 | RoFormerForMaskedLM, 129 | RoFormerTokenizer, 130 | ) 131 | 132 | # Load the pretrained Aberta2 model 133 | aberta2_model = RoFormerForMaskedLM.from_pretrained("alchemab/antiberta2") 134 | aberta2_tokenizer = RoFormerTokenizer.from_pretrained("alchemab/antiberta2") 135 | 136 | # only take the RoFormer module: 137 | aberta2_model = aberta2_model.roformer 138 | 139 | # Apply LoRA to the model 140 | peft_lm = get_peft_model(aberta2_model, peft_config) 141 | 142 | # freeze all the layers except the LoRA adapter matrices 143 | for name, param in peft_lm.named_parameters(): 144 | if "lora" in name: 145 | param.requires_grad = True 146 | else: 147 | param.required_grad = False 148 | 149 | return peft_lm, aberta2_tokenizer 150 | 151 | def setup_peft_tcrbert(peft_config, no_lora=False, regular_ft=False): 152 | from transformers import ( 153 | BertModel, 154 | AutoTokenizer, 155 | ) 156 | 157 | # Load the pretrained TCRBert model 158 | tcrbert_model = BertModel.from_pretrained("wukevin/tcr-bert-mlm-only") 159 | tcrbert_tokenizer = AutoTokenizer.from_pretrained("wukevin/tcr-bert-mlm-only", trust_remote_code=True) 160 | 161 | if regular_ft: 162 | return tcrbert_model, tcrbert_tokenizer 163 | 164 | # Apply LoRA to the model 165 | peft_lm = get_peft_model(tcrbert_model, peft_config) 166 | 167 | # NOT APPLYING LoRA to the model: 168 | if no_lora: 169 | for name, param in tcrbert_model.named_parameters(): 170 | param.requires_grad = False 171 | return tcrbert_model, tcrbert_tokenizer 172 | 173 | # freeze all the layers except the LoRA adapter matrices 174 | for name, param in peft_lm.named_parameters(): 175 | if "lora" in name: 176 | param.requires_grad = True 177 | else: 178 | param.required_grad = False 179 | 180 | return peft_lm, tcrbert_tokenizer 181 | 182 | def setup_peft_inhouse(peft_config, no_lora=False, model_ckpt_path=None): 183 | from .pretrain.model import CdrBERT, getCdrTokenizer, MODEL_CONFIG 184 | 185 | # load the in-house TCR model: 186 | inhouse_tokenizer = getCdrTokenizer() 187 | inhouse_model = CdrBERT(MODEL_CONFIG, inhouse_tokenizer) 188 | inhouse_ckpt = torch.load(model_ckpt_path) 189 | # Remove "model." prefix from keys. Artifact of Pytorch Lightning 190 | new_state_dict = {} 191 | for key, value in inhouse_ckpt['state_dict'].items(): 192 | new_key = key.replace("model.", "") 193 | new_state_dict[new_key] = value 194 | inhouse_model.load_state_dict(new_state_dict) 195 | 196 | # Apply LoRA to the model 197 | peft_lm = get_peft_model(inhouse_model, peft_config) 198 | 199 | # NOT APPLYING LoRA to the model: 200 | if no_lora: 201 | for name, param in inhouse_model.named_parameters(): 202 | param.requires_grad = False 203 | return inhouse_model, inhouse_tokenizer 204 | 205 | # freeze all the layers except the LoRA adapter matrices 206 | for name, param in peft_lm.named_parameters(): 207 | if "lora" in name: 208 | param.requires_grad = True 209 | else: 210 | param.required_grad = False 211 | 212 | return peft_lm, inhouse_tokenizer -------------------------------------------------------------------------------- /src/training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | from pytorch_lightning import Trainer 4 | from lightning.pytorch.accelerators import find_usable_cuda_devices 5 | from pytorch_lightning.utilities import rank_zero_only 6 | import pytorch_lightning.loggers as log 7 | from lightning.pytorch.strategies import DDPStrategy 8 | 9 | import argparse 10 | import os 11 | import wandb 12 | 13 | from .alignment import CLIPModel 14 | from .data_module import EpitopeReceptorDataModule 15 | from .configs import get_lightning_config, get_projection_config 16 | from .logging import call_backs, combine_args_and_configs 17 | 18 | 19 | def setup_parser(): 20 | 21 | # Command line interface arguments and parsing 22 | parser = argparse.ArgumentParser(description='argument parser for training') 23 | parser.add_argument('--batch-size', type=int, default=4, help='Batch size for training') 24 | parser.add_argument('--grad-accum', type=int, default=1, help='Number of gradient accumulation steps') 25 | parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate') 26 | parser.add_argument('--weight-decay', type=float, default=0.01, help='Weight decay parameter for AdamW algorithm') 27 | parser.add_argument('--random-seed', type=int, default=14, help='Random seed for reproducibility') 28 | 29 | # WandB configuration: 30 | parser.add_argument('--entity', type=str, default='lordim', help='entity name') 31 | parser.add_argument('--project', type=str, default='clip_antibody', help='project name') 32 | parser.add_argument('--group', type=str, default='clipbody_test', help='group name') 33 | parser.add_argument('--run-id', type=str, default='clipbody_test', help='run id') 34 | parser.add_argument('--use-wandb', default=False, action='store_true', help='use wandb for logging') 35 | 36 | # Training and Data configuration 37 | parser.add_argument('--receptor-model-name', type=str, default='ablang', help='name of the receptor foundation model') 38 | parser.add_argument('--receptor-type', type=str, default='TCR', help='Is the receptor BCR or TCR') 39 | parser.add_argument('--include-mhc', default=False, action='store_true', help='include MHC sequences alongside epitope in the training data') 40 | parser.add_argument('--mhc-groove-only', default=False, action='store_true', help='only include A1-A2 domains for class I MHC, A1-B1 domains for class II MHC') 41 | parser.add_argument('--unique-epitopes', default=False, action='store_true', help='split the data based on unique epitopes') 42 | parser.add_argument('--no-lora', default=False, action='store_true', help='do not use LoRA adapter matrices for the models') 43 | parser.add_argument('--regular-ft', default=False, action='store_true', help='use regular fine-tuning') 44 | parser.add_argument('--mask-seqs', default=False, action='store_true', help='mask the sequences for training') 45 | parser.add_argument('--mask-prob', type=float, default=0.15, help='probability of masking a residue') 46 | parser.add_argument('--mse-weight', type=float, default=0., help='weight for the MSE loss') 47 | parser.add_argument('--weigh-epitope-count', default=False, action='store_true', help='weight the epitope count in the clip loss') 48 | parser.add_argument('--swe-pooling', default=False, action='store_true', help='use SWE pooling for sequence embeddings') 49 | parser.add_argument('--hidden-dim', type=int, default=None, help='dimension of the hidden layer') 50 | parser.add_argument('--projection-dim', type=int, default=None, help='dimension of the projection layer') 51 | parser.add_argument('--lightning-config-name', type=str, default='default') 52 | parser.add_argument('--dataset-path', type=str, required=True, help='path to the dataset') 53 | parser.add_argument('--mhc-path', type=str, default=None, help='path to file with MHC sequence info. Required if --include-mhc is set to True') 54 | parser.add_argument('--oversample', default=False, action='store_true', help='oversample the epitopes with few receptor data') 55 | parser.add_argument('--fewshot-ratio', type=float, default=None, help='ratio of few-shot data to the total data') 56 | parser.add_argument('--lr-scheduler', type=str, default='cos_anneal', help='learning rate scheduler') 57 | 58 | # PyTorch Lightning configuration 59 | parser.add_argument('--torch-device', type=str, default='gpu') 60 | parser.add_argument('--output-dir', type=str, required=True, help='wandb and checkpoint output') 61 | parser.add_argument('--num-gpus', type=int, default=1, help='number of GPUs to use') 62 | parser.add_argument('--max-epochs', type=int, default = 1, required=False) 63 | parser.add_argument('--gpus-used', type=int, nargs='+', required=False, help='which GPUs used for env variable CUDA_VISIBLE_DEVICES') 64 | parser.add_argument('--stage', type=str, default='fit', help='stage of training') 65 | parser.add_argument('--check-val-every-n-epoch', type=int, default=1, help='check validation every n epochs') 66 | parser.add_argument('--val-check-interval', type=float, default=1.0, help='validation check interval') 67 | parser.add_argument('--save-ckpts', default=False, action='store_true', help='save checkpoints') 68 | parser.add_argument('--from-checkpoint', type=str, default=None, help='path to checkpoint') 69 | parser.add_argument('--save-embed-path', type=str, default=None, help='path to save embeddings for eval') 70 | 71 | args = parser.parse_args() 72 | 73 | return args 74 | 75 | 76 | 77 | if __name__ == '__main__': 78 | # utilizing Tensor Cores: 79 | torch.set_float32_matmul_precision('high') 80 | 81 | # setting up the environment variables: 82 | os.environ["TOKENIZERS_PARALLELISM"] = "true" # resolving tokenizers parallelism issue 83 | 84 | args = setup_parser() 85 | 86 | # retrieve the configs: 87 | lightning_config = get_lightning_config() 88 | model_config = get_projection_config(args.receptor_model_name) 89 | 90 | # update configs based on input arguments: 91 | combine_args_and_configs(args, [lightning_config, model_config]) 92 | 93 | # setup callbacks: 94 | if args.stage == 'fit' and args.output_dir is not None: 95 | if not os.path.exists(os.path.join(args.output_dir, args.run_id)): 96 | os.makedirs(os.path.join(args.output_dir, args.run_id)) 97 | 98 | output_dir = os.path.join(args.output_dir, args.run_id) 99 | 100 | # get callbacks 101 | callbacks = call_backs(output_dir, args.save_ckpts) 102 | else: 103 | callbacks = None 104 | 105 | # construct PyTorch Lightning Module: 106 | if args.receptor_type == 'TCR': 107 | print("Using TCR data!") 108 | else: 109 | print("Using BCR data!") 110 | tsv_file_path = args.dataset_path 111 | mhc_file_path = args.mhc_path 112 | 113 | if args.mask_seqs: 114 | print("WARNING: Partially making sequence residues during training") 115 | 116 | if args.unique_epitopes: 117 | print("WARNING: Splitting data based on unique epitopes") 118 | 119 | pl_datamodule = EpitopeReceptorDataModule(tsv_file_path, mhc_file=mhc_file_path, ln_cfg=lightning_config, 120 | batch_size=lightning_config.batch_size, include_mhc=lightning_config.include_mhc, 121 | model_config=model_config, random_seed=args.random_seed) 122 | 123 | # construct the CLIP model: 124 | clip_model = CLIPModel(lightning_config, model_config) 125 | 126 | if rank_zero_only.rank == 0: 127 | # initalize wandb: 128 | if args.use_wandb: 129 | run = wandb.init(project=args.project, 130 | entity=args.entity, 131 | group=args.group, 132 | dir=output_dir, 133 | name=args.run_id, 134 | id=args.run_id, 135 | resume=True if args.from_checkpoint else None, 136 | ) 137 | 138 | run_output_dir = run.dir 139 | wandb_logger = log.WandbLogger(save_dir=run_output_dir, log_model=False) 140 | wandb_logger.watch(clip_model) 141 | 142 | if len(args.gpus_used) > 1: 143 | strat = 'ddp' 144 | if args.regular_ft: 145 | strat = 'ddp_find_unused_parameters_true' 146 | else: 147 | strat = 'auto' 148 | # build PyTorch Lightning Trainer: 149 | trainer = Trainer(max_epochs=args.max_epochs, 150 | logger=wandb_logger if args.use_wandb else None, 151 | accelerator=args.torch_device, 152 | devices=args.gpus_used if args.gpus_used else 1, #TODO: smooth CPU/GPU conversion 153 | enable_progress_bar=True, 154 | callbacks=callbacks if callbacks is not None else None, 155 | accumulate_grad_batches=args.grad_accum, 156 | reload_dataloaders_every_n_epochs=1 if args.oversample else 0, 157 | strategy=strat, 158 | ) 159 | else: 160 | strat = 'ddp' 161 | if args.regular_ft: 162 | strat = 'ddp_find_unused_parameters_true' 163 | # build PyTorch Lightning Trainer: 164 | trainer = Trainer(max_epochs=args.max_epochs, 165 | logger=None, 166 | accelerator=args.torch_device, 167 | devices=args.gpus_used if args.gpus_used else 1, #TODO: smooth CPU/GPU conversion 168 | enable_progress_bar=True, 169 | callbacks=callbacks if callbacks is not None else None, 170 | accumulate_grad_batches=args.grad_accum, 171 | reload_dataloaders_every_n_epochs=1 if args.oversample else 0, 172 | strategy=strat, 173 | ) 174 | 175 | # run the model: 176 | if args.stage == 'fit': 177 | print('Start Training...') 178 | trainer.fit(model=clip_model, datamodule=pl_datamodule, ckpt_path=args.from_checkpoint) 179 | else: 180 | print("**********************") 181 | print("* Inference Mode... *") 182 | print("**********************") 183 | trainer.test(model=clip_model, datamodule=pl_datamodule, ckpt_path=args.from_checkpoint) 184 | -------------------------------------------------------------------------------- /src/swe_pooling.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Contents of this file are from the open source code for 3 | 4 | NaderiAlizadeh, Navid, and Rohit Singh. 5 | Aggregating Residue-Level Protein Language Model Embeddings with Optimal Transport. 6 | bioRxiv (2024): 2024-01. 7 | 8 | MIT License 9 | 10 | Copyright (c) 2024 Navid NaderiAlizadeh and Rohit Singh 11 | 12 | Permission is hereby granted, free of charge, to any person obtaining a copy 13 | of this software and associated documentation files (the "Software"), to deal 14 | in the Software without restriction, including without limitation the rights 15 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 16 | copies of the Software, and to permit persons to whom the Software is 17 | furnished to do so, subject to the following conditions: 18 | 19 | The above copyright notice and this permission notice shall be included in all 20 | copies or substantial portions of the Software. 21 | 22 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 23 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 24 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 25 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 26 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 27 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 28 | SOFTWARE. 29 | ''' 30 | 31 | from types import SimpleNamespace 32 | 33 | import os 34 | import pickle as pk 35 | from functools import lru_cache 36 | 37 | import torch 38 | import torch.nn as nn 39 | from torch.nn.utils.rnn import pad_sequence 40 | from torch.utils.data import DataLoader, Dataset 41 | from tqdm import tqdm 42 | 43 | import contextlib 44 | 45 | class Interp1d(torch.autograd.Function): 46 | def __call__(self, x, y, xnew, out=None): 47 | return self.forward(x, y, xnew, out) 48 | 49 | def forward(ctx, x, y, xnew, out=None): 50 | """ 51 | Linear 1D interpolation on the GPU for Pytorch. 52 | This function returns interpolated values of a set of 1-D functions at 53 | the desired query points `xnew`. 54 | This function is working similarly to Matlab™ or scipy functions with 55 | the `linear` interpolation mode on, except that it parallelises over 56 | any number of desired interpolation problems. 57 | The code will run on GPU if all the tensors provided are on a cuda 58 | device. 59 | Parameters 60 | ---------- 61 | x : (N, ) or (D, N) Pytorch Tensor 62 | A 1-D or 2-D tensor of real values. 63 | y : (N,) or (D, N) Pytorch Tensor 64 | A 1-D or 2-D tensor of real values. The length of `y` along its 65 | last dimension must be the same as that of `x` 66 | xnew : (P,) or (D, P) Pytorch Tensor 67 | A 1-D or 2-D tensor of real values. `xnew` can only be 1-D if 68 | _both_ `x` and `y` are 1-D. Otherwise, its length along the first 69 | dimension must be the same as that of whichever `x` and `y` is 2-D. 70 | out : Pytorch Tensor, same shape as `xnew` 71 | Tensor for the output. If None: allocated automatically. 72 | """ 73 | # making the vectors at least 2D 74 | is_flat = {} 75 | require_grad = {} 76 | v = {} 77 | device = [] 78 | eps = torch.finfo(y.dtype).eps 79 | for name, vec in {'x': x, 'y': y, 'xnew': xnew}.items(): 80 | assert len(vec.shape) <= 2, 'interp1d: all inputs must be '\ 81 | 'at most 2-D.' 82 | if len(vec.shape) == 1: 83 | v[name] = vec[None, :] 84 | else: 85 | v[name] = vec 86 | is_flat[name] = v[name].shape[0] == 1 87 | require_grad[name] = vec.requires_grad 88 | device = list(set(device + [str(vec.device)])) 89 | assert len(device) == 1, 'All parameters must be on the same device.' 90 | device = device[0] 91 | 92 | # Checking for the dimensions 93 | assert (v['x'].shape[1] == v['y'].shape[1] 94 | and ( 95 | v['x'].shape[0] == v['y'].shape[0] 96 | or v['x'].shape[0] == 1 97 | or v['y'].shape[0] == 1 98 | ) 99 | ), ("x and y must have the same number of columns, and either " 100 | "the same number of row or one of them having only one " 101 | "row.") 102 | 103 | reshaped_xnew = False 104 | if ((v['x'].shape[0] == 1) and (v['y'].shape[0] == 1) 105 | and (v['xnew'].shape[0] > 1)): 106 | # if there is only one row for both x and y, there is no need to 107 | # loop over the rows of xnew because they will all have to face the 108 | # same interpolation problem. We should just stack them together to 109 | # call interp1d and put them back in place afterwards. 110 | original_xnew_shape = v['xnew'].shape 111 | v['xnew'] = v['xnew'].contiguous().view(1, -1) 112 | reshaped_xnew = True 113 | 114 | # identify the dimensions of output and check if the one provided is ok 115 | D = max(v['x'].shape[0], v['xnew'].shape[0]) 116 | shape_ynew = (D, v['xnew'].shape[-1]) 117 | if out is not None: 118 | if out.numel() != shape_ynew[0]*shape_ynew[1]: 119 | # The output provided is of incorrect shape. 120 | # Going for a new one 121 | out = None 122 | else: 123 | ynew = out.reshape(shape_ynew) 124 | if out is None: 125 | ynew = torch.zeros(*shape_ynew, device=device) 126 | 127 | # moving everything to the desired device in case it was not there 128 | # already (not handling the case things do not fit entirely, user will 129 | # do it if required.) 130 | for name in v: 131 | v[name] = v[name].to(device) 132 | 133 | # calling searchsorted on the x values. 134 | ind = ynew.long() 135 | 136 | # expanding xnew to match the number of rows of x in case only one xnew is 137 | # provided 138 | if v['xnew'].shape[0] == 1: 139 | v['xnew'] = v['xnew'].expand(v['x'].shape[0], -1) 140 | 141 | torch.searchsorted(v['x'].contiguous(), 142 | v['xnew'].contiguous(), out=ind) 143 | 144 | # the `-1` is because searchsorted looks for the index where the values 145 | # must be inserted to preserve order. And we want the index of the 146 | # preceeding value. 147 | ind -= 1 148 | # we clamp the index, because the number of intervals is x.shape-1, 149 | # and the left neighbour should hence be at most number of intervals 150 | # -1, i.e. number of columns in x -2 151 | ind = torch.clamp(ind, 0, v['x'].shape[1] - 1 - 1) 152 | 153 | # helper function to select stuff according to the found indices. 154 | def sel(name): 155 | if is_flat[name]: 156 | return v[name].contiguous().view(-1)[ind] 157 | return torch.gather(v[name], 1, ind) 158 | 159 | # activating gradient storing for everything now 160 | enable_grad = False 161 | saved_inputs = [] 162 | for name in ['x', 'y', 'xnew']: 163 | if require_grad[name]: 164 | enable_grad = True 165 | saved_inputs += [v[name]] 166 | else: 167 | saved_inputs += [None, ] 168 | # assuming x are sorted in the dimension 1, computing the slopes for 169 | # the segments 170 | is_flat['slopes'] = is_flat['x'] 171 | # now we have found the indices of the neighbors, we start building the 172 | # output. Hence, we start also activating gradient tracking 173 | with torch.enable_grad() if enable_grad else contextlib.suppress(): 174 | v['slopes'] = ( 175 | (v['y'][:, 1:]-v['y'][:, :-1]) 176 | / 177 | (eps + (v['x'][:, 1:]-v['x'][:, :-1])) 178 | ) 179 | 180 | # now build the linear interpolation 181 | ynew = sel('y') + sel('slopes')*( 182 | v['xnew'] - sel('x')) 183 | 184 | if reshaped_xnew: 185 | ynew = ynew.view(original_xnew_shape) 186 | 187 | ctx.save_for_backward(ynew, *saved_inputs) 188 | return ynew 189 | 190 | @staticmethod 191 | def backward(ctx, grad_out): 192 | inputs = ctx.saved_tensors[1:] 193 | gradients = torch.autograd.grad( 194 | ctx.saved_tensors[0], 195 | [i for i in inputs if i is not None], 196 | grad_out, retain_graph=True) 197 | result = [None, ] * 5 198 | pos = 0 199 | for index in range(len(inputs)): 200 | if inputs[index] is not None: 201 | result[index] = gradients[pos] 202 | pos += 1 203 | return (*result,) 204 | 205 | class SWE_Pooling(nn.Module): 206 | def __init__(self, d_in, num_ref_points, num_slices): 207 | ''' 208 | Produces fixed-dimensional permutation-invariant embeddings for input sets of arbitrary size based on sliced-Wasserstein embedding. 209 | Inputs: 210 | d_in: The dimensionality of the space that each set sample belongs to 211 | num_ref_points: Number of points in the reference set 212 | num_slices: Number of slices 213 | ''' 214 | super(SWE_Pooling, self).__init__() 215 | self.d_in = d_in 216 | self.num_ref_points = num_ref_points 217 | self.num_slices = num_slices 218 | 219 | uniform_ref = torch.linspace(-1, 1, num_ref_points).unsqueeze(1).repeat(1, num_slices) 220 | self.reference = nn.Parameter(uniform_ref) 221 | 222 | self.theta = nn.utils.weight_norm(nn.Linear(d_in, num_slices, bias=False), dim=0) 223 | if num_slices <= d_in: 224 | nn.init.eye_(self.theta.weight_v) 225 | else: 226 | nn.init.normal_(self.theta.weight_v) 227 | 228 | self.theta.weight_g.data = torch.ones_like(self.theta.weight_g.data, requires_grad=False) 229 | self.theta.weight_g.requires_grad = False 230 | 231 | # weights to reduce the output embedding dimensionality 232 | self.weight = nn.Linear(num_ref_points, 1, bias=False) 233 | 234 | def forward(self, X, mask=None): 235 | ''' 236 | Calculates GSW between two empirical distributions. 237 | Note that the number of samples is assumed to be equal 238 | (This is however not necessary and could be easily extended 239 | for empirical distributions with different number of samples) 240 | Input: 241 | X: B x N x dn tensor, containing a batch of B sets, each containing N samples in a dn-dimensional space 242 | mask [optional]: B x N binary tensor, with 1 iff the set element is valid; used for the case where set sizes are different 243 | Output: 244 | weighted_embeddings: B x num_slices tensor, containing a batch of B embeddings, each of dimension "num_slices" (i.e., number of slices) 245 | ''' 246 | 247 | B, N, _ = X.shape 248 | Xslices = self.get_slice(X) 249 | 250 | M, _ = self.reference.shape 251 | 252 | if mask is None: 253 | # serial implementation should be used if set sizes are different 254 | Xslices_sorted, Xind = torch.sort(Xslices, dim=1) 255 | 256 | if M == N: 257 | Xslices_sorted_interpolated = Xslices_sorted 258 | else: 259 | x = torch.linspace(0, 1, N + 2)[1:-1].unsqueeze(0).repeat(B * self.num_slices, 1).to(X.device) 260 | xnew = torch.linspace(0, 1, M + 2)[1:-1].unsqueeze(0).repeat(B * self.num_slices, 1).to(X.device) 261 | y = torch.transpose(Xslices_sorted, 1, 2).reshape(B * self.num_slices, -1) 262 | Xslices_sorted_interpolated = torch.transpose(Interp1d()(x, y, xnew).view(B, self.num_slices, -1), 1, 2) 263 | else: 264 | # replace invalid set elements with points to the right of the maximum element for each slice and each set (which will not impact the sorting and interpolation process) 265 | invalid_elements_mask = ~mask.bool().unsqueeze(-1).repeat(1, 1, self.num_slices) 266 | Xslices_copy = Xslices.clone() 267 | Xslices_copy[invalid_elements_mask] = -1e10 268 | 269 | top2_Xslices, _ = torch.topk(Xslices_copy, k=2, dim=1) 270 | max_Xslices = top2_Xslices[:, 0].unsqueeze(1) 271 | delta_y = - torch.diff(top2_Xslices, dim=1) 272 | 273 | Xslices_modified = Xslices.clone() 274 | 275 | Xslices_modified[invalid_elements_mask] = max_Xslices.repeat(1, N, 1)[invalid_elements_mask] 276 | 277 | delta_x = 1 / (1 + torch.sum(mask, dim=1, keepdim=True)) 278 | slope = delta_y / delta_x.unsqueeze(-1).repeat(1, 1, self.num_slices) # B x 1 x num_slices 279 | slope = slope.repeat(1, N, 1) 280 | 281 | eps = 1e-3 282 | x_shifts = eps * torch.cumsum(invalid_elements_mask, dim=1) 283 | y_shifts = slope * x_shifts 284 | Xslices_modified = Xslices_modified + y_shifts 285 | 286 | Xslices_sorted, _ = torch.sort(Xslices_modified, dim=1) 287 | 288 | x = torch.arange(1, N + 1).to(X.device) / (1 + torch.sum(mask, dim=1, keepdim=True)) # B x N 289 | 290 | invalid_elements_mask = ~mask.bool() 291 | x_copy = x.clone() 292 | x_copy[invalid_elements_mask] = -1e10 293 | max_x, _ = torch.max(x_copy, dim=1, keepdim=True) 294 | x[invalid_elements_mask] = max_x.repeat(1, N)[invalid_elements_mask] 295 | 296 | x = x.unsqueeze(1).repeat(1, self.num_slices, 1) + torch.transpose(x_shifts, 1, 2) 297 | x = x.view(-1, N) # BL x N 298 | 299 | xnew = torch.linspace(0, 1, M + 2)[1:-1].unsqueeze(0).repeat(B * self.num_slices, 1).to(X.device) 300 | y = torch.transpose(Xslices_sorted, 1, 2).reshape(B * self.num_slices, -1) 301 | Xslices_sorted_interpolated = torch.transpose(Interp1d()(x, y, xnew).view(B, self.num_slices, -1), 1, 2) 302 | 303 | Rslices = self.reference.expand(Xslices_sorted_interpolated.shape) 304 | 305 | _, Rind = torch.sort(Rslices, dim=1) 306 | embeddings = (Rslices - torch.gather(Xslices_sorted_interpolated, dim=1, index=Rind)).permute(0, 2, 1) # B x num_slices x M 307 | 308 | weighted_embeddings = self.weight(embeddings).sum(-1) 309 | 310 | return weighted_embeddings.view(-1, self.num_slices) 311 | 312 | def get_slice(self, X): 313 | ''' 314 | Slices samples from distribution X~P_X 315 | Input: 316 | X: B x N x dn tensor, containing a batch of B sets, each containing N samples in a dn-dimensional space 317 | ''' 318 | return self.theta(X) -------------------------------------------------------------------------------- /src/alignment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | import pytorch_lightning as pl 8 | from torch.optim.lr_scheduler import ReduceLROnPlateau 9 | 10 | from .models import * 11 | from .utils import construct_label_matrices, construct_label_matrices_ones 12 | from .data_module import compute_weights 13 | from .lr_scheduler import CosineAnnealingWarmUpRestarts 14 | 15 | class CLIPModel(pl.LightningModule): 16 | def __init__(self, lightning_config, model_config, device='cuda', **kwargs): 17 | super().__init__() 18 | self.save_hyperparameters() 19 | 20 | self.ln_cfg = lightning_config 21 | self.model_config = model_config 22 | 23 | self.epitope_input_dim = model_config.epitope_input_dim 24 | self.receptor_input_dim = model_config.receptor_input_dim 25 | self.projection_dim = model_config.projection_dim 26 | self.hidden_dim = model_config.hidden_dim 27 | 28 | # loss functions: 29 | self.bceloss_logits = nn.BCEWithLogitsLoss(reduction='none') 30 | self.celoss = nn.CrossEntropyLoss(reduction='none') 31 | self.mse_weight = lightning_config.mse_weight 32 | self.epitope_weights = None 33 | 34 | # logging: 35 | self.log_iterations = None 36 | self.training_step_metrics = {} 37 | self.val_step_metrics = {} 38 | self.test_step_metrics = {} 39 | 40 | # for evaluation later: 41 | self.epitope_embeddings = [] 42 | self.receptor_embeddings = [] 43 | self.epitope_sequences = [] 44 | self.receptor_sequences = [] 45 | 46 | 47 | def forward(self, epitope_seqs, receptor_seqs, mask=False): 48 | epitope_proj = self.epitope_encoder(epitope_seqs, mask=mask) 49 | receptor_proj = self.receptor_encoder(receptor_seqs, mask=mask) 50 | return epitope_proj, receptor_proj 51 | 52 | 53 | def clip_loss_multiclass(self, epitope_features, receptor_features, label_matrix, temperature=1.0): 54 | """ 55 | Compute the multi-class CLIP loss for epitope and receptor features based on label_indices 56 | 57 | Args: 58 | epitope_features: Tensor of shape (batch_size, feature_dim) representing the epitope embeddings. 59 | receptor_features: Tensor of shape (batch_size, feature_dim) representing the receptor embeddings. 60 | label_indices: list of length (batch_size) where each element is a list of indices 61 | of the correct labels for each epitope. 62 | temperature: A scaling factor to control the sharpness of the similarity distribution. 63 | 64 | Returns: 65 | loss: A scalar tensor representing the multi-class CLIP loss. 66 | """ 67 | 68 | # Normalize the features to unit length (each dim bs x proj_dim) 69 | epitope_features = F.normalize(epitope_features, dim=-1) 70 | receptor_features = F.normalize(receptor_features, dim=-1) 71 | 72 | # MSE Loss between the normalized features: 73 | diff_norm = torch.norm(epitope_features - receptor_features, dim=-1) 74 | mse_loss = F.mse_loss(diff_norm, torch.zeros(len(diff_norm)).to(self.device), reduction='mean') 75 | 76 | # Compute the logits (similarities) as the dot product of epitope and receptor features 77 | logits_per_epitope = epitope_features @ receptor_features.t() 78 | logits_per_receptor = receptor_features @ epitope_features.t() 79 | 80 | # Scale by temperature 81 | logits_per_epitope /= temperature 82 | logits_per_receptor /= temperature 83 | 84 | # Compute the cross-entropy loss for both epitope-to-receptor and receptor-to-epitope 85 | # epitopes_loss = self.celoss(logits_per_epitope, label_matrix) 86 | # receptor_loss = self.celoss(logits_per_receptor, label_matrix) 87 | 88 | # Compute the binary cross-entropy loss for both epitope-to-receptor and receptor-to-epitope 89 | epitopes_loss = self.bceloss_logits(logits_per_epitope, label_matrix) 90 | receptor_loss = self.bceloss_logits(logits_per_receptor, label_matrix) 91 | 92 | # multiply the loss with inverse square-rooted count weights: 93 | 94 | clip_loss = (epitopes_loss + receptor_loss) / 2.0 # shape: (batch_size) 95 | return clip_loss, mse_loss 96 | 97 | 98 | def training_step(self, batch, batch_idx): 99 | """ 100 | Training step for the CLIPBody Model 101 | """ 102 | epitope_seqs, receptor_seqs = batch 103 | 104 | epitope_proj, receptor_proj = self(epitope_seqs, receptor_seqs, mask=self.ln_cfg.mask_seqs) 105 | 106 | # label_matrix = construct_label_matrices(epitope_seqs, receptor_seqs).to(self.device) 107 | label_matrix = construct_label_matrices_ones(epitope_seqs, receptor_seqs, self.ln_cfg.include_mhc).to(self.device) 108 | 109 | # print("epitope seqs:", epitope_seqs) 110 | 111 | # construct weight matrices for the epitope sequences: 112 | if self.ln_cfg.weigh_epitope_count: 113 | weights = torch.tensor([self.epitope_weights[seq] for seq in epitope_seqs]).to(self.device) 114 | clip_loss, mse_loss = self.clip_loss_multiclass(epitope_proj, receptor_proj, label_matrix, temperature=0.07) 115 | clip_loss = clip_loss * weights 116 | clip_loss = clip_loss.sum() 117 | else: 118 | clip_loss, mse_loss = self.clip_loss_multiclass(epitope_proj, receptor_proj, label_matrix, temperature=0.07) 119 | clip_loss = clip_loss.mean() 120 | 121 | loss = clip_loss * (1 - self.mse_weight) + mse_loss * self.mse_weight 122 | training_metrics = { 123 | 'loss': loss, 124 | } 125 | self.training_step_metrics.setdefault('loss', []).append(loss.detach().item()) 126 | if self.ln_cfg.mse_weight > 0: 127 | self.training_step_metrics.setdefault('clip_loss', []).append(clip_loss.detach().item()) 128 | self.training_step_metrics.setdefault('mse_loss', []).append(mse_loss.detach().item()) 129 | 130 | return training_metrics 131 | 132 | def validation_step(self, batch, batch_idx): 133 | """ 134 | Validation step for the CLIPBody Model 135 | """ 136 | 137 | epitope_seqs, receptor_seqs = batch 138 | try: 139 | epitope_proj, receptor_proj = self(epitope_seqs, receptor_seqs) 140 | except: 141 | print("Error in feeding sequences") 142 | print("epitope_seqs", epitope_seqs) 143 | print("receptor_seqs", receptor_seqs) 144 | raise ValueError 145 | 146 | # label_matrix = construct_label_matrices(epitope_seqs, receptor_seqs).to(self.device) 147 | label_matrix = construct_label_matrices_ones(epitope_seqs, receptor_seqs, self.ln_cfg.include_mhc).to(self.device) 148 | 149 | # construct weight matrices for the epitope sequences: 150 | if self.ln_cfg.weigh_epitope_count and not self.ln_cfg.unique_epitopes: 151 | weights = torch.tensor([self.epitope_weights[seq] for seq in epitope_seqs]).to(self.device) 152 | clip_loss, mse_loss = self.clip_loss_multiclass(epitope_proj, receptor_proj, label_matrix, temperature=0.07) 153 | clip_loss = clip_loss * weights 154 | clip_loss = clip_loss.sum() 155 | else: 156 | clip_loss, mse_loss = self.clip_loss_multiclass(epitope_proj, receptor_proj, label_matrix, temperature=0.07) 157 | clip_loss = clip_loss.mean() 158 | 159 | loss = clip_loss * (1 - self.mse_weight) + mse_loss * self.mse_weight 160 | val_metrics = { 161 | 'loss': loss, 162 | } 163 | self.val_step_metrics.setdefault('loss', []).append(loss.detach().item()) 164 | if self.ln_cfg.mse_weight > 0: 165 | self.val_step_metrics.setdefault('clip_loss', []).append(clip_loss.detach().item()) 166 | self.val_step_metrics.setdefault('mse_loss', []).append(mse_loss.detach().item()) 167 | 168 | return val_metrics 169 | 170 | def test_step(self, batch, batch_idx): 171 | """ 172 | Test step for the CLIPBody Model 173 | """ 174 | epitope_seqs, receptor_seqs = batch 175 | epitope_proj, receptor_proj = self(epitope_seqs, receptor_seqs) 176 | 177 | # save the embeddings batches for evaluation later 178 | if self.ln_cfg.include_mhc: 179 | epitope_seqs_to_save = epitope_seqs[0] # extract epitope seqs from the array [epitopes_seqs, mhca_seqs, mhcb_seqs] 180 | else: 181 | epitope_seqs_to_save = epitope_seqs 182 | self.epitope_sequences.append(epitope_seqs_to_save) 183 | self.receptor_sequences.append(receptor_seqs) 184 | self.epitope_embeddings.append(epitope_proj) 185 | self.receptor_embeddings.append(receptor_proj) 186 | 187 | # label_matrix = construct_label_matrices(epitope_seqs, receptor_seqs).to(self.device) 188 | label_matrix = construct_label_matrices_ones(epitope_seqs, receptor_seqs, self.ln_cfg.include_mhc).to(self.device) 189 | 190 | clip_loss, mse_loss = self.clip_loss_multiclass(epitope_proj, receptor_proj, label_matrix, temperature=0.07) 191 | 192 | clip_loss = clip_loss.mean() 193 | 194 | loss = clip_loss * (1 - self.mse_weight) + mse_loss * self.mse_weight 195 | test_metrics = { 196 | 'loss': loss, 197 | } 198 | self.test_step_metrics.setdefault('loss', []).append(loss.detach().item()) 199 | if self.ln_cfg.mse_weight > 0: 200 | self.test_step_metrics.setdefault('clip_loss', []).append(clip_loss.detach().item()) 201 | self.test_step_metrics.setdefault('mse_loss', []).append(mse_loss.detach().item()) 202 | 203 | return test_metrics 204 | 205 | def on_fit_start(self): 206 | # compute the weights for each epitope sequence 207 | if self.ln_cfg.weigh_epitope_count: 208 | print("Weighing the Epitopes by inverse sqrt of their counts!") 209 | self.epitope_weights = compute_weights(self.trainer.datamodule.train_dataloader().dataset.data['epitope'].tolist()) 210 | 211 | def on_train_epoch_end(self): 212 | pass 213 | 214 | def on_validation_epoch_end(self): 215 | for metric, values in self.training_step_metrics.items(): 216 | avg_metric = self.aggregate_metric(values) 217 | self.log(f'train_{metric}', avg_metric, prog_bar=False, sync_dist=True) 218 | print(f'Epoch train end: {metric}/train', avg_metric) 219 | self.training_step_metrics.clear() 220 | 221 | for metric, values in self.val_step_metrics.items(): 222 | avg_metric = self.aggregate_metric(values) 223 | self.log(f'val_{metric}', avg_metric, prog_bar=False, sync_dist=True) 224 | print(f'Epoch validation end: {metric}/val', avg_metric) 225 | self.val_step_metrics.clear() 226 | 227 | def on_test_epoch_end(self): 228 | for metric, values in self.test_step_metrics.items(): 229 | avg_metric = self.aggregate_metric(values) 230 | # self.log(f'test_{metric}', avg_metric, prog_bar=False, sync_dist=True) 231 | print(f'Epoch test end: {metric}/test', avg_metric) 232 | self.test_step_metrics.clear() 233 | 234 | # save the embeddings as numpy arrays: 235 | if self.ln_cfg.save_embed_path: 236 | if not os.path.isdir(self.ln_cfg.save_embed_path): 237 | os.makedirs(self.ln_cfg.save_embed_path) 238 | 239 | epitope_sequences = np.concatenate(self.epitope_sequences, axis=0) 240 | receptor_sequences = np.concatenate(self.receptor_sequences, axis=1) 241 | epitope_embeddings = torch.cat(self.epitope_embeddings, dim=0).detach().cpu().numpy() 242 | receptor_embeddings = torch.cat(self.receptor_embeddings, dim=0).detach().cpu().numpy() 243 | 244 | # actually save the embeds 245 | print("Saving sequences and embeddings to disk...") 246 | np.save(self.ln_cfg.save_embed_path + '/epitope_seqs.npy', epitope_sequences) 247 | np.save(self.ln_cfg.save_embed_path + '/receptor_seqs.npy', receptor_sequences) 248 | np.save(self.ln_cfg.save_embed_path + '/epitope_embeds.npy', epitope_embeddings) 249 | np.save(self.ln_cfg.save_embed_path + '/receptor_embeds.npy', receptor_embeddings) 250 | 251 | 252 | @staticmethod 253 | def aggregate_metric(step_outputs): 254 | return np.mean(step_outputs) 255 | 256 | def configure_optimizers(self): 257 | if self.ln_cfg.regular_ft: 258 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.ln_cfg.lr, weight_decay=self.ln_cfg.weight_decay) 259 | return { 260 | "optimizer": optimizer, 261 | } 262 | 263 | if self.ln_cfg.lr_scheduler == 'plateau': 264 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.ln_cfg.lr, weight_decay=self.ln_cfg.weight_decay) 265 | scheduler_lr = ReduceLROnPlateau(optimizer, 'min', factor=0.3, patience=2, min_lr=1e-6) 266 | 267 | elif self.ln_cfg.lr_scheduler == 'cos_anneal': 268 | optimizer = torch.optim.AdamW(self.parameters(), lr=1e-6, weight_decay=self.ln_cfg.weight_decay) 269 | scheduler_lr = CosineAnnealingWarmUpRestarts(optimizer, T_0=10, T_mult=1, eta_max=self.ln_cfg.lr, T_up=2, gamma=0.7) 270 | 271 | return { 272 | "optimizer": optimizer, 273 | "lr_scheduler": { 274 | "scheduler": scheduler_lr, 275 | "interval": "epoch", 276 | "monitor": "val_loss", 277 | "frequency": 1, 278 | }, 279 | } 280 | 281 | 282 | def configure_model(self): 283 | 284 | self.epitope_encoder = EpitopeEncoderESM(self.epitope_input_dim, self.projection_dim, hidden_dim=self.hidden_dim, 285 | ln_cfg=self.ln_cfg, model_config=self.model_config, device=self.device) 286 | 287 | if self.model_config.receptor_model_name == 'ablang': 288 | self.receptor_encoder = AntibodyEncoderAbLang(self.receptor_input_dim, self.projection_dim, ln_cfg=self.ln_cfg, device=self.device) 289 | elif self.model_config.receptor_model_name == 'ablang2': 290 | self.receptor_encoder = AntibodyEncoderAbLang2(self.receptor_input_dim, self.projection_dim, ln_cfg=self.ln_cfg, device=self.device) 291 | elif self.model_config.receptor_model_name == 'antiberta2': 292 | self.receptor_encoder = AntibodyEncoderAntiberta2(self.receptor_input_dim, self.projection_dim, ln_cfg=self.ln_cfg, device=self.device) 293 | elif self.model_config.receptor_model_name == 'tcrbert': 294 | self.receptor_encoder = TCREncoderTCRBert(self.receptor_input_dim, self.projection_dim, 295 | hidden_dim=self.hidden_dim, ln_cfg=self.ln_cfg, device=self.device) 296 | elif self.model_config.receptor_model_name == 'tcrlang': 297 | self.receptor_encoder = TCREncoderTCRLang(self.receptor_input_dim, self.projection_dim, 298 | hidden_dim=self.hidden_dim, ln_cfg=self.ln_cfg, device=self.device) 299 | elif self.model_config.receptor_model_name in ['esm2', 'esm3']: 300 | if "catcr" in self.ln_cfg.dataset_path: 301 | self.receptor_encoder = TCREncoderESMBetaOnly(self.receptor_input_dim, self.projection_dim, hidden_dim=self.hidden_dim, 302 | ln_cfg=self.ln_cfg, model_config=self.model_config, device=self.device) 303 | else: 304 | #TODO: REPLACE THIS LINE! 305 | # self.receptor_encoder = TCREncoderESMBetaOnly(self.receptor_input_dim, self.projection_dim, hidden_dim=self.hidden_dim, 306 | # ln_cfg=self.ln_cfg, model_config=self.model_config, device=self.device) 307 | self.receptor_encoder = TCREncoderESM(self.receptor_input_dim, self.projection_dim, hidden_dim=self.hidden_dim, 308 | ln_cfg=self.ln_cfg, model_config=self.model_config, device=self.device) 309 | elif self.model_config.receptor_model_name == 'inhouse': 310 | self.receptor_encoder = TCREncoderInHouse(self.receptor_input_dim, self.projection_dim, hidden_dim=self.hidden_dim, 311 | ln_cfg=self.ln_cfg, model_config=self.model_config, device=self.device) 312 | elif self.model_config.receptor_model_name == 'onehot': 313 | self.epitope_encoder = EpitopeEncoderOneHot(self.epitope_input_dim, self.projection_dim, 314 | ln_cfg=self.ln_cfg, model_config=self.model_config, device=self.device) 315 | self.receptor_encoder = TCREncoderOneHot(self.receptor_input_dim, self.projection_dim, 316 | ln_cfg=self.ln_cfg, model_config=self.model_config, device=self.device) 317 | 318 | else: 319 | raise NotImplementedError("Such Ab Model not implemented yet. Please choose from existing models.") 320 | 321 | 322 | # for inference later 323 | def put_submodules_to_device(self, device): 324 | self.epitope_encoder.device = device 325 | self.receptor_encoder.device = device -------------------------------------------------------------------------------- /src/data_module.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from torch.utils.data import DataLoader, Dataset, random_split, Sampler 3 | from sklearn.model_selection import train_test_split 4 | import pandas as pd 5 | import torch 6 | import numpy as np 7 | from collections import defaultdict, deque 8 | import random 9 | import os 10 | 11 | from .utils import (load_iedb_data, load_iedb_data_cdr3, load_vdjdb_data_cdr3, 12 | load_vdjdb_data_pmhc, load_pird_data_cdr3, load_mixtcrpred_data, 13 | load_mixtcrpred_data_pmhc, load_catcr_data) 14 | 15 | class EpitopeReceptorDataset(Dataset): 16 | '''Returns receptor paired with epitope-only data''' 17 | def __init__(self, data): 18 | self.data = data 19 | 20 | def __len__(self): 21 | return len(self.data) 22 | 23 | def __getitem__(self, idx): 24 | # print("idx", idx) 25 | epitope_seq = self.data.iloc[idx]['epitope'] 26 | heavy_chain_seq = self.data.iloc[idx]['heavy_chain'] 27 | light_chain_seq = self.data.iloc[idx]['light_chain'] 28 | return epitope_seq, (heavy_chain_seq, light_chain_seq) 29 | 30 | class pMhcReceptorDataset(Dataset): 31 | '''Returns receptor paired with epitope+MHC data''' 32 | def __init__(self, data): 33 | self.data = data 34 | 35 | def __len__(self): 36 | return len(self.data) 37 | 38 | def __getitem__(self, idx): 39 | epitope_seq = self.data.iloc[idx]['epitope'] 40 | mhc_a = self.data.iloc[idx]['mhc.a_seq'] 41 | mhc_b = self.data.iloc[idx]['mhc.b_seq'] 42 | heavy_chain_seq = self.data.iloc[idx]['heavy_chain'] 43 | light_chain_seq = self.data.iloc[idx]['light_chain'] 44 | return (epitope_seq, mhc_a, mhc_b), (heavy_chain_seq, light_chain_seq) 45 | 46 | class EpitopeReceptorDataModule(pl.LightningDataModule): 47 | def __init__(self, tsv_file, mhc_file=None, batch_size=32, include_mhc=False, ln_cfg = None, 48 | model_config = None, split_ratio=(0.7, 0.15, 0.15), random_seed=7): 49 | super().__init__() 50 | self.tsv_file = tsv_file 51 | self.batch_size = batch_size 52 | self.model_config = model_config 53 | self.ln_cfg = ln_cfg 54 | self.include_mhc = include_mhc 55 | self.mhc_file = mhc_file 56 | if self.include_mhc: 57 | assert self.mhc_file is not None, "Must provide a file with MHC data" 58 | self.split_ratio = split_ratio 59 | 60 | self.random_seed = random_seed 61 | 62 | def prepare_data_must(self): 63 | # Read the TSV file 64 | if 'IEDB' in self.tsv_file: 65 | if self.model_config.receptor_model_name == 'ablang': 66 | self.data = load_iedb_data(self.tsv_file, replace_X=True) 67 | elif self.model_config.receptor_model_name == 'tcrlang': 68 | self.data = load_iedb_data_cdr3(self.tsv_file, replace_hashtag=True) 69 | elif self.model_config.receptor_model_name == 'tcrbert': 70 | self.data = load_iedb_data_cdr3(self.tsv_file) 71 | else: 72 | self.data = load_iedb_data(self.tsv_file) 73 | elif 'vdjdb' in self.tsv_file: 74 | if self.include_mhc: 75 | self.data = load_vdjdb_data_pmhc(self.tsv_file, self.mhc_file) 76 | else: 77 | self.data = load_vdjdb_data_cdr3(self.tsv_file) 78 | elif 'mixtcrpred' in self.tsv_file: 79 | if self.include_mhc: 80 | self.data = load_mixtcrpred_data_pmhc(self.tsv_file, self.mhc_file) 81 | else: 82 | self.data = load_mixtcrpred_data(self.tsv_file) 83 | elif 'pird' in self.tsv_file: 84 | self.data = load_pird_data_cdr3(self.tsv_file) 85 | elif 'catcr' in self.tsv_file: 86 | self.data = load_catcr_data(self.tsv_file) 87 | 88 | # self.train_data, self.test_data = load_catcr_data(self.tsv_file) 89 | # return 90 | else: 91 | raise ValueError(f"Can't process this tsv file: {self.tsv_file}") 92 | 93 | # Ensure the data has the correct columns 94 | assert 'epitope' in self.data.columns 95 | assert 'heavy_chain' in self.data.columns 96 | assert 'light_chain' in self.data.columns 97 | 98 | def split_data_random(self): 99 | if self.ln_cfg.unique_epitopes: 100 | # ------------------------------------------------------ 101 | # Splitting data via unique epitopes: 102 | 103 | np.random.seed(self.random_seed) 104 | 105 | # Get unique values in epitope column 106 | unique_epitopes = self.data['epitope'].unique() 107 | 108 | # Shuffle the unique values: 109 | np.random.shuffle(unique_epitopes) 110 | 111 | # Split the unique values into train, dev, and test sets 112 | train_size = int(self.split_ratio[0] * len(unique_epitopes)) 113 | dev_size = int(self.split_ratio[1] * len(unique_epitopes)) 114 | test_size = len(unique_epitopes) - train_size - dev_size 115 | 116 | train_values = unique_epitopes[:train_size]#[:100] 117 | dev_values = unique_epitopes[train_size:train_size + dev_size] 118 | test_values = unique_epitopes[train_size + dev_size:] 119 | 120 | # Create train, dev, and test dataframes 121 | # making sure that each set has a unique set of epitopes 122 | self.train_data = self.data[self.data['epitope'].isin(train_values)] 123 | self.dev_data = self.data[self.data['epitope'].isin(dev_values)] 124 | self.test_data = self.data[self.data['epitope'].isin(test_values)] 125 | 126 | elif self.ln_cfg.fewshot_ratio: 127 | self.train_data, self.dev_data = split_df_by_ratio(self.data, 0.85, random_seed=self.random_seed) 128 | if self.ln_cfg.fewshot_ratio < 1: 129 | self.train_data, _ = split_df_by_ratio(self.train_data, self.ln_cfg.fewshot_ratio, random_seed=self.random_seed) 130 | self.test_data = self.dev_data.copy() 131 | 132 | else: 133 | # ------------------------------------------------------ 134 | # Split the data into train, dev, and test sets 135 | total_size = len(self.data) 136 | train_size = int(self.split_ratio[0] * total_size) 137 | dev_size = int(self.split_ratio[1] * total_size) 138 | test_size = total_size - train_size - dev_size 139 | 140 | self.train_data, self.temp = train_test_split(self.data, test_size=0.3, random_state=self.random_seed) 141 | self.dev_data, self.test_data = train_test_split(self.temp, test_size=0.5, random_state=self.random_seed) 142 | 143 | # # oversample here: 144 | # if self.ln_cfg.oversample: 145 | # self.train_data = upsample_epitopes(self.train_data, 'epitope') 146 | 147 | # Reset the index of the dataframes 148 | self.train_data = self.train_data.reset_index(drop=True) 149 | self.dev_data = self.dev_data.reset_index(drop=True) 150 | self.test_data = self.test_data.reset_index(drop=True) 151 | 152 | 153 | def setup(self, stage=None): 154 | self.prepare_data_must() 155 | 156 | if "catcr" in self.tsv_file: 157 | # # copy the test data into dev: 158 | # self.dev_data = self.test_data.copy() 159 | self.split_data_random() 160 | else: 161 | self.split_data_random() 162 | 163 | if self.ln_cfg.save_embed_path: 164 | self.save_datasplit(self.ln_cfg.save_embed_path) 165 | 166 | def train_dataloader(self): 167 | if self.include_mhc: 168 | train_dataset = pMhcReceptorDataset(self.train_data) 169 | else: 170 | train_dataset = EpitopeReceptorDataset(self.train_data) 171 | 172 | if self.ln_cfg.oversample: 173 | train_sampler = OversampleSampler(self.train_data, self.batch_size) 174 | return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=False, sampler=train_sampler, 175 | num_workers=4, persistent_workers=True) 176 | else: 177 | return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, 178 | num_workers=4, pin_memory=True, persistent_workers=True) 179 | 180 | def val_dataloader(self): 181 | if self.include_mhc: 182 | dev_dataset = pMhcReceptorDataset(self.dev_data) 183 | else: 184 | dev_dataset = EpitopeReceptorDataset(self.dev_data) 185 | # return DataLoader(dev_dataset, batch_size=self.batch_size, shuffle=False, 186 | # num_workers=4, pin_memory=True, persistent_workers=True) 187 | return DataLoader(dev_dataset, batch_size=self.batch_size, shuffle=False, 188 | num_workers=4, persistent_workers=True) 189 | 190 | 191 | def test_dataloader(self): 192 | if self.include_mhc: 193 | test_dataset = pMhcReceptorDataset(self.test_data) 194 | else: 195 | test_dataset = EpitopeReceptorDataset(self.test_data) 196 | return DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, 197 | num_workers=4, pin_memory=True, persistent_workers=True) 198 | 199 | 200 | def save_datasplit(self, savepath): 201 | ''' 202 | Save the pandas dataframes of train/dev/test splits to a specified path 203 | ''' 204 | 205 | if not os.path.isdir(savepath): 206 | os.makedirs(savepath) 207 | 208 | train_path = os.path.join(savepath, 'train.tsv') 209 | dev_path = os.path.join(savepath, 'dev.tsv') 210 | test_path = os.path.join(savepath, 'test.tsv') 211 | 212 | self.train_data.to_csv(train_path, sep='\t', index=False) 213 | self.dev_data.to_csv(dev_path, sep='\t', index=False) 214 | self.test_data.to_csv(test_path, sep='\t', index=False) 215 | 216 | 217 | # class OversampleSampler(Sampler): 218 | # def __init__(self, df, batch_size): 219 | # self.df = df 220 | # self.batch_size = batch_size 221 | # self.indices = self.generate_indices() 222 | 223 | # def generate_indices(self): 224 | # epitope_counts = self.df['epitope'].value_counts() 225 | # max_count = epitope_counts.max() 226 | 227 | # # Group indices by epitope and shuffle them 228 | # epitope_index_dict = { 229 | # epitope: np.random.permutation(self.df[self.df['epitope'] == epitope].index.tolist() * int( max_count // count )).tolist() 230 | # for epitope, count in epitope_counts.items() 231 | # } 232 | 233 | # # Generate batches with as distinct epitopes as possible 234 | # batched_indices = [] 235 | # while any(epitope_index_dict.values()): 236 | # batch = [] 237 | # available_epitopes = [epitope for epitope, indices in epitope_index_dict.items() if indices] 238 | # np.random.shuffle(available_epitopes) 239 | # selected_epitopes = available_epitopes[:self.batch_size] 240 | 241 | # for epitope in selected_epitopes: 242 | # if epitope_index_dict[epitope]: 243 | # batch.append(epitope_index_dict[epitope].pop()) 244 | 245 | # # Fill the remaining batch size with other available indices if needed 246 | # if len(batch) < self.batch_size: 247 | # remaining_epitopes = [epitope for epitope, indices in epitope_index_dict.items() if indices] 248 | # np.random.shuffle(remaining_epitopes) 249 | # for epitope in remaining_epitopes: 250 | # if len(batch) >= self.batch_size: 251 | # break 252 | # if epitope_index_dict[epitope]: 253 | # batch.append(epitope_index_dict[epitope].pop()) 254 | 255 | # np.random.shuffle(batch) 256 | # batched_indices.append(batch) 257 | 258 | # # shuffle the order of minibatches as well 259 | # np.random.shuffle(batched_indices) 260 | # batched_indices = sum(batched_indices, []) 261 | 262 | # return np.array(batched_indices) 263 | 264 | # def __iter__(self): 265 | # return iter(self.indices) 266 | 267 | # def __len__(self): 268 | # return len(self.indices) 269 | 270 | 271 | class OversampleSampler(Sampler): 272 | def __init__(self, df, batch_size): 273 | self.df = df 274 | self.indices = self.generate_indices() 275 | 276 | # print("DF size: ", self.df.shape) 277 | # print("Oversampled indices:", self.indices[:1000]) 278 | 279 | def generate_indices(self): 280 | epitope_counts = self.df['epitope'].value_counts() 281 | max_count = epitope_counts.max() 282 | 283 | oversample_indices = [] 284 | for epitope, count in epitope_counts.items(): 285 | epitope_indices = self.df[self.df['epitope'] == epitope].index.tolist() 286 | oversample_ratio = max_count // count # int( np.sqrt(max_count // count) ) 287 | oversample_indices.extend(epitope_indices * oversample_ratio) 288 | 289 | # return oversample_indices 290 | 291 | # Shuffle the oversampled indices to ensure randomness 292 | np.random.shuffle(oversample_indices) 293 | return np.array(oversample_indices) 294 | 295 | def __iter__(self): 296 | # # Shuffle the oversampled indices to ensure randomness 297 | # np.random.shuffle(self.indices) 298 | # return iter(np.array(self.indices)) 299 | return iter(self.indices) 300 | 301 | def __len__(self): 302 | return len(self.indices) 303 | 304 | 305 | class UniqueValueSampler(Sampler): 306 | def __init__(self, dataframe, batch_size, seed=42): 307 | self.dataframe = dataframe 308 | self.batch_size = batch_size 309 | self.unique_values = list(dataframe['epitope'].unique()) 310 | self.original_indices_by_value = defaultdict(list) 311 | self.indices_by_value = defaultdict(list) 312 | 313 | for idx, value in enumerate(dataframe['epitope']): 314 | self.original_indices_by_value[value].append(idx) 315 | 316 | self.seed = seed 317 | self.reset_indices() 318 | 319 | def reset_indices(self): 320 | # Reset the indices for each unique value from the original indices 321 | self.indices_by_value = {value: indices[:] for value, indices in self.original_indices_by_value.items()} 322 | 323 | def __iter__(self): 324 | random.seed(self.seed) 325 | shuffled_unique_values = self.unique_values[:] 326 | random.shuffle(shuffled_unique_values) 327 | 328 | batches = [] 329 | current_batch = [] 330 | used_values = set() 331 | 332 | for value in shuffled_unique_values: 333 | if value not in used_values and self.indices_by_value[value]: 334 | index = self.indices_by_value[value].pop(0) 335 | current_batch.append(index) 336 | used_values.add(value) 337 | 338 | if len(current_batch) == self.batch_size: 339 | batches.append(current_batch) 340 | current_batch = [] 341 | used_values.clear() 342 | 343 | # Add the last batch if it contains any items 344 | if current_batch: 345 | batches.append(current_batch) 346 | 347 | # Ensure we cover all indices, even if they don't form a full batch 348 | remaining_indices = [idx for value in shuffled_unique_values for idx in self.indices_by_value[value]] 349 | for i in range(0, len(remaining_indices), self.batch_size): 350 | batches.append(remaining_indices[i:i+self.batch_size]) 351 | 352 | # Flatten the list of batches to a list of indices 353 | flattened_batches = [idx for batch in batches for idx in batch] 354 | 355 | self.reset_indices() # Reset indices for the next epoch 356 | return iter(flattened_batches) 357 | 358 | def __len__(self): 359 | return len(self.dataframe) 360 | 361 | 362 | def compute_weights(epitope_seqs): 363 | ''' 364 | given a list of redundant epitope sequences, count the number of times each unique epitope appears 365 | and compute the inverse square-rooted count weights for each epitope 366 | and save them into a dictionary 367 | ''' 368 | epitope_weights = {} 369 | for seq in epitope_seqs: 370 | if seq in epitope_weights: 371 | epitope_weights[seq] += 1 372 | else: 373 | epitope_weights[seq] = 1 374 | 375 | # compute the inverse square-rooted count weights 376 | for seq in epitope_weights: 377 | epitope_weights[seq] = np.sqrt(1 / epitope_weights[seq]) 378 | 379 | return epitope_weights 380 | 381 | 382 | def upsample_epitopes(df: pd.DataFrame, epitope_column: str) -> pd.DataFrame: 383 | """ 384 | Upsample the DataFrame so that each unique epitope is repeated by the ratio 385 | of max count to its current count. 386 | 387 | Parameters: 388 | - df: pandas DataFrame containing the data. 389 | - epitope_column: name of the column containing the epitope identifiers. 390 | 391 | Returns: 392 | - upsampled_df: pandas DataFrame with epitopes upsampled by the calculated ratio. 393 | """ 394 | # Step 1: Count the number of entries for each epitope 395 | epitope_counts = df[epitope_column].value_counts() 396 | 397 | # Step 2: Find the maximum count 398 | max_count = epitope_counts.max() 399 | 400 | # Step 3: Function to upsample each group by the ratio 401 | def upsample(group): 402 | # Calculate the number of repetitions needed for each epitope group 403 | num_repeats = max_count // len(group) 404 | # Repeat each group by the calculated number of repeats 405 | return group.loc[group.index.repeat(num_repeats)] 406 | 407 | # Step 4: Apply the upsample function to each epitope group 408 | upsampled_df = df.groupby(epitope_column, group_keys=False).apply(upsample) 409 | 410 | return upsampled_df 411 | 412 | def split_df_by_ratio(df, r, random_seed=14): 413 | # Create two empty lists to hold the dataframes for the two sets 414 | df_1_list = [] 415 | df_2_list = [] 416 | 417 | # Group the dataframe by the epitope 418 | grouped = df.groupby('epitope') 419 | 420 | # Iterate through each group 421 | for epitope, group in grouped: 422 | # Shuffle the group 423 | shuffled_group = group.sample(frac=1, random_state=random_seed) 424 | 425 | # Determine the split index 426 | split_idx = int(len(shuffled_group) * r) 427 | 428 | # Split the group into two parts based on the ratio 429 | df_1_list.append(shuffled_group.iloc[:split_idx]) 430 | df_2_list.append(shuffled_group.iloc[split_idx:]) 431 | 432 | # Concatenate all the individual dataframes to create the final dataframes 433 | df_1 = pd.concat(df_1_list).reset_index(drop=True) 434 | df_2 = pd.concat(df_2_list).reset_index(drop=True) 435 | 436 | return df_1, df_2 -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | import re 5 | import json 6 | import functools 7 | import os 8 | 9 | # adapted from the HuggingFace repo for AbLang 10 | def get_sequence_embeddings(encoded_input, model_output, is_sep=True, is_cls=True, epitope_mask=None): 11 | if isinstance(model_output, dict): 12 | output_last_h_state = model_output['last_hidden_state'] 13 | else: 14 | output_last_h_state = model_output.last_hidden_state 15 | 16 | mask = encoded_input['attention_mask'].float() 17 | if is_sep: 18 | d = {k: v for k, v in torch.nonzero(mask).cpu().numpy()} # dict of sep tokens 19 | # make sep token invisible 20 | for i in d: 21 | mask[i, d[i]] = 0 22 | if is_cls: 23 | mask[:, 0] = 0.0 # make cls token invisible 24 | if epitope_mask is not None: 25 | mask = mask * epitope_mask # make non-epitope regions invisible 26 | mask = mask.unsqueeze(-1).expand(output_last_h_state.size()) 27 | sum_embeddings = torch.sum(output_last_h_state * mask, 1) 28 | sum_mask = torch.clamp(mask.sum(1), min=1e-9) 29 | return sum_embeddings / sum_mask 30 | 31 | def get_attention_mask(encoded_input, is_sep=True, is_cls=True): 32 | mask = encoded_input['attention_mask'].float() 33 | if is_sep: 34 | d = {k: v for k, v in torch.nonzero(mask).cpu().numpy()} # dict of sep tokens 35 | # make sep token invisible 36 | for i in d: 37 | mask[i, d[i]] = 0 38 | if is_cls: 39 | mask[:, 0] = 0.0 # make cls token invisible 40 | return mask 41 | 42 | # load CATCR dataset (contains epitope/CDR3-B data from VDJdb, IEDB, McPAS-TCR): 43 | def load_catcr_data(tsv_path): 44 | df_catcr = pd.read_csv(tsv_path, delimiter=',') 45 | 46 | # rename the TCR column to light_chain: 47 | df_catcr = df_catcr.rename(columns={'CDR3_B':'light_chain', 'EPITOPE':'epitope'}) 48 | 49 | # no alpha chain, so create an empty heavy chain column: 50 | df_catcr['heavy_chain'] = "" 51 | 52 | # reset index 53 | df_catcr = df_catcr.reset_index(drop=True) 54 | 55 | return df_catcr 56 | 57 | # load CATCR dataset (contains epitope/CDR3-B data from VDJdb, IEDB, McPAS-TCR): 58 | def load_catcr_data_presplit(tsv_path): 59 | df_catcr_train = pd.read_csv(os.path.join(tsv_path, "train.csv"), delimiter=',') 60 | df_catcr_test = pd.read_csv(os.path.join(tsv_path, "test.csv"), delimiter=',') 61 | 62 | # rename the TCR column to light_chain: 63 | df_catcr_train = df_catcr_train.rename(columns={'CDR3_B':'light_chain', 64 | 'EPITOPE':'epitope'}) 65 | df_catcr_test = df_catcr_test.rename(columns={'CDR3_B':'light_chain', 66 | 'EPITOPE':'epitope'}) 67 | 68 | # no alpha chain, so create an empty heavy chain column: 69 | df_catcr_train['heavy_chain'] = "" 70 | df_catcr_test['heavy_chain'] = "" 71 | 72 | # reset index 73 | df_catcr_train = df_catcr_train.reset_index(drop=True) 74 | df_catcr_test = df_catcr_test.reset_index(drop=True) 75 | 76 | return df_catcr_train, df_catcr_test 77 | 78 | # load MixTCRPred dataset (contains data from VDJdb, IEDB, McPAS, and 10x Genomics) 79 | def load_mixtcrpred_data(tsv_path): 80 | df_mix = pd.read_csv(tsv_path, delimiter=',') 81 | 82 | # drop rows where Epitope name is not a peptide sequence: 83 | df_mix = df_mix[df_mix['epitope'].str.isupper()] 84 | df_mix = df_mix[df_mix['epitope'].apply(is_alpha_only)] 85 | 86 | # drop rows whose TCR sequences are missing: 87 | df_mix = df_mix.dropna(subset=['cdr3_TRA', 'cdr3_TRB']) 88 | 89 | # rename the TCR columns to heavy_chain and light_chain: 90 | df_mix = df_mix.rename(columns={'cdr3_TRA':'heavy_chain', 'cdr3_TRB':'light_chain'}) 91 | 92 | # drop duplicates of epitope-TRA-TRB triplets: 93 | df_mix = df_mix.drop_duplicates(subset=['epitope', 'heavy_chain', 'light_chain']) # drops ~2700 entries 94 | 95 | # only restrict it to human data: 96 | df_mix = df_mix.loc[df_mix["species"] == "HomoSapiens"] 97 | 98 | # reset index 99 | df_mix = df_mix.reset_index(drop=True) 100 | 101 | return df_mix 102 | 103 | def load_mixtcrpred_data_pmhc(tsv_path, mhc_map_path): 104 | df_mhc_map = pd.read_csv(mhc_map_path, delimiter='\t', index_col=0) # format where index is two-part name like "A*01:01" and the column label is "max_sequence" containing the sequence with signal peptide removed 105 | # raise NotImplementedError("mhc_map needs to be updated to include all the mouse MHCs if we are including mouse data") 106 | 107 | df_mix = load_mixtcrpred_data(tsv_path) 108 | 109 | def clean_mhc_allele(allele): 110 | ''' mixtcrpred will sometimes have MHC names formatted like "HLA-DRB1:01". 111 | In these cases, the colon must be replaced with an asterisk''' 112 | if allele != "B2M" and "-" not in allele: # from splitting HLA-__A/__B entries, some HLAs have lost their prefix. Add it back. H2-* alleles are fine 113 | allele = "HLA-" + allele 114 | 115 | pattern = re.compile(r'(HLA-[A-Za-z0-9]+):([:0-9]+)') 116 | return pattern.sub(r'\1*\2', allele) 117 | 118 | # Create new columns (will change their values below) 119 | df_mix["mhc.a"] = df_mix["MHC"] 120 | df_mix["mhc.b"] = df_mix["MHC"] 121 | 122 | # make sure mouse names are consistent with map_mhc_allele's canonical alleles 123 | for i, row in df_mix.iterrows(): 124 | if row.MHC_class == "MHCI": 125 | df_mix.at[i, "mhc.a"] = row.MHC 126 | df_mix.at[i, "mhc.b"] = "B2M" 127 | else: # MHCII 128 | # Cases that need to be handled 129 | # Counter({'H2-IAb': 3674, 130 | # 'HLA-DPB1*04:01': 388, 131 | # 'HLA-DRB1:01': 71, 132 | # 'HLA-DQA1:02/DQB1*06:02': 46, 133 | # 'HLA-DRB1*04:05': 31, 134 | # 'H2-IEk': 25, 135 | # 'H2-Kb': 24, 136 | # 'HLA-DRA:01/DRB1:01': 21, 137 | # 'HLA-DRB1*07:01': 20, 138 | # 'HLA-DRB1*15:01': 15, 139 | # 'HLA-DQA1*05:01/DQB1*02:01': 14, 140 | # 'HLA-DRB1*04:01': 14, 141 | # 'HLA-DQA': 13, 142 | # 'H-2q': 12, 143 | # 'HLA-DRB1*11:01': 10, 144 | # 'HLA-DQ2': 10, 145 | # 'HLA-DRA:01': 10}) 146 | if "/" in row.MHC: 147 | df_mix.at[i, "mhc.a"] = row.MHC.split("/")[0] 148 | df_mix.at[i, "mhc.b"] = row.MHC.split("/")[1] 149 | else: 150 | # inconsistencies in mixtcrpred: there are 24 examples of 'H2-Kb' and 12 examples of 'H-2q' 151 | # which are labeled as MHC II even though they are MHC I alleles. Switch them to MHC I 152 | if row.MHC == "H2-Kb": 153 | df_mix.at[i, "mhc.a"] = "H2-Kb" 154 | df_mix.at[i, "mhc.b"] = "B2M" 155 | elif row.MHC == "H-2q": 156 | # also switch formatting to H2- nomenclature 157 | df_mix.at[i, "mhc.a"] = "H2-Q" 158 | df_mix.at[i, "mhc.b"] = "B2M" 159 | else: 160 | if row.MHC == "H2-IAb": # switch this nomenclature to H2-A 161 | df_mix.at[i, "mhc.a"] = "H2-AA" 162 | df_mix.at[i, "mhc.b"] = "H2-AB" 163 | elif row.MHC == "H2-IEk": # add A and B chains for this allele 164 | df_mix.at[i, "mhc.a"] = "H2-IEkA" 165 | df_mix.at[i, "mhc.b"] = "H2-IEkB" 166 | elif row.MHC == "HLA-DQ2": 167 | df_mix.at[i, "mhc.a"] = "HLA-DQA1" 168 | df_mix.at[i, "mhc.b"] = "HLA-DQB1" 169 | elif row.MHC == "HLA-DQA": 170 | df_mix.at[i, "mhc.a"] = "HLA-DQA" 171 | df_mix.at[i, "mhc.b"] = "HLA-DQB" 172 | elif row.MHC == "HLA-DRA:01": 173 | df_mix.at[i, "mhc.a"] = "HLA-DRA:01" 174 | df_mix.at[i, "mhc.b"] = "HLA-DRB" 175 | else: 176 | # remainder are all 'HLA-D*B[:*]...' alleles 177 | # extract the allele name 178 | pattern = re.compile(r'HLA-([A-Za-z]+)[0-9]*[:*].*') 179 | allele_name = pattern.search(row.MHC).group(1) # e.g. extracts DRB from HLA-DRB1:01 or HLA-DRB1*11:01 180 | mhca_name = allele_name.replace("B", "A") 181 | df_mix.at[i, "mhc.a"] = f"HLA-{mhca_name}" 182 | df_mix.at[i, "mhc.b"] = row.MHC 183 | 184 | 185 | df_mix["mhc.a"] = df_mix["mhc.a"].apply(clean_mhc_allele) 186 | df_mix["mhc.b"] = df_mix["mhc.b"].apply(clean_mhc_allele) 187 | 188 | df_mix["mhc.a_seq"] = df_mix["mhc.a"].apply(functools.partial(map_mhc_allele, df_mhc_map=df_mhc_map)) 189 | df_mix["mhc.b_seq"] = df_mix["mhc.b"].apply(functools.partial(map_mhc_allele, df_mhc_map=df_mhc_map)) 190 | 191 | 192 | 193 | return df_mix 194 | 195 | 196 | # load the IEDB dataset: 197 | def load_iedb_data(tsv_path, replace_X=False, remove_X=False, use_anarci=False): 198 | df_iedb = pd.read_csv(tsv_path, delimiter='\t') 199 | 200 | # drop rows where Epitope name is not a peptide sequence: 201 | df_iedb = df_iedb[df_iedb['Epitope - Name'].str.isupper()] 202 | df_iedb = df_iedb[df_iedb['Epitope - Name'].apply(is_alpha_only)] 203 | 204 | # drop rows whose Ab HL sequences missing: 205 | df_iedb = df_iedb.dropna(subset=['Chain 1 - Protein Sequence', 'Chain 2 - Protein Sequence']) 206 | 207 | # drop rows whose CDRs are missing: 208 | cdr_columns = ['Chain 1 - CDR3 Calculated', 'Chain 1 - CDR2 Calculated', 'Chain 1 - CDR1 Calculated', 209 | 'Chain 2 - CDR3 Calculated', 'Chain 2 - CDR2 Calculated', 'Chain 2 - CDR1 Calculated'] 210 | df_iedb = df_iedb.dropna(subset=cdr_columns) 211 | 212 | if use_anarci: 213 | from anarci import run_anarci 214 | # run ANARCI to get the Fv region of the sequences: 215 | print("running anarci on sequences...") 216 | for col_id in ['Chain 1 - Protein Sequence', 'Chain 2 - Protein Sequence']: 217 | seqs = df_iedb[col_id].str.upper() 218 | seqs_ = [(str(i), s) for i, s in enumerate(seqs)] 219 | anarci_results = run_anarci(seqs_) 220 | start_end_pairs = [(anarci_results[2][i][0]['query_start'], anarci_results[2][i][0]['query_end']) for i in range(len(seqs_))] 221 | seqs = [seq[a:b] for seq, (a,b) in zip(seqs, start_end_pairs)] 222 | df_iedb[col_id] = seqs 223 | 224 | df_iedb = df_iedb.reset_index(drop=True) 225 | # FOR FUTURE USERS: IF PERFORMING BCR CALCULATIONS, SAVE df_iedb TO A CSV FILE 226 | # e.g. df_iedb.to_csv("path/to/iedb_data_with_anarci.csv", index=False) 227 | print("done running anarci!") 228 | 229 | # change column names: 230 | df_iedb = df_iedb.rename(columns={'Epitope - Name': 'epitope', 231 | 'Chain 1 - Protein Sequence': 'heavy_chain', 232 | 'Chain 2 - Protein Sequence': 'light_chain', 233 | 'Chain 1 - CDR3 Calculated': 'heavy_chain_cdr3', 234 | 'Chain 1 - CDR2 Calculated': 'heavy_chain_cdr2', 235 | 'Chain 1 - CDR1 Calculated': 'heavy_chain_cdr1', 236 | 'Chain 2 - CDR3 Calculated': 'light_chain_cdr3', 237 | 'Chain 2 - CDR2 Calculated': 'light_chain_cdr2', 238 | 'Chain 2 - CDR1 Calculated': 'light_chain_cdr1'}) 239 | 240 | if replace_X: 241 | # replace X's with [MASK] in the sequences for AbLang: 242 | df_iedb['heavy_chain'] = df_iedb['heavy_chain'].str.replace('X', '[MASK]') 243 | df_iedb['light_chain'] = df_iedb['light_chain'].str.replace('X', '[MASK]') 244 | 245 | if remove_X: 246 | # remove rows with 'X' in the sequences: 247 | df_iedb = df_iedb[~df_iedb['heavy_chain'].str.contains('X')] 248 | df_iedb = df_iedb[~df_iedb['light_chain'].str.contains('X')] 249 | 250 | # remove all examples where the heavy_chain and light_chain values are the same (likely means invalid pair): 251 | df_iedb = df_iedb[df_iedb['heavy_chain'] != df_iedb['light_chain']] 252 | 253 | # reset index 254 | df_iedb = df_iedb.reset_index(drop=True) 255 | 256 | return df_iedb 257 | 258 | def load_iedb_data_cdr3(tsv_path, replace_hashtag=False): 259 | df_iedb = pd.read_csv(tsv_path, delimiter='\t') 260 | 261 | # drop rows where Epitope name is not a peptide sequence: 262 | df_iedb = df_iedb[df_iedb['Epitope - Name'].str.isupper()] 263 | df_iedb = df_iedb[df_iedb['Epitope - Name'].apply(is_alpha_only)] 264 | 265 | # drop rows whose Ab HL sequences missing: 266 | df_iedb = df_iedb.dropna(subset=['Chain 1 - CDR3 Curated', 'Chain 2 - CDR3 Curated']) 267 | 268 | # change column names: 269 | df_iedb = df_iedb.rename(columns={'Epitope - Name': 'epitope', 270 | 'Chain 1 - CDR3 Curated': 'heavy_chain', 271 | 'Chain 2 - CDR3 Curated': 'light_chain',}) 272 | 273 | # remove all examples where the heavy_chain and light_chain values are the same (likely means invalid pair): 274 | df_iedb = df_iedb[df_iedb['heavy_chain'] != df_iedb['light_chain']] 275 | 276 | # drop duplicates of epitope-TRA-TRB triplets: 277 | df_iedb = df_iedb.drop_duplicates(subset=['epitope', 'heavy_chain', 'light_chain']) 278 | 279 | if replace_hashtag: 280 | # replace #'s with X in the sequences for TCRLang: 281 | df_iedb['heavy_chain'] = df_iedb['heavy_chain'].str.replace('#', 'X') 282 | df_iedb['light_chain'] = df_iedb['light_chain'].str.replace('#', 'X') 283 | 284 | # make the AA's upper case in the alpha and beta chains: 285 | df_iedb['heavy_chain'] = df_iedb['heavy_chain'].str.upper() 286 | df_iedb['light_chain'] = df_iedb['light_chain'].str.upper() 287 | 288 | # strip the peptides with whitespace: 289 | df_iedb['heavy_chain'] = df_iedb['heavy_chain'].str.strip() 290 | df_iedb['light_chain'] = df_iedb['light_chain'].str.strip() 291 | 292 | # reset index 293 | df_iedb = df_iedb.reset_index(drop=True) 294 | 295 | return df_iedb 296 | 297 | 298 | def load_vdjdb_data_cdr3(tsv_path): 299 | print("path name:", tsv_path) 300 | df_vdj = pd.read_csv(tsv_path, delimiter='\t') 301 | 302 | # Get subset of columsn we are interested in 303 | df_vdj = df_vdj[["cdr3.alpha", "cdr3.beta", "species", "mhc.a", "mhc.b", "mhc.class", "antigen.epitope", "cdr3fix.alpha", "cdr3fix.beta"]] 304 | 305 | # subset to only paired data (both alpha and beta chain CDR3s are known) 306 | df_vdj = df_vdj.dropna(subset=["cdr3.alpha", "cdr3.beta"]) 307 | 308 | # Extract the fixed CDR3 sequences and use those as the ground truth CDR3 sequence 309 | # https://github.com/antigenomics/vdjdb-db?tab=readme-ov-file#cdr3-sequence-fixing 310 | # There is always a fixed value for every existing CDR3, but sometimes the "fixed" value is the same as the empirical one 311 | df_vdj["heavy_chain"] = df_vdj["cdr3fix.alpha"].apply(json.loads).apply(vdjdb_extract_fixed_cdr3) 312 | df_vdj["light_chain"] = df_vdj["cdr3fix.beta"].apply(json.loads).apply(vdjdb_extract_fixed_cdr3) 313 | df_vdj = df_vdj.rename(columns={'antigen.epitope':'epitope'}) 314 | 315 | # remove all examples where the heavy_chain and light_chain values are the same (likely means invalid pair): 316 | df_vdj = df_vdj[df_vdj["heavy_chain"] != df_vdj["light_chain"]] # only removes 1 entry 317 | 318 | # drop duplicates of epitope-TRA-TRB triplets: 319 | df_vdj = df_vdj.drop_duplicates(subset=['epitope', 'heavy_chain', 'light_chain']) # drops ~2700 entries 320 | 321 | # make the AA's upper case in the alpha and beta chains (shouldn't be necessary, but just to be safe) 322 | df_vdj['heavy_chain'] = df_vdj['heavy_chain'].str.upper() 323 | df_vdj['light_chain'] = df_vdj['light_chain'].str.upper() 324 | 325 | df_vdj = df_vdj.loc[df_vdj["species"] == "HomoSapiens"] # Before this point, species counts are {'HomoSapiens': 29556, 'MusMusculus': 2264} 326 | 327 | # reset index 328 | df_vdj = df_vdj.reset_index(drop=True) 329 | 330 | return df_vdj 331 | 332 | def load_vdjdb_data_pmhc(tsv_path, mhc_map_path): 333 | df_mhc_map = pd.read_csv(mhc_map_path, delimiter='\t', index_col=0) # format where index is two-part name like "A*01:01" and the column label is "max_sequence" containing the sequence with signal peptide removed 334 | 335 | df_vdj = load_vdjdb_data_cdr3(tsv_path) 336 | 337 | # Map the MHC allele names to their sequences 338 | df_vdj["mhc.a_seq"] = df_vdj["mhc.a"].apply(functools.partial(map_mhc_allele, df_mhc_map=df_mhc_map)) 339 | df_vdj["mhc.b_seq"] = df_vdj["mhc.b"].apply(functools.partial(map_mhc_allele, df_mhc_map=df_mhc_map)) 340 | 341 | return df_vdj 342 | 343 | 344 | 345 | # Helper function to extract fixed CDR3 from the VDJdb "cdr3fix.[alpha/beta]" column 346 | def vdjdb_extract_fixed_cdr3(obj): 347 | return obj["cdr3"] 348 | 349 | # Helper function to map a given MHC allele name to its sequence 350 | def map_mhc_allele(allele_name, df_mhc_map): 351 | ''' 352 | allele_name should be in the original VDJ format, such as "HLA-A*03:01" 353 | 354 | For allele_names that only specify type and no subtype (e.g. "HLA-B*08"), will map to subtype 01 (e.g. "HLA-B*08:01") 355 | 356 | For allele names not found in the MHC map, will map to canonical allele as specified here: https://www.ebi.ac.uk/ipd/imgt/hla/alignment/help/references/ 357 | ''' 358 | 359 | canonical_alleles = { # only explicitly listed subset found in vdjdb human data 360 | "HLA-A": "HLA-A*01:01", 361 | "HLA-B": "HLA-B*07:02", 362 | "HLA-C": "HLA-C*01:02", 363 | "HLA-E": "HLA-E*01:01", 364 | "HLA-DRA": "HLA-DRA*01:01", 365 | "HLA-DRB": "HLA-DRB1*01:01", # not in imgt/hla, but found in mixtcrpred so mapping to DRB1 366 | "HLA-DRB1": "HLA-DRB1*01:01", 367 | "HLA-DRB3": "HLA-DRB3*01:01", 368 | "HLA-DRB5": "HLA-DRB5*01:01", 369 | "HLA-DQA": "HLA-DQA1*01:01", # not in imgt/hla, but found in vdjdb so mapping to DQA1 370 | "HLA-DQA1": "HLA-DQA1*01:01", 371 | "HLA-DQB": "HLA-DQB1*05:01", # not in imgt/hla but found in mixtcrpred so mapping to DQB1 372 | "HLA-DQB1": "HLA-DQB1*05:01", 373 | "HLA-DPA": "HLA-DPA1*01:03", # not in imgt/hla but found in vdjdb so mappin gto DPA1 374 | "HLA-DPA1": "HLA-DPA1*01:03", 375 | "HLA-DPB": "HLA-DPB1*01:01", # not in imgt/hla but found in vdjdb so mappin gto DPB1 376 | "HLA-DPB1": "HLA-DPB1*01:01", 377 | # "DQ2": "HLA-DQA2*01:01", # found in mixtcrpred, unsure which DQ allele it is referring to, but only DQA has a canonical allele for 2. DQB only has canonical allele for 1 378 | } 379 | 380 | if allele_name == "B2M": # Handles beta chain placeholder for Class I MHCs that do not have a beta chain 381 | return "B2M" 382 | 383 | hla_id = allele_name 384 | if hla_id.startswith("HLA-"): 385 | components = hla_id.split(":") 386 | num_parts = len(components) 387 | if num_parts > 2: 388 | hla_id = ":".join(components[:2]) 389 | elif num_parts == 1: # either is of form HLA-A*01 or HLA-A 390 | pattern = re.compile(r'(HLA-[A-Za-z0-9]+)\*([0-9]+)') 391 | if pattern.match(hla_id): 392 | hla_id = ":".join([components[0], "01"]) 393 | else: 394 | pass # leave it as is 395 | 396 | if hla_id not in df_mhc_map.index: 397 | hla_gene = hla_id.split("*")[0] 398 | if hla_gene in canonical_alleles: 399 | hla_id = canonical_alleles[hla_gene] 400 | else: 401 | raise Exception(f"Could not find MHC allele {allele_name} in the MHC map and no canonical allele for {hla_gene} specified.") 402 | else: # handles mouse H2 alleles 403 | hla_id = allele_name 404 | 405 | 406 | allele_sequence = df_mhc_map.loc[hla_id, "max_sequence"] 407 | return allele_sequence 408 | 409 | 410 | 411 | def load_pird_data_cdr3(csv_path): 412 | ''' 413 | Load the PIRD dataset from a CSV file in latin-1 encoding (default from database download) and return a pandas dataframe. 414 | ''' 415 | df_pird = pd.read_csv(csv_path, encoding='latin-1') 416 | df_pird = df_pird.replace('-', np.nan) 417 | 418 | # get subset of columns we are interested in: 419 | df_pird = df_pird[['Antigen.sequence', 'HLA', 'CDR3.alpha.aa', 'CDR3.beta.aa']] 420 | 421 | # subset to only paired data (both alpha and beta chain CDR3s are known) 422 | df_pird = df_pird.dropna(subset=["Antigen.sequence", "CDR3.alpha.aa", "CDR3.beta.aa"]) 423 | 424 | # drop rows where Epitope name is not a peptide sequence: 425 | df_pird = df_pird[df_pird['Antigen.sequence'].str.isupper()] 426 | df_pird = df_pird[df_pird['Antigen.sequence'].apply(is_alpha_only)] 427 | 428 | # change column names: 429 | df_pird = df_pird.rename(columns={'Antigen.sequence': 'epitope', 430 | 'CDR3.alpha.aa': 'heavy_chain', 431 | 'CDR3.beta.aa': 'light_chain'}) 432 | 433 | # remove all examples where the heavy_chain and light_chain values are the same (likely means invalid pair): 434 | df_pird = df_pird[df_pird["heavy_chain"] != df_pird["light_chain"]] # only removes 1 entry 435 | 436 | # drop duplicates of epitope-TRA-TRB triplets: 437 | df_pird = df_pird.drop_duplicates(subset=['epitope', 'heavy_chain', 'light_chain']) # drops 82 entries 438 | 439 | # make the AA's upper case in the alpha and beta chains (shouldn't be necessary, but just to be safe) 440 | df_pird['heavy_chain'] = df_pird['heavy_chain'].str.upper() 441 | df_pird['light_chain'] = df_pird['light_chain'].str.upper() 442 | 443 | # reset index 444 | df_pird = df_pird.reset_index(drop=True) 445 | 446 | return df_pird 447 | 448 | # Function to check if the string contains only alphabetic characters 449 | def is_alpha_only(s): 450 | return s.isalpha() 451 | 452 | def insert_spaces(sequence): 453 | # Regular expression to match single amino acids or special tokens like '[UNK]' 454 | pattern = re.compile(r'\[.*?\]|.') 455 | 456 | # Find all matches and join them with a space 457 | spaced_sequence = ' '.join(pattern.findall(sequence)) 458 | 459 | return spaced_sequence 460 | 461 | 462 | def construct_label_matrices(epitope_seqs, receptor_seqs, include_mhc): 463 | if include_mhc: 464 | epitope_seqs = epitope_seqs[0] # extract epitope seqs from the array [epitopes_seqs, mhca_seqs, mhcb_seqs] 465 | bs = len(epitope_seqs) 466 | 467 | # Create a 2D tensor filled with zeros 468 | label_matrix = torch.zeros((bs, bs), dtype=torch.float32) 469 | # Construct the label matrix 470 | for i, correct_ep in enumerate(epitope_seqs): 471 | count = epitope_seqs.count(correct_ep) 472 | for j, ep in enumerate(epitope_seqs): 473 | if ep == correct_ep: 474 | label_matrix[i, j] = 1.0 / count 475 | 476 | return label_matrix 477 | 478 | def construct_label_matrices_ones(epitope_seqs, receptor_seqs, include_mhc): 479 | if include_mhc: 480 | epitope_seqs = epitope_seqs[0] # extract epitope seqs from the array [epitopes_seqs, mhca_seqs, mhcb_seqs] 481 | bs = len(epitope_seqs) 482 | 483 | # Create a 2D tensor filled with zeros 484 | label_matrix = torch.zeros((bs, bs), dtype=torch.float32) 485 | # Construct the label matrix 486 | for i, correct_ep in enumerate(epitope_seqs): 487 | for j, ep in enumerate(epitope_seqs): 488 | if ep == correct_ep: 489 | label_matrix[i, j] = 1.0 490 | 491 | return label_matrix 492 | 493 | def apply_masking_seq(sequences, mask_token='.', mask_regions=True, p=0.15): 494 | ''' 495 | mask_regions: True or List[np.array(dtype=bool)] 496 | - if True, all amino acids will be considered 497 | - if List, only amino acids with True values in the list will be considered (i.e. mask for regions to mask) 498 | 499 | For each sequence string in the sequences list, apply masking by changing the 500 | amino acid with the mask_token with a certain probability. 501 | ''' 502 | 503 | if mask_regions is True: # convert True value into all-True arrays every sequence 504 | mask_regions = [np.ones(len(seq), dtype=bool) for seq in sequences] 505 | 506 | masked_sequences = [] 507 | mask_indices = [] 508 | for n, seq in enumerate(sequences): 509 | masked_seq = '' 510 | # seq_mask_indices = [] 511 | mask_count = 0 512 | for i, aa in enumerate(seq): 513 | if mask_regions[n][i] and torch.rand(1) < p and mask_count < sum(mask_regions[n]) - 1: 514 | masked_seq += mask_token 515 | mask_indices.append([n, i]) 516 | mask_count += 1 517 | else: 518 | masked_seq += aa 519 | masked_sequences.append(masked_seq) 520 | 521 | return masked_sequences, mask_indices -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import re 5 | 6 | from .utils import get_sequence_embeddings, insert_spaces, get_attention_mask, apply_masking_seq 7 | from .swe_pooling import SWE_Pooling 8 | 9 | class EpitopeEncoderESM(nn.Module): 10 | def __init__(self, input_dim, projection_dim, ln_cfg, model_config, hidden_dim=1024, device='cpu'): 11 | super().__init__() 12 | 13 | self.ln_config = ln_cfg 14 | self.model_config = model_config 15 | self.projection_dim = projection_dim 16 | 17 | if self.model_config.receptor_model_name == 'esm3': 18 | from .lora import setup_peft_esm3 19 | from .configs import peft_config_esm3 20 | 21 | # load the LoRA adapted ESM-3 Model here: 22 | self.esm_lora, self.esm_tokenizer = setup_peft_esm3(peft_config_esm3, ln_cfg.no_lora) 23 | else: 24 | from .lora import setup_peft_esm2 25 | from .configs import peft_config_esm2 26 | 27 | # load the LoRA adapted ESM-2 Model here: 28 | self.esm_lora, self.esm_tokenizer = setup_peft_esm2(peft_config_esm2, ln_cfg.no_lora, ln_cfg.regular_ft) 29 | 30 | # For ESM2, we need linker to represent multimers 31 | self.linker_size = 25 32 | self.gly_linker = 'G'*self.linker_size 33 | 34 | if self.projection_dim: 35 | if hidden_dim: 36 | print("Using multi-layer projection head") 37 | self.proj_head = nn.Sequential( 38 | nn.Linear(input_dim, hidden_dim), 39 | nn.LayerNorm(hidden_dim), 40 | nn.LeakyReLU(), 41 | nn.Dropout(p=0.3), 42 | nn.Linear(hidden_dim, projection_dim), 43 | ) 44 | # Initialize the projection head weights 45 | nn.init.kaiming_uniform_(self.proj_head[0].weight) 46 | nn.init.kaiming_uniform_(self.proj_head[-1].weight) 47 | else: 48 | print("Using single-layer projection head") 49 | self.proj_head = nn.Sequential( 50 | nn.Linear(input_dim, projection_dim), 51 | nn.LayerNorm(projection_dim), 52 | ) 53 | # Initialize the projection head weights 54 | nn.init.kaiming_uniform_(self.proj_head[0].weight) 55 | else: 56 | print("NOT using projection head") 57 | 58 | if self.ln_config.swe_pooling: 59 | self.swe_pooling = SWE_Pooling(d_in=input_dim, num_ref_points=512, num_slices=projection_dim) 60 | 61 | self.proj_head = nn.Sequential( 62 | nn.Linear(projection_dim, projection_dim // 2), 63 | nn.LayerNorm(projection_dim // 2), 64 | ) 65 | 66 | self.device = device 67 | 68 | 69 | def forward(self, x, mask): 70 | seqs_tokens = self.process_seqs(x, mask, mask_prob=self.ln_config.mask_prob) 71 | 72 | if self.model_config.receptor_model_name == 'esm3': 73 | outputs = self.esm_lora(sequence_tokens=seqs_tokens['input_ids']).embeddings 74 | else: 75 | outputs = self.esm_lora(**seqs_tokens) 76 | 77 | if self.ln_config.swe_pooling: 78 | assert self.ln_config.include_mhc == False, "SWE pooling not supported for MHC sequences yet" # TODO: implement for MHC sequences 79 | # for SWE pooling: 80 | attn_mask = get_attention_mask(seqs_tokens, is_sep=False, is_cls=False) 81 | # attn_mask = get_attention_mask(seqs_tokens) 82 | if isinstance(outputs, dict): 83 | outputs = outputs['last_hidden_state'] 84 | elif self.model_config.receptor_model_name == 'esm3': 85 | pass 86 | else: 87 | outputs = outputs.last_hidden_state 88 | seq_embeds = self.swe_pooling(outputs, attn_mask) 89 | else: 90 | # for regular mean pooling 91 | epitope_mask = None 92 | if self.ln_config.include_mhc: 93 | epitope_seqs, mhca_seqs, mhcb_seqs = x 94 | epitope_mask = torch.zeros_like(seqs_tokens['attention_mask']) 95 | for i, (seq, mhcA, mhcB) in enumerate(zip(epitope_seqs, mhca_seqs, mhcb_seqs)): 96 | # assumes no special tokens 97 | epitope_mask[i, len(mhcA) + self.linker_size : len(mhcA) + self.linker_size + len(seq)] = 1 98 | 99 | if self.model_config.receptor_model_name == 'esm3': 100 | outputs = {'last_hidden_state': outputs} 101 | seq_embeds = get_sequence_embeddings(seqs_tokens, outputs, is_sep=False, is_cls=False, epitope_mask=epitope_mask) 102 | 103 | if self.projection_dim: 104 | return self.proj_head(seq_embeds) 105 | else: 106 | return seq_embeds 107 | 108 | def process_seqs(self, inputs, mask, mask_prob=0.15): 109 | ''' 110 | input: list of epitope sequences or epitope-mhc array (3,N) where N is the number of samples 111 | 112 | if self.include_mhc = True, expecting input to be list containing tuples of epitope and MHC sequences in form 113 | (epitope_seq, mhc.a_seq, mhc.b_seq) 114 | 115 | if self.include_mhc = False, expecting input to be list of strings of epitope sequences 116 | ''' 117 | if self.ln_config.include_mhc: 118 | epitope_seqs, mhca_seqs, mhcb_seqs = inputs 119 | 120 | if self.ln_config.mhc_groove_only: 121 | # keep only A1+A2 domains (roundly AA 0-180) for class 1 MHCs, and A1+B1 domains (rougly AA 0-90) for class 2 MHCs 122 | for i, (mhcA, mhcB) in enumerate(zip(mhca_seqs, mhcb_seqs)): 123 | if mhcB == "B2M": 124 | mhca_seqs[i] = mhcA[:180] 125 | else: 126 | mhca_seqs[i] = mhcA[:90] 127 | mhcb_seqs[i] = mhcB[:90] 128 | 129 | # Create the pMHC sequence in the order [mhcA ..G.. epitope ..G.. mhcB] 130 | seqs = [ 131 | ( 132 | f"{mhcA}{self.gly_linker}{seq}{self.gly_linker}{mhcB}" 133 | if mhcB != "B2M" else 134 | f"{mhcA}{self.gly_linker}{seq}" 135 | ) 136 | for seq, mhcA, mhcB in zip(epitope_seqs, mhca_seqs, mhcb_seqs) 137 | ] 138 | 139 | # marking where the Glycine linker starts 140 | # linker between mhcA and epitope 141 | attn_starts = [(i, len(mhcA)) for i, mhcA in enumerate(mhca_seqs)] 142 | # linker between epitope and mhcB 143 | attn_starts.extend([ 144 | (i, len(mhcA) + self.linker_size + len(seq)) 145 | for i, (seq, mhcA, mhcB) in enumerate(zip(epitope_seqs, mhca_seqs, mhcb_seqs)) if mhcB != "B2M" 146 | ]) 147 | else: 148 | seqs = inputs 149 | 150 | # removing special tokens since epitopes are protein fragments (peptides) 151 | seqs_tokens = self.esm_tokenizer(seqs, return_tensors="pt", add_special_tokens=False, padding=True) 152 | 153 | if mask: 154 | if self.ln_config.include_mhc: 155 | mask_regions = [np.zeros(len(seq), dtype=bool) for seq in seqs] 156 | for i, (seq, mhcA, mhcB) in enumerate(zip(epitope_seqs, mhca_seqs, mhcb_seqs)): 157 | # always include epitope sequence for random masking 158 | epitope_offset = len(mhcA) + self.linker_size 159 | mask_regions[i][epitope_offset : epitope_offset + len(seq)] = True 160 | # if class I MHC, only apply random masks to A1+A2 domains (rougly AAs 0-180) 161 | if mhcB == "B2M": 162 | mask_regions[i][0 : min(180, len(mhcA))] = True 163 | 164 | # if class II MHC, only apply random masks to A1+B1 domains (rougly AAs 0-90 for each) 165 | else: 166 | mask_regions[i][0 : min(90, len(mhcA))] = True 167 | beta_offset = len(mhcA) + self.linker_size + len(seq) + self.linker_size 168 | mask_regions[i][beta_offset : min(beta_offset+90, beta_offset+len(mhcB))] = True 169 | else: 170 | mask_regions = True # seqs is just the epitope, so all tokens can be masked 171 | 172 | # masking the sequences for training 173 | seqs, attn_mask_indices = apply_masking_seq(seqs, mask_token='', mask_regions=mask_regions, p=mask_prob) 174 | indices_tensor = torch.tensor(attn_mask_indices, dtype=torch.long) 175 | if len(indices_tensor) > 0: 176 | seqs_tokens['attention_mask'][indices_tensor[:, 0], indices_tensor[:, 1]] = 0. 177 | 178 | # if necessary, masking the linker region 179 | if self.ln_config.include_mhc: 180 | for i, start in attn_starts: 181 | seqs_tokens['attention_mask'][i, start:start+self.linker_size] = 0. 182 | 183 | return seqs_tokens.to(self.device) 184 | 185 | class EpitopeEncoderOneHot(nn.Module): 186 | def __init__(self, input_dim, projection_dim, ln_cfg, model_config, device='cpu'): 187 | super().__init__() 188 | 189 | self.ln_config = ln_cfg 190 | self.projection_dim = projection_dim 191 | 192 | if self.projection_dim: 193 | print("Using single-layer projection head") 194 | self.proj_head = nn.Sequential( 195 | nn.Linear(input_dim, projection_dim), 196 | nn.LayerNorm(projection_dim), 197 | ) 198 | else: 199 | assert False, "Projection head must be used with one-hot encoding!" 200 | 201 | # Define the amino acid to index mapping 202 | self.amino_acid_to_index = { 203 | 'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 204 | 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 205 | 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14, 206 | 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19, 207 | 'X': 20 # Unknown amino acid 208 | } 209 | 210 | self.device = device 211 | 212 | def forward(self, x, mask): 213 | seqs = x 214 | seqs_onehot = self.create_padded_one_hot_tensor(seqs, len(self.amino_acid_to_index)) 215 | 216 | proj_output = self.proj_head(seqs_onehot) 217 | 218 | # average the projected embeddings by seq length: 219 | seq_lens = torch.sum(seqs_onehot, dim=(1, 2)) 220 | # Create a mask with shape (batch_size, max_seq_length) 221 | seq_mask = torch.arange(proj_output.size(1)).unsqueeze(0).to(self.device) < seq_lens.unsqueeze(-1) 222 | seq_mask = seq_mask.unsqueeze(2) # Shape (batch_size, max_seq_length, 1) 223 | # Sum the embeddings across the sequence length dimension using the mask 224 | masked_embeddings = proj_output * seq_mask 225 | sum_embeddings = masked_embeddings.sum(dim=1) 226 | 227 | # Divide by the true sequence lengths to get the average 228 | avg_embeddings = sum_embeddings / seq_lens.unsqueeze(1)#.to(embeddings.device) 229 | 230 | return avg_embeddings 231 | 232 | 233 | # @staticmethod 234 | def encode_amino_acid_sequence(self, sequence): 235 | """ Convert an amino acid sequence to a list of indices. """ 236 | return [self.amino_acid_to_index[aa] for aa in sequence] 237 | 238 | # @staticmethod 239 | def one_hot_encode_sequence(self, sequence, vocab_size): 240 | """ One-hot encode a single sequence. """ 241 | encoding = np.zeros((len(sequence), vocab_size), dtype=int) 242 | for idx, char in enumerate(sequence): 243 | encoding[idx, char] = 1 244 | return encoding 245 | 246 | # @staticmethod 247 | def pad_sequences(self, encoded_sequences, max_length): 248 | """ Pad the encoded sequences to the maximum length. """ 249 | padded_sequences = [] 250 | for seq in encoded_sequences: 251 | padded_seq = np.pad(seq, ((0, max_length - len(seq)), (0, 0)), mode='constant', constant_values=0) 252 | padded_sequences.append(padded_seq) 253 | return np.array(padded_sequences) 254 | 255 | # @staticmethod 256 | def create_padded_one_hot_tensor(self, sequences, vocab_size): 257 | """ Convert a batch of sequences to a padded one-hot encoding tensor. """ 258 | # Encode and one-hot encode each sequence 259 | encoded_sequences = [self.one_hot_encode_sequence(self.encode_amino_acid_sequence(seq), vocab_size) for seq in sequences] 260 | 261 | # Determine the maximum sequence length 262 | max_length = max(len(seq) for seq in sequences) 263 | 264 | # Pad the sequences 265 | padded_sequences = self.pad_sequences(encoded_sequences, max_length) 266 | 267 | # Convert to a PyTorch tensor 268 | padded_tensor = torch.tensor(padded_sequences, dtype=torch.float32) 269 | 270 | return padded_tensor.to(self.device) 271 | 272 | 273 | class TCREncoderTCRBert(nn.Module): 274 | def __init__(self, input_dim, projection_dim, ln_cfg, hidden_dim=1024, device='cpu'): 275 | super().__init__() 276 | from .lora import setup_peft_tcrbert 277 | from .configs import peft_config_tcrbert 278 | 279 | self.tcrbert_tra_lora, self.tcrbert_tra_tokenizer = setup_peft_tcrbert(peft_config_tcrbert, no_lora=ln_cfg.no_lora, regular_ft=ln_cfg.regular_ft) 280 | self.tcrbert_trb_lora, self.tcrbert_trb_tokenizer = setup_peft_tcrbert(peft_config_tcrbert, no_lora=ln_cfg.no_lora, regular_ft=ln_cfg.regular_ft) 281 | 282 | self.ln_config = ln_cfg 283 | 284 | if hidden_dim: 285 | print("Using multi-layer projection head") 286 | self.proj_head = nn.Sequential( 287 | nn.Linear(input_dim, hidden_dim), 288 | nn.LayerNorm(hidden_dim), 289 | nn.LeakyReLU(), 290 | nn.Dropout(p=0.3), 291 | nn.Linear(hidden_dim, projection_dim), 292 | ) 293 | # Initialize the projection head weights 294 | nn.init.kaiming_uniform_(self.proj_head[0].weight) 295 | nn.init.kaiming_uniform_(self.proj_head[-1].weight) 296 | else: 297 | print("Using single-layer projection head") 298 | self.proj_head = nn.Sequential( 299 | nn.Linear(input_dim, projection_dim), 300 | nn.LayerNorm(projection_dim), 301 | ) 302 | # Initialize the projection head weights 303 | nn.init.kaiming_uniform_(self.proj_head[0].weight) 304 | 305 | if self.ln_config.swe_pooling: 306 | self.swe_pooling_a = SWE_Pooling(d_in=input_dim // 2, num_ref_points=256, num_slices=projection_dim // 2) 307 | self.swe_pooling_b = SWE_Pooling(d_in=input_dim // 2, num_ref_points=256, num_slices=projection_dim // 2) 308 | 309 | self.proj_head = nn.Sequential( 310 | nn.Linear(projection_dim, projection_dim // 2), 311 | nn.LayerNorm(projection_dim // 2), 312 | ) 313 | 314 | self.device = device 315 | 316 | def forward(self, x, mask): 317 | tra_tokens, trb_tokens = self.process_seqs(x, mask, mask_prob=self.ln_config.mask_prob) 318 | 319 | # feed to TCRBERT 320 | rescoding_tra = self.tcrbert_tra_lora(**tra_tokens) 321 | rescoding_trb = self.tcrbert_trb_lora(**trb_tokens) 322 | 323 | if self.ln_config.swe_pooling: 324 | # for SWE pooling: 325 | attn_mask_a = get_attention_mask(tra_tokens) 326 | if isinstance(rescoding_tra, dict): 327 | rescoding_tra = rescoding_tra['last_hidden_state'] 328 | else: 329 | rescoding_tra = rescoding_tra.last_hidden_state 330 | tra_outputs = self.swe_pooling_a(rescoding_tra, attn_mask_a) 331 | 332 | attn_mask_b = get_attention_mask(trb_tokens) 333 | if isinstance(rescoding_trb, dict): 334 | rescoding_trb = rescoding_trb['last_hidden_state'] 335 | else: 336 | rescoding_trb = rescoding_trb.last_hidden_state 337 | trb_outputs = self.swe_pooling_b(rescoding_trb, attn_mask_b) 338 | else: 339 | # for regular mean pooling 340 | tra_outputs = get_sequence_embeddings(tra_tokens, rescoding_tra) 341 | trb_outputs = get_sequence_embeddings(trb_tokens, rescoding_trb) 342 | 343 | tcr_embeds = torch.cat((tra_outputs, trb_outputs), dim=-1) 344 | 345 | return self.proj_head(tcr_embeds) 346 | 347 | 348 | def process_seqs(self, seqs, mask, mask_prob=0.15): 349 | tra_seqs_, trb_seqs_ = seqs 350 | 351 | # insert spaces between residues for correct formatting: 352 | tra_seqs = [insert_spaces(seq) for seq in tra_seqs_] 353 | trb_seqs = [insert_spaces(seq) for seq in trb_seqs_] 354 | 355 | tra_tokens = self.tcrbert_tra_tokenizer(tra_seqs, return_tensors="pt", padding=True) 356 | trb_tokens = self.tcrbert_trb_tokenizer(trb_seqs, return_tensors="pt", padding=True) 357 | 358 | if mask: 359 | # masking the sequences for training 360 | tra_seqs_, attn_mask_indices = apply_masking_seq(tra_seqs_, p=mask_prob) 361 | indices_tensor = torch.tensor(attn_mask_indices, dtype=torch.long) 362 | if len(indices_tensor) > 0: 363 | tra_tokens['attention_mask'][indices_tensor[:, 0], indices_tensor[:, 1] + 1] = 0. # +1 to account for the CLS token 364 | 365 | trb_seqs_, attn_mask_indices = apply_masking_seq(trb_seqs_, p=mask_prob) 366 | indices_tensor = torch.tensor(attn_mask_indices, dtype=torch.long) 367 | if len(indices_tensor) > 0: 368 | trb_tokens['attention_mask'][indices_tensor[:, 0], indices_tensor[:, 1] + 1] = 0. # +1 to account for the CLS token 369 | 370 | # print("TCR Seqs:", tra_seqs_) 371 | # print("seqs_tokens:", tra_tokens['attention_mask']) 372 | 373 | return tra_tokens.to(self.device), trb_tokens.to(self.device) 374 | 375 | 376 | class TCREncoderTCRLang(nn.Module): 377 | def __init__(self, input_dim, projection_dim, ln_cfg, hidden_dim=1024, device='cpu'): 378 | super().__init__() 379 | from .lora import setup_peft_ablang2 380 | from .configs import peft_config_ablang2 381 | 382 | self.ablang2_lora, self.ablang2_tokenizer = setup_peft_ablang2(peft_config_ablang2, receptor_type='TCR', device=device, no_lora=ln_cfg.no_lora) 383 | self.padding_idx = 21 384 | self.mask_token = 23 385 | self.sep_token_id = 25 386 | 387 | self.ln_config = ln_cfg 388 | 389 | if hidden_dim: 390 | print("Using multi-layer projection head") 391 | self.proj_head = nn.Sequential( 392 | nn.Linear(input_dim, hidden_dim), 393 | nn.LayerNorm(hidden_dim), 394 | nn.LeakyReLU(), 395 | nn.Dropout(p=0.5), 396 | nn.Linear(hidden_dim, projection_dim), 397 | ) 398 | else: 399 | print("Using single-layer projection head") 400 | self.proj_head = nn.Sequential( 401 | nn.Linear(input_dim, projection_dim), 402 | nn.LayerNorm(projection_dim), 403 | ) 404 | 405 | if self.ln_config.swe_pooling: 406 | self.swe_pooling = SWE_Pooling(d_in=input_dim, num_ref_points=512, num_slices=projection_dim) 407 | 408 | self.proj_head = nn.Sequential( 409 | nn.Linear(projection_dim, projection_dim // 2), 410 | nn.LayerNorm(projection_dim // 2), 411 | ) 412 | 413 | self.device = device 414 | 415 | def forward(self, x, mask): 416 | seq_tokens = self.process_seqs(x, mask, mask_prob=self.ln_config.mask_prob) 417 | 418 | # print("seq tokens:", seq_tokens) 419 | 420 | # feed to TCRLang 421 | rescoding = self.ablang2_lora(seq_tokens) 422 | 423 | # process TCRLang outputs 424 | seq_inputs = {'attention_mask': ~((seq_tokens == self.padding_idx) | (seq_tokens == self.mask_token))} 425 | model_output = {'last_hidden_state': rescoding.last_hidden_states} 426 | 427 | if self.ln_config.swe_pooling: 428 | # for SWE pooling: 429 | attn_mask = seq_inputs['attention_mask'] 430 | model_embed = model_output['last_hidden_state'] 431 | 432 | seq_outputs = self.swe_pooling(model_embed, attn_mask) 433 | 434 | else: 435 | # for regular mean pooling 436 | seq_outputs = get_sequence_embeddings(seq_inputs, model_output, is_sep=False, is_cls=False) 437 | 438 | return self.proj_head(seq_outputs) 439 | 440 | def process_seqs(self, seqs, mask, mask_prob=0.15): 441 | H_seqs, L_seqs = seqs 442 | 443 | # format the seq strings accordingly to TCRLang (B chain comes first, so we swap H and L orders): 444 | ab_seqs = [f"{L_seqs[i]}|{H_seqs[i]}" for i in range(len(H_seqs))] 445 | 446 | seqs_tokens = self.ablang2_tokenizer(ab_seqs, pad=True, w_extra_tkns=False, device=self.device) 447 | 448 | if mask: 449 | # masking the sequences for training 450 | ab_seqs, attn_mask_indices = apply_masking_seq(ab_seqs, mask_token='*', p=mask_prob) 451 | indices_tensor = torch.tensor(attn_mask_indices, dtype=torch.long) 452 | if len(indices_tensor) > 0: 453 | seqs_tokens[indices_tensor[:, 0], indices_tensor[:, 1]] = self.mask_token 454 | 455 | # leave the SEP tokens ('|') unmasked!! 456 | for i, l_seq in enumerate(L_seqs): 457 | seqs_tokens[i, len(l_seq)] = self.sep_token_id 458 | 459 | return seqs_tokens 460 | 461 | class TCREncoderESM(nn.Module): 462 | def __init__(self, input_dim, projection_dim, ln_cfg, model_config=None, hidden_dim=1024, device='cpu'): 463 | super().__init__() 464 | from .lora import setup_peft_esm2 465 | from .configs import peft_config_esm2 466 | 467 | self.ln_config = ln_cfg 468 | self.model_config = model_config 469 | self.projection_dim = projection_dim 470 | 471 | if self.model_config.receptor_model_name == 'esm3': 472 | from .lora import setup_peft_esm3 473 | from .configs import peft_config_esm3 474 | 475 | # load the LoRA adapted ESM-3 Model here: 476 | self.esm_lora, self.esm_tokenizer = setup_peft_esm3(peft_config_esm3, ln_cfg.no_lora) 477 | else: 478 | from .lora import setup_peft_esm2 479 | from .configs import peft_config_esm2 480 | 481 | # load the LoRA adapted ESM-2 Model here: 482 | self.esm_lora, self.esm_tokenizer = setup_peft_esm2(peft_config_esm2, ln_cfg.no_lora, ln_cfg.regular_ft) 483 | 484 | if self.projection_dim: 485 | if hidden_dim: 486 | print("Using multi-layer projection head") 487 | self.proj_head = nn.Sequential( 488 | nn.Linear(input_dim, hidden_dim), 489 | nn.LayerNorm(hidden_dim), 490 | nn.LeakyReLU(), 491 | nn.Dropout(p=0.5), 492 | nn.Linear(hidden_dim, projection_dim), 493 | ) 494 | else: 495 | print("Using single-layer projection head") 496 | self.proj_head = nn.Sequential( 497 | nn.Linear(input_dim, projection_dim), 498 | nn.LayerNorm(projection_dim), 499 | ) 500 | else: 501 | print("NOT using projection head") 502 | 503 | # for ESM-2, we need linker to represent multimers 504 | self.linker_size = 25 505 | self.gly_linker = 'G'*self.linker_size 506 | self.gly_idx = 6 # according to: https://huggingface.co/facebook/esm2_t33_650M_UR50D/blob/main/vocab.txt 507 | 508 | self.device = device 509 | 510 | def forward(self, x, mask): 511 | seqs_tokens = self.process_seqs(x, mask, mask_prob=self.ln_config.mask_prob) 512 | 513 | if self.model_config.receptor_model_name == 'esm3': 514 | outputs = self.esm_lora(sequence_tokens=seqs_tokens['input_ids']).embeddings 515 | outputs = {'last_hidden_state': outputs} 516 | else: 517 | outputs = self.esm_lora(**seqs_tokens) 518 | 519 | seq_embeds = get_sequence_embeddings(seqs_tokens, outputs, is_sep=False, is_cls=False) 520 | 521 | if self.projection_dim: 522 | return self.proj_head(seq_embeds) 523 | else: 524 | return seq_embeds 525 | 526 | def process_seqs(self, seqs, mask, mask_prob=0.15): 527 | ''' 528 | seqs: list of epitope sequences 529 | ''' 530 | tra_seqs, trb_seqs = seqs 531 | seqs = [f"{tra_seqs[i]}{self.gly_linker}{trb_seqs[i]}" for i in range(len(tra_seqs))] 532 | mask_regions = [[True]*len(seqa)+[False]*self.linker_size+[True]*len(seqb) for seqa, seqb in zip(tra_seqs, trb_seqs)] 533 | 534 | # marking where the Glycine linker starts 535 | attn_starts = [len(alpha_chain) for alpha_chain in tra_seqs] 536 | 537 | # removing special tokens since epitopes are protein fragments (peptides) 538 | seqs_tokens = self.esm_tokenizer(seqs, return_tensors="pt", add_special_tokens=False, padding=True) 539 | 540 | if mask: 541 | # masking the sequences for training 542 | seqs, attn_mask_indices = apply_masking_seq(seqs, mask_token='', mask_regions=mask_regions, p=mask_prob) 543 | indices_tensor = torch.tensor(attn_mask_indices, dtype=torch.long) 544 | if len(indices_tensor) > 0: 545 | seqs_tokens['attention_mask'][indices_tensor[:, 0], indices_tensor[:, 1]] = 0. 546 | 547 | # remove mask tokens on linker region: 548 | for i in range(len(attn_starts)): 549 | seqs_tokens['input_ids'][i, attn_starts[i]:attn_starts[i]+self.linker_size] = self.gly_idx 550 | 551 | 552 | # attention masking the linker region 553 | for i in range(len(attn_starts)): 554 | seqs_tokens['attention_mask'][i, attn_starts[i]:attn_starts[i]+self.linker_size] = 0. 555 | 556 | # print("TCR Seqs:", seqs) 557 | # print("seqs_tokens:", seqs_tokens['attention_mask']) 558 | 559 | return seqs_tokens.to(self.device) 560 | 561 | 562 | class TCREncoderESMBetaOnly(nn.Module): 563 | def __init__(self, input_dim, projection_dim, ln_cfg, model_config=None, hidden_dim=1024, device='cpu'): 564 | super().__init__() 565 | from .lora import setup_peft_esm2 566 | from .configs import peft_config_esm2 567 | 568 | # # load the LoRA adapted ESM Model here: 569 | # self.esm_lora, self.esm_tokenizer = setup_peft_esm2(peft_config_esm2, ln_cfg.no_lora) 570 | 571 | self.ln_config = ln_cfg 572 | self.model_config = model_config 573 | self.projection_dim = projection_dim 574 | 575 | if self.model_config.receptor_model_name == 'esm3': 576 | from .lora import setup_peft_esm3 577 | from .configs import peft_config_esm3 578 | 579 | # load the LoRA adapted ESM-3 Model here: 580 | self.esm_lora, self.esm_tokenizer = setup_peft_esm3(peft_config_esm3, ln_cfg.no_lora) 581 | else: 582 | from .lora import setup_peft_esm2 583 | from .configs import peft_config_esm2 584 | 585 | # load the LoRA adapted ESM-2 Model here: 586 | self.esm_lora, self.esm_tokenizer = setup_peft_esm2(peft_config_esm2, ln_cfg.no_lora) 587 | 588 | if self.projection_dim: 589 | if hidden_dim: 590 | print("Using multi-layer projection head") 591 | self.proj_head = nn.Sequential( 592 | nn.Linear(input_dim, hidden_dim), 593 | nn.LayerNorm(hidden_dim), 594 | nn.LeakyReLU(), 595 | nn.Dropout(p=0.5), 596 | nn.Linear(hidden_dim, projection_dim), 597 | ) 598 | else: 599 | print("Using single-layer projection head") 600 | self.proj_head = nn.Sequential( 601 | nn.Linear(input_dim, projection_dim), 602 | nn.LayerNorm(projection_dim), 603 | ) 604 | else: 605 | print("NOT using projection head") 606 | 607 | self.device = device 608 | 609 | def forward(self, x, mask): 610 | seqs_tokens = self.process_seqs(x, mask, mask_prob=self.ln_config.mask_prob) 611 | 612 | if self.model_config.receptor_model_name == 'esm3': 613 | outputs = self.esm_lora(sequence_tokens=seqs_tokens['input_ids']).embeddings 614 | outputs = {'last_hidden_state': outputs} 615 | else: 616 | outputs = self.esm_lora(**seqs_tokens) 617 | 618 | seq_embeds = get_sequence_embeddings(seqs_tokens, outputs, is_sep=False, is_cls=False) 619 | 620 | if self.projection_dim: 621 | return self.proj_head(seq_embeds) 622 | else: 623 | return seq_embeds 624 | 625 | def process_seqs(self, seqs, mask, mask_prob=0.15): 626 | ''' 627 | seqs: list of epitope sequences 628 | ''' 629 | tra_seqs, trb_seqs = seqs 630 | seqs = trb_seqs 631 | 632 | # removing special tokens since epitopes are protein fragments (peptides) 633 | seqs_tokens = self.esm_tokenizer(seqs, return_tensors="pt", add_special_tokens=False, padding=True) 634 | 635 | if mask: 636 | # masking the sequences for training 637 | seqs, attn_mask_indices = apply_masking_seq(seqs, mask_token='', p=mask_prob) 638 | indices_tensor = torch.tensor(attn_mask_indices, dtype=torch.long) 639 | if len(indices_tensor) > 0: 640 | seqs_tokens['attention_mask'][indices_tensor[:, 0], indices_tensor[:, 1]] = 0. 641 | 642 | return seqs_tokens.to(self.device) 643 | 644 | 645 | class TCREncoderInHouse(nn.Module): 646 | def __init__(self, input_dim, projection_dim, ln_cfg, model_config=None, hidden_dim=1024, device='cpu'): 647 | super().__init__() 648 | from .lora import setup_peft_inhouse 649 | from .configs import peft_config_inhouse 650 | import os 651 | 652 | model_ckpt_path = os.getenv('INHOUSE_MODEL_CKPT_PATH') 653 | 654 | self.inhouse_lora, self.inhouse_tokenizer = setup_peft_inhouse(peft_config_inhouse, ln_cfg.no_lora, model_ckpt_path=model_ckpt_path) 655 | 656 | self.ln_config = ln_cfg 657 | self.model_config = model_config 658 | self.projection_dim = projection_dim 659 | 660 | if self.projection_dim: 661 | if hidden_dim: 662 | print("Using multi-layer projection head") 663 | self.proj_head = nn.Sequential( 664 | nn.Linear(input_dim, hidden_dim), 665 | nn.LayerNorm(hidden_dim), 666 | nn.LeakyReLU(), 667 | nn.Dropout(p=0.5), 668 | nn.Linear(hidden_dim, projection_dim), 669 | ) 670 | else: 671 | print("Using single-layer projection head") 672 | self.proj_head = nn.Sequential( 673 | nn.Linear(input_dim, projection_dim), 674 | nn.LayerNorm(projection_dim), 675 | ) 676 | else: 677 | print("NOT using projection head") 678 | 679 | self.device = device 680 | 681 | def forward(self, x, mask): 682 | seq_tokens = self.process_seqs(x, mask, mask_prob=self.ln_config.mask_prob) 683 | 684 | # feed to InHouse Model 685 | # print("seq tokens input_ids:", seq_tokens['input_ids']) 686 | # print("seq tokens attention_mask:", seq_tokens['attention_mask']) 687 | seq_outputs, _ = self.inhouse_lora(seq_tokens["input_ids"], seq_tokens["attention_mask"]) 688 | 689 | # print("seq outputs:", seq_outputs) 690 | 691 | return self.proj_head(seq_outputs) 692 | 693 | def process_seqs(self, seqs, mask, mask_prob=0.15): 694 | tra_seqs, trb_seqs = seqs 695 | 696 | if mask: 697 | tra_seqs_, tra_masks = apply_masking_seq(tra_seqs, mask_token='', p=mask_prob) 698 | trb_seqs_, trb_masks = apply_masking_seq(trb_seqs, mask_token='', p=mask_prob) 699 | 700 | # adjust the tra_masks and trb_masks to the correct indices: 701 | tra_masks = [(n, 1+i) for (n, i) in tra_masks] 702 | trb_masks = [(n, 1+len(tra_seqs[n])+2+i) for (n, i) in trb_masks] 703 | 704 | indices_tensor = torch.tensor(tra_masks + trb_masks, dtype=torch.long) 705 | 706 | tra_seqs, trb_seqs = tra_seqs_, trb_seqs_ 707 | 708 | # format the seq strings accordingly to InHouse: 709 | ab_seqs = [self.apply_special_token_formatting(tra_seqs[i], trb_seqs[i]) for i in range(len(tra_seqs))] 710 | 711 | seqs_tokens = self.inhouse_tokenizer(ab_seqs, return_tensors="pt", add_special_tokens=False, padding=True) 712 | 713 | # adjust the attention mask 714 | if mask and len(indices_tensor) > 0: 715 | seqs_tokens['attention_mask'][indices_tensor[:, 0], indices_tensor[:, 1]] = 0. 716 | 717 | return seqs_tokens.to(self.device) 718 | 719 | def apply_special_token_formatting(self, alpha, beta): 720 | ''' 721 | Apply RoBERTa style formatting to input: 722 | seq1seq2 723 | ''' 724 | return f"{self.inhouse_tokenizer.cls_token}{alpha}{self.inhouse_tokenizer.eos_token}{self.inhouse_tokenizer.eos_token}{beta}{self.inhouse_tokenizer.eos_token}" 725 | 726 | class TCREncoderOneHot(nn.Module): 727 | def __init__(self, input_dim, projection_dim, ln_cfg, model_config, device='cpu'): 728 | super().__init__() 729 | 730 | self.ln_config = ln_cfg 731 | self.projection_dim = projection_dim 732 | 733 | if self.projection_dim: 734 | print("Using single-layer projection head") 735 | self.proj_head = nn.Sequential( 736 | nn.Linear(input_dim, projection_dim), 737 | nn.LayerNorm(projection_dim), 738 | ) 739 | else: 740 | assert False, "Projection head must be used with one-hot encoding!" 741 | 742 | # Define the amino acid to index mapping 743 | self.amino_acid_to_index = { 744 | 'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 745 | 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 746 | 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14, 747 | 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19, 748 | 'X': 20 # Unknown amino acid 749 | } 750 | 751 | self.device = device 752 | 753 | def forward(self, x, mask): 754 | seqs = [seqa + seqb for seqa, seqb in zip(x[0], x[1])] 755 | seqs_onehot = self.create_padded_one_hot_tensor(seqs, len(self.amino_acid_to_index)) 756 | 757 | proj_output = self.proj_head(seqs_onehot) 758 | 759 | # average the projected embeddings by seq length: 760 | seq_lens = torch.sum(seqs_onehot, dim=(1, 2)) 761 | # Create a mask with shape (batch_size, max_seq_length) 762 | seq_mask = torch.arange(proj_output.size(1)).unsqueeze(0).to(self.device) < seq_lens.unsqueeze(-1) 763 | seq_mask = seq_mask.unsqueeze(2) # Shape (batch_size, max_seq_length, 1) 764 | # Sum the embeddings across the sequence length dimension using the mask 765 | masked_embeddings = proj_output * seq_mask 766 | sum_embeddings = masked_embeddings.sum(dim=1) 767 | 768 | # Divide by the true sequence lengths to get the average 769 | avg_embeddings = sum_embeddings / seq_lens.unsqueeze(1)#.to(embeddings.device) 770 | 771 | return avg_embeddings 772 | 773 | 774 | # @staticmethod 775 | def encode_amino_acid_sequence(self, sequence): 776 | """ Convert an amino acid sequence to a list of indices. """ 777 | return [self.amino_acid_to_index[aa] for aa in sequence] 778 | 779 | # @staticmethod 780 | def one_hot_encode_sequence(self, sequence, vocab_size): 781 | """ One-hot encode a single sequence. """ 782 | encoding = np.zeros((len(sequence), vocab_size), dtype=int) 783 | for idx, char in enumerate(sequence): 784 | encoding[idx, char] = 1 785 | return encoding 786 | 787 | # @staticmethod 788 | def pad_sequences(self, encoded_sequences, max_length): 789 | """ Pad the encoded sequences to the maximum length. """ 790 | padded_sequences = [] 791 | for seq in encoded_sequences: 792 | padded_seq = np.pad(seq, ((0, max_length - len(seq)), (0, 0)), mode='constant', constant_values=0) 793 | padded_sequences.append(padded_seq) 794 | return np.array(padded_sequences) 795 | 796 | # @staticmethod 797 | def create_padded_one_hot_tensor(self, sequences, vocab_size): 798 | """ Convert a batch of sequences to a padded one-hot encoding tensor. """ 799 | # Encode and one-hot encode each sequence 800 | encoded_sequences = [self.one_hot_encode_sequence(self.encode_amino_acid_sequence(seq), vocab_size) for seq in sequences] 801 | 802 | # Determine the maximum sequence length 803 | max_length = max(len(seq) for seq in sequences) 804 | 805 | # Pad the sequences 806 | padded_sequences = self.pad_sequences(encoded_sequences, max_length) 807 | 808 | # Convert to a PyTorch tensor 809 | padded_tensor = torch.tensor(padded_sequences, dtype=torch.float32) 810 | 811 | return padded_tensor.to(self.device) --------------------------------------------------------------------------------