├── .gitignore
├── figs
└── architecture_v2.PNG
├── __init__.py
├── src
├── __init__.py
├── logging.py
├── lr_scheduler.py
├── models_ab.py
├── configs.py
├── lora.py
├── training.py
├── swe_pooling.py
├── alignment.py
├── data_module.py
├── utils.py
└── models.py
├── setup.py
├── LICENSE
├── README.md
└── environment.yml
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 |
--------------------------------------------------------------------------------
/figs/architecture_v2.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kundajelab/ImmuneCLIP/HEAD/figs/architecture_v2.PNG
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # This file can be empty, or you can use it to define what gets imported
2 | # when you do `from project import *`
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | # This file can be empty, or you can use it to define what gets imported
2 | # when you do `from project import *`
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='ImmunoAlign',
5 | version='0.0.1',
6 | packages=find_packages(),
7 | install_requires=[
8 | 'torch',
9 | 'pytorch-lightning',
10 | # add other dependencies here
11 | ],
12 | )
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Kundaje Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/logging.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
4 | from typing import List
5 | from argparse import Namespace
6 | from dataclasses import is_dataclass
7 |
8 | def call_backs(output_dir, save_ckpts=False):
9 | lrmonitor_callback = LearningRateMonitor(logging_interval='step')
10 |
11 | #Example 2: other options
12 | checkpoint_callback = ModelCheckpoint(
13 | monitor='val_loss',
14 | dirpath=os.path.join(output_dir, 'checkpoints'),
15 | filename="model-{epoch:03d}-{val_loss:.4f}",
16 | save_top_k=2,
17 | mode='min',
18 | save_last=True,
19 | )
20 |
21 | if save_ckpts:
22 | callbacks = [checkpoint_callback, lrmonitor_callback]
23 | else:
24 | callbacks = [lrmonitor_callback]
25 |
26 | return callbacks
27 |
28 | def combine_args_and_configs(args: Namespace, dataclasses: List):
29 | if not isinstance(args, dict):
30 | args = vars(args).items()
31 | else:
32 | args = args.items()
33 | for name, value in args:
34 | if value is not None:
35 | for obj in dataclasses:
36 | if is_dataclass(obj) and hasattr(obj, name):
37 | print("overwriting default", name, value)
38 | setattr(obj, name, value)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ImmuneCLIP
2 | Code for paper "Sequence-based TCR-Peptide Representations Using Cross-Epitope Contrastive Fine-tuning of Protein Language Models." [[paper](https://link.springer.com/chapter/10.1007/978-3-031-90252-9_3)] (_RECOMB 2025_)
3 |
4 | 🚧 This repository is under active construction. 🚧
5 |
6 | ## Model Components:
7 |
8 |
9 | * Epitope Encoder
10 | * PEFT-adapted Protein Language Models (e.g. ESM-2, ESM-3)
11 | * Default: Using LoRA (rank = 8) on last 8 transformer layers
12 | * Projection layer: FC linear layer (dim $d_{e} \rightarrow d_p$)
13 | * $d_{e}$ is the original PLM dimension, $d_p$ is the projection dimension
14 |
15 | * Recepter Encoder
16 | * PEFT-adapted Protein Language Models (e.g. ESM-2, ESM-3) **or** BCR/TCR Language Models (e.g. AbLang, TCR-BERT, etc.)
17 | * Default: Using LoRA (rank = 8) on last 4 transformer layers
18 | * Projection layer: FC linear layer (dim $d_{r} \rightarrow d_p$)
19 | * $d_{r}$ is the original receptor LM dimension, $d_p$ is the projection dimension
20 |
21 | ## Dataset:
22 | * MixTCRPred Dataset ([paper](https://github.com/GfellerLab/MixTCRpred/tree/main))
23 | * Contains curated mixture of TCR-pMHC sequence data from IEDB, VDJdb, 10x Genomics, and McPAS-TCR
24 |
25 | ## Pre-trained Weights:
26 | * The pre-trained weights for ImmuneCLIP is deposited at [Zenodo](https://zenodo.org/records/14962685)
27 |
28 | ## CLI:
29 | ### Environment Variables
30 | To run this application, set the following environment variables
31 | ```
32 | WANDB_OUTPUT_DIR=
33 | ```
34 |
35 | Additionally, if training on top of a custom in-house TCR model, the following path needs to be set
36 | ```
37 | INHOUSE_MODEL_CKPT_PATH=
38 | ```
39 |
40 |
41 | ### Training
42 | ```
43 | # go to root directory of the repo, and then run:
44 | python -m src.training --run-id [RUN_ID] --dataset-path [PATH_TO_DATASET] --stage fit --max-epochs 100 \\
45 | --receptor-model-name [esm2|tcrlang|tcrbert] --projection-dim 512 --gpus-used [GPU_IDX] --lr 1e-3 \\
46 | --batch-size 8 --output-dir [CHECKPOINTS_OUTPUT_DIR] [--mask-seqs]
47 | ```
48 |
49 | ### Evaluation
50 | ```
51 | # currently, running model on test stage embeds the test set epitope/receptor pairs with the fine-tuned model and saves them.
52 | python -m src.training --run-id [RUN_ID] --dataset-path [PATH_TO_DATASET] --stage test --from-checkpoint [CHECKPOINT_PATH] \\
53 | --projection-dim 512 --receptor-model-name [esm2|tcrlang|tcrbert] --gpus-used [GPU_IDX] --batch-size 8 \\
54 | --save-embed-path [PATH_FOR_SAVING_EMBEDS]
55 | ```
56 |
--------------------------------------------------------------------------------
/src/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.optim.lr_scheduler import _LRScheduler
3 |
4 | class CosineAnnealingWarmUpRestarts(_LRScheduler):
5 | def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1):
6 | if T_0 <= 0 or not isinstance(T_0, int):
7 | raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
8 | if T_mult < 1 or not isinstance(T_mult, int):
9 | raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
10 | if T_up < 0 or not isinstance(T_up, int):
11 | raise ValueError("Expected positive integer T_up, but got {}".format(T_up))
12 | self.T_0 = T_0
13 | self.T_mult = T_mult
14 | self.base_eta_max = eta_max
15 | self.eta_max = eta_max
16 | self.T_up = T_up
17 | self.T_i = T_0
18 | self.gamma = gamma
19 | self.cycle = 0
20 | self.T_cur = last_epoch
21 | super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch)
22 |
23 | def get_lr(self):
24 | if self.T_cur == -1:
25 | return self.base_lrs
26 | elif self.T_cur < self.T_up:
27 | return [(self.eta_max - base_lr)*self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs]
28 | else:
29 | return [base_lr + (self.eta_max - base_lr) * (1 + math.cos(math.pi * (self.T_cur-self.T_up) / (self.T_i - self.T_up))) / 2
30 | for base_lr in self.base_lrs]
31 |
32 | def step(self, epoch=None):
33 | if epoch is None:
34 | epoch = self.last_epoch + 1
35 | self.T_cur = self.T_cur + 1
36 | if self.T_cur >= self.T_i:
37 | self.cycle += 1
38 | self.T_cur = self.T_cur - self.T_i
39 | self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up
40 | else:
41 | if epoch >= self.T_0:
42 | if self.T_mult == 1:
43 | self.T_cur = epoch % self.T_0
44 | self.cycle = epoch // self.T_0
45 | else:
46 | n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
47 | self.cycle = n
48 | self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
49 | self.T_i = self.T_0 * self.T_mult ** (n)
50 | else:
51 | self.T_i = self.T_0
52 | self.T_cur = epoch
53 |
54 | self.eta_max = self.base_eta_max * (self.gamma**self.cycle)
55 | self.last_epoch = math.floor(epoch)
56 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
57 | param_group['lr'] = lr
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: immuneclip
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1
6 | - _openmp_mutex=5.1
7 | - bzip2=1.0.8
8 | - ca-certificates=2024.3.11
9 | - ld_impl_linux-64=2.38
10 | - libffi=3.4.4
11 | - libgcc-ng=11.2.0
12 | - libgomp=11.2.0
13 | - libstdcxx-ng=11.2.0
14 | - libuuid=1.41.5
15 | - ncurses=6.4
16 | - openssl=3.0.14
17 | - pip=24.0
18 | - python=3.10.14
19 | - readline=8.2
20 | - setuptools=69.5.1
21 | - sqlite=3.45.3
22 | - tk=8.6.14
23 | - wheel=0.43.0
24 | - xz=5.4.6
25 | - zlib=1.2.13
26 | - pip:
27 | - git+https://github.com/oxpig/AbLang.git
28 | - git+https://github.com/oxpig/AbLang2.git
29 | - accelerate==0.32.1
30 | - aiohttp==3.9.5
31 | - aiosignal==1.3.1
32 | - anyio==4.4.0
33 | - argon2-cffi==23.1.0
34 | - argon2-cffi-bindings==21.2.0
35 | - arrow==1.3.0
36 | - asttokens==2.4.1
37 | - async-lru==2.0.4
38 | - async-timeout==4.0.3
39 | - attrs==23.2.0
40 | - babel==2.15.0
41 | - beautifulsoup4==4.12.3
42 | - biopython==1.84
43 | - biotite==0.41.2
44 | - bleach==6.1.0
45 | - brotli==1.1.0
46 | - certifi==2024.7.4
47 | - cffi==1.16.0
48 | - charset-normalizer==3.3.2
49 | - click==8.1.7
50 | - comm==0.2.2
51 | - debugpy==1.8.2
52 | - decorator==5.1.1
53 | - defusedxml==0.7.1
54 | - docker-pycreds==0.4.0
55 | - einops==0.8.0
56 | - esm==3.0.0
57 | - exceptiongroup==1.2.1
58 | - executing==2.0.1
59 | - fastjsonschema==2.20.0
60 | - filelock==3.15.4
61 | - fqdn==1.5.1
62 | - frozenlist==1.4.1
63 | - fsspec==2024.6.1
64 | - gitdb==4.0.11
65 | - gitpython==3.1.43
66 | - h11==0.14.0
67 | - httpcore==1.0.5
68 | - httpx==0.27.0
69 | - huggingface-hub==0.23.4
70 | - idna==3.7
71 | - ipykernel==6.29.5
72 | - ipython==8.26.0
73 | - ipywidgets==8.1.3
74 | - isoduration==20.11.0
75 | - jedi==0.19.1
76 | - jinja2==3.1.4
77 | - joblib==1.4.2
78 | - json5==0.9.25
79 | - jsonpointer==3.0.0
80 | - jsonschema==4.23.0
81 | - jsonschema-specifications==2023.12.1
82 | - jupyter-client==8.6.2
83 | - jupyter-core==5.7.2
84 | - jupyter-events==0.10.0
85 | - jupyter-lsp==2.2.5
86 | - jupyter-server==2.14.1
87 | - jupyter-server-terminals==0.5.3
88 | - jupyterlab==4.2.3
89 | - jupyterlab-pygments==0.3.0
90 | - jupyterlab-server==2.27.2
91 | - jupyterlab-widgets==3.0.11
92 | - lightning==2.3.3
93 | - lightning-utilities==0.11.3.post0
94 | - llvmlite==0.43.0
95 | - markupsafe==2.1.5
96 | - matplotlib-inline==0.1.7
97 | - mistune==3.0.2
98 | - mpmath==1.3.0
99 | - msgpack==1.0.8
100 | - msgpack-numpy==0.4.8
101 | - multidict==6.0.5
102 | - nbclient==0.10.0
103 | - nbconvert==7.16.4
104 | - nbformat==5.10.4
105 | - nest-asyncio==1.6.0
106 | - networkx==3.3
107 | - notebook-shim==0.2.4
108 | - numba==0.60.0
109 | - numpy==1.26.4
110 | - nvidia-cublas-cu12==12.1.3.1
111 | - nvidia-cuda-cupti-cu12==12.1.105
112 | - nvidia-cuda-nvrtc-cu12==12.1.105
113 | - nvidia-cuda-runtime-cu12==12.1.105
114 | - nvidia-cudnn-cu12==8.9.2.26
115 | - nvidia-cufft-cu12==11.0.2.54
116 | - nvidia-curand-cu12==10.3.2.106
117 | - nvidia-cusolver-cu12==11.4.5.107
118 | - nvidia-cusparse-cu12==12.1.0.106
119 | - nvidia-nccl-cu12==2.20.5
120 | - nvidia-nvjitlink-cu12==12.5.82
121 | - nvidia-nvtx-cu12==12.1.105
122 | - overrides==7.7.0
123 | - packaging==24.1
124 | - pandas==2.2.2
125 | - pandocfilters==1.5.1
126 | - parso==0.8.4
127 | - peft==0.11.1
128 | - pexpect==4.9.0
129 | - pillow==10.4.0
130 | - platformdirs==4.2.2
131 | - prometheus-client==0.20.0
132 | - prompt-toolkit==3.0.47
133 | - protobuf==5.27.2
134 | - psutil==6.0.0
135 | - ptyprocess==0.7.0
136 | - pure-eval==0.2.2
137 | - pycparser==2.22
138 | - pygments==2.18.0
139 | - python-dateutil==2.9.0.post0
140 | - python-json-logger==2.0.7
141 | - pytorch-lightning==2.3.3
142 | - pytz==2024.1
143 | - pyyaml==6.0.1
144 | - pyzmq==26.0.3
145 | - referencing==0.35.1
146 | - regex==2024.5.15
147 | - requests==2.32.3
148 | - rfc3339-validator==0.1.4
149 | - rfc3986-validator==0.1.1
150 | - rotary-embedding-torch==0.6.4
151 | - rpds-py==0.19.0
152 | - safetensors==0.4.3
153 | - scikit-learn==1.5.1
154 | - scipy==1.14.0
155 | - send2trash==1.8.3
156 | - sentry-sdk==2.8.0
157 | - setproctitle==1.3.3
158 | - six==1.16.0
159 | - smmap==5.0.1
160 | - sniffio==1.3.1
161 | - soupsieve==2.5
162 | - stack-data==0.6.3
163 | - sympy==1.13.0
164 | - terminado==0.18.1
165 | - threadpoolctl==3.5.0
166 | - tinycss2==1.3.0
167 | - tokenizers==0.19.1
168 | - tomli==2.0.1
169 | - torch==2.3.1
170 | - torchmetrics==1.4.0.post0
171 | - torchtext==0.18.0
172 | - torchvision==0.18.1
173 | - tornado==6.4.1
174 | - tqdm==4.66.4
175 | - traitlets==5.14.3
176 | - transformers==4.42.3
177 | - triton==2.3.1
178 | - types-python-dateutil==2.9.0.20240316
179 | - typing-extensions==4.12.2
180 | - tzdata==2024.1
181 | - uri-template==1.3.0
182 | - urllib3==2.2.2
183 | - wandb==0.17.4
184 | - wcwidth==0.2.13
185 | - webcolors==24.6.0
186 | - webencodings==0.5.1
187 | - websocket-client==1.8.0
188 | - widgetsnbextension==4.0.11
189 | - yarl==1.9.4
190 |
--------------------------------------------------------------------------------
/src/models_ab.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | import re
5 |
6 | from .utils import get_sequence_embeddings, insert_spaces, get_attention_mask, apply_masking_seq
7 | from .swe_pooling import SWE_Pooling
8 |
9 | class AntibodyEncoderAbLang(nn.Module):
10 | def __init__(self, input_dim, projection_dim, ln_cfg, device='cpu'):
11 | super().__init__()
12 | from .lora import setup_peft_ablang
13 | from .configs import peft_config_ablang
14 |
15 | # load the LoRA adapted AbLang HL Models here:
16 | self.ablang_H_lora, self.ablang_H_tokenizer = setup_peft_ablang(peft_config_ablang, chain='H')
17 | self.ablang_L_lora, self.ablang_L_tokenizer = setup_peft_ablang(peft_config_ablang, chain='L')
18 |
19 | self.proj_head = nn.Sequential(
20 | nn.Linear(input_dim, projection_dim),
21 | nn.LayerNorm(projection_dim),
22 | )
23 |
24 | self.device = device
25 |
26 | def forward(self, x):
27 | H_seqs, L_seqs = x
28 | H_seqs_tokens = self.process_seqs(H_seqs, chain='H')
29 | L_seqs_tokens = self.process_seqs(L_seqs, chain='L')
30 |
31 | try:
32 | H_outputs = self.ablang_H_lora(**H_seqs_tokens)
33 | except:
34 | print("Error in feeding H sequences")
35 |
36 | print("H seq:", H_seqs)
37 | print("H seq tokens max:", torch.max(H_seqs_tokens['input_ids']))
38 | print("self.ablang_H_lora: ", self.ablang_H_lora)
39 |
40 | raise ValueError
41 |
42 | try:
43 | L_outputs = self.ablang_L_lora(**L_seqs_tokens)
44 | except:
45 | print("Error in feeding L sequences")
46 |
47 | print("L seq:", L_seqs)
48 | print("L seq tokens max:", torch.max(L_seqs_tokens['input_ids']))
49 | print("self.ablang_L_lora: ", self.ablang_L_lora)
50 |
51 |
52 | H_outputs = get_sequence_embeddings(H_seqs_tokens, H_outputs)
53 | L_outputs = get_sequence_embeddings(L_seqs_tokens, L_outputs)
54 |
55 | Ab_seq_embeds = torch.cat((H_outputs, L_outputs), dim=-1)
56 |
57 | return self.proj_head(Ab_seq_embeds)
58 |
59 | def process_seqs(self, seqs, chain):
60 | '''
61 | seqs: tuple of sequences
62 | '''
63 |
64 | # format the seq strings accordingly to AbLang:
65 | seqs = [insert_spaces(seq) for seq in seqs]
66 | # seqs = [' '.join(seq) for seq in seqs]
67 |
68 | if chain == 'H':
69 | seqs_tokens = self.ablang_H_tokenizer(seqs, return_tensors="pt", padding=True)
70 | else:
71 | seqs_tokens = self.ablang_L_tokenizer(seqs, return_tensors="pt", padding=True)
72 |
73 | return seqs_tokens.to(self.device)
74 |
75 |
76 | class AntibodyEncoderAbLang2(nn.Module):
77 | def __init__(self, input_dim, projection_dim, ln_cfg, device='cpu'):
78 | super().__init__()
79 | from .lora import setup_peft_ablang2
80 | from .configs import peft_config_ablang2
81 |
82 | self.ablang2_lora, self.ablang2_tokenizer = setup_peft_ablang2(peft_config_ablang2, receptor_type='BCR', device=device, no_lora=ln_cfg.no_lora)
83 | self.padding_idx = 21
84 |
85 | self.proj_head = nn.Sequential(
86 | nn.Linear(input_dim, projection_dim),
87 | nn.LayerNorm(projection_dim),
88 | )
89 |
90 | self.device = device
91 |
92 | def forward(self, x):
93 | seq_tokens = self.process_seqs(x)
94 |
95 | # print("seq tokens:", seq_tokens)
96 |
97 | # feed to AbLang2
98 | rescoding = self.ablang2_lora(seq_tokens)
99 |
100 | # process AbLang2 outputs
101 | seq_inputs = {'attention_mask': ~(seq_tokens == self.padding_idx)}
102 | model_output = {'last_hidden_state': rescoding.last_hidden_states}
103 |
104 | seq_outputs = get_sequence_embeddings(seq_inputs, model_output, is_sep=False, is_cls=False)
105 |
106 | return self.proj_head(seq_outputs)
107 |
108 | def process_seqs(self, seqs):
109 | H_seqs, L_seqs = seqs
110 |
111 | # format the seq strings accordingly to AbLang2:
112 | ab_seqs = [f"{H_seqs[i]}|{L_seqs[i]}" for i in range(len(H_seqs))]
113 |
114 | seqs_tokens = self.ablang2_tokenizer(ab_seqs, pad=True, w_extra_tkns=False, device=self.device)
115 |
116 | return seqs_tokens
117 |
118 |
119 | class AntibodyEncoderAntiberta2(nn.Module):
120 | def __init__(self, input_dim, projection_dim, ln_cfg, device='cpu'):
121 | super().__init__()
122 | from .lora import setup_peft_aberta2
123 | from .configs import peft_config_aberta2
124 |
125 | self.aberta2_lora, self.aberta2_tokenizer = setup_peft_aberta2(peft_config_aberta2)
126 |
127 | self.proj_head = nn.Sequential(
128 | nn.Linear(input_dim, projection_dim),
129 | nn.LayerNorm(projection_dim),
130 | )
131 |
132 | self.device = device
133 |
134 | def forward(self, x):
135 | seq_tokens = self.process_seqs(x)
136 |
137 | try:
138 | # feed to AntiBERTa
139 | rescoding = self.aberta2_lora(**seq_tokens)
140 | except:
141 | print("seqs:", x)
142 | print("seq tokens:", seq_tokens)
143 | raise ValueError
144 |
145 | seq_embeds = get_sequence_embeddings(seq_tokens, rescoding)
146 |
147 | return self.proj_head(seq_embeds)
148 |
149 | def process_seqs(self, seqs):
150 | H_seqs, L_seqs = seqs
151 |
152 | # format the seq strings accordingly to Antiberta2:
153 | ab_seqs = [f"{insert_spaces(H_seqs[i])} [SEP] {insert_spaces(L_seqs[i])}" for i in range(len(H_seqs))]
154 |
155 | seqs_tokens = self.aberta2_tokenizer(ab_seqs, return_tensors="pt", padding=True)
156 |
157 | return seqs_tokens.to(self.device)
158 |
--------------------------------------------------------------------------------
/src/configs.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import os
3 |
4 | from peft import LoraConfig, TaskType
5 |
6 | output_dir_path = os.getenv('WANDB_OUTPUT_DIR')
7 |
8 | @dataclass
9 | class LightningConfig:
10 | max_epochs: int = 10
11 | lr: float = 1e-4
12 | weight_decay: float = 0.01
13 | batch_size: int = 4
14 | num_workers_train: int = 8
15 | torch_device: str = 'gpu'
16 | dataset_path: str = None
17 | output_dir: str = output_dir_path
18 | include_mhc: bool = False
19 | mhc_groove_only: bool = False
20 | unique_epitopes: bool = False
21 | mask_seqs: bool = False
22 | mask_prob: float = 0.15
23 | swe_pooling: bool = False
24 | save_embed_path: str = None
25 | no_lora: bool = False
26 | mse_weight: float = 0.
27 | weigh_epitope_count: bool = False
28 | oversample: bool = False
29 | regular_ft: bool = False
30 | fewshot_ratio: float = None
31 | lr_scheduler: str = 'cos_anneal'
32 |
33 |
34 | @dataclass
35 | class EncoderProjectionConfigAbLang:
36 | epitope_input_dim: int = 1280
37 | receptor_input_dim: int = 1536
38 | projection_dim: int = 512
39 | temperature: float = 0.07
40 | receptor_model_name: str = 'ablang'
41 |
42 | @dataclass
43 | class EncoderProjectionConfigAbLang2:
44 | epitope_input_dim: int = 1280
45 | receptor_input_dim: int = 480
46 | projection_dim: int = 512
47 | temperature: float = 0.07
48 | receptor_model_name: str = 'ablang2'
49 |
50 | @dataclass
51 | class EncoderProjectionConfigAntiberta2:
52 | epitope_input_dim: int = 1280
53 | receptor_input_dim: int = 1024
54 | projection_dim: int = 512
55 | temperature: float = 0.07
56 | receptor_model_name: str = 'antiberta2'
57 |
58 | @dataclass
59 | class EncoderProjectionConfigTCRBert:
60 | epitope_input_dim: int = 1280
61 | receptor_input_dim: int = 1536
62 | hidden_dim: int = None
63 | projection_dim: int = 512
64 | temperature: float = 0.07
65 | receptor_model_name: str = 'tcrbert'
66 |
67 | @dataclass
68 | class EncoderProjectionConfigTCRLang:
69 | epitope_input_dim: int = 1280
70 | receptor_input_dim: int = 480
71 | hidden_dim: int = None
72 | projection_dim: int = 512
73 | temperature: float = 0.07
74 | receptor_model_name: str = 'tcrlang'
75 |
76 | @dataclass
77 | class EncoderProjectionConfigESM2:
78 | epitope_input_dim: int = 1280
79 | receptor_input_dim: int = 1280
80 | hidden_dim: int = None
81 | projection_dim: int = None
82 | temperature: float = 0.07
83 | receptor_model_name: str = 'esm2'
84 |
85 | @dataclass
86 | class EncoderProjectionConfigESM3:
87 | epitope_input_dim: int = 1536
88 | receptor_input_dim: int = 1536
89 | hidden_dim: int = None
90 | projection_dim: int = None
91 | temperature: float = 0.07
92 | receptor_model_name: str = 'esm3'
93 |
94 | @dataclass
95 | class EncoderProjectionConfigInHouse:
96 | epitope_input_dim: int = 1280
97 | receptor_input_dim: int = 768
98 | hidden_dim: int = None
99 | projection_dim: int = None
100 | temperature: float = 0.07
101 | receptor_model_name: str = 'inhouse'
102 |
103 | @dataclass
104 | class EncoderProjectionConfigOneHot:
105 | epitope_input_dim: int = 21
106 | receptor_input_dim: int = 21
107 | hidden_dim: int = None
108 | projection_dim: int = None
109 | temperature: float = 0.07
110 | receptor_model_name: str = 'inhouse'
111 |
112 | # --------------------------------------------------
113 | # PEFT configs:
114 | peft_config_esm2 = LoraConfig(
115 | r=8,
116 | lora_alpha=32,
117 | lora_dropout=0.1,
118 | bias='none',
119 | layers_to_transform=[32, 31, 30, 29, 28, 27, 26, 25],
120 | task_type=TaskType.FEATURE_EXTRACTION,
121 | target_modules=['attention.self.key', 'attention.self.value']
122 | )
123 |
124 | peft_config_esm3 = LoraConfig(
125 | r=8,
126 | lora_alpha=32,
127 | lora_dropout=0.1,
128 | bias='none',
129 | layers_to_transform=[47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36],
130 | task_type=TaskType.FEATURE_EXTRACTION,
131 | target_modules=['attn.layernorm_qkv.1']
132 | )
133 |
134 | peft_config_ablang = LoraConfig(
135 | r=8,
136 | lora_alpha=32,
137 | lora_dropout=0.1,
138 | bias='none',
139 | layers_to_transform=[11, 10, 9, 8],
140 | task_type=TaskType.FEATURE_EXTRACTION,
141 | target_modules=['attention.self.query', 'attention.self.value']
142 | )
143 |
144 | peft_config_ablang2 = LoraConfig(
145 | r=8,
146 | lora_alpha=32,
147 | lora_dropout=0.1,
148 | bias='none',
149 | # layers_to_transform=[11, 10, 9, 8],
150 | task_type=TaskType.FEATURE_EXTRACTION,
151 | target_modules=".*(8|9|10|11).*[kv]_proj$"
152 | )
153 |
154 | peft_config_aberta2 = LoraConfig(
155 | r=8,
156 | lora_alpha=32,
157 | lora_dropout=0.1,
158 | bias='none',
159 | layers_to_transform=[15, 14, 13, 12],
160 | task_type=TaskType.FEATURE_EXTRACTION,
161 | target_modules=["attention.self.query", "attention.self.value"]
162 | )
163 |
164 | peft_config_tcrbert = LoraConfig(
165 | r=8,
166 | lora_alpha=32,
167 | lora_dropout=0.1,
168 | bias='none',
169 | layers_to_transform=[11, 10, 9, 8],
170 | task_type=TaskType.FEATURE_EXTRACTION,
171 | target_modules=["attention.self.key", "attention.self.value"]
172 | )
173 |
174 | peft_config_inhouse = LoraConfig(
175 | r=8,
176 | lora_alpha=32,
177 | lora_dropout=0.1,
178 | bias='none',
179 | layers_to_transform=[11, 10, 9, 8],
180 | task_type=TaskType.FEATURE_EXTRACTION,
181 | target_modules=["attention.self.key", "attention.self.value"]
182 | )
183 | # --------------------------------------------------
184 |
185 |
186 | def get_lightning_config(name='default'):
187 | if name == 'default':
188 | return LightningConfig()
189 |
190 | def get_projection_config(name='ablang'):
191 | if name == 'ablang':
192 | return EncoderProjectionConfig()
193 | elif name == 'ablang2':
194 | return EncoderProjectionConfigAbLang2()
195 | elif name == 'antiberta2':
196 | return EncoderProjectionConfigAntiberta2()
197 | elif name == 'tcrbert':
198 | return EncoderProjectionConfigTCRBert()
199 | elif name == 'tcrlang':
200 | return EncoderProjectionConfigTCRLang()
201 | elif name == 'esm2':
202 | return EncoderProjectionConfigESM2()
203 | elif name == 'esm3':
204 | return EncoderProjectionConfigESM3()
205 | elif name == 'inhouse':
206 | return EncoderProjectionConfigInHouse()
207 | elif name == 'onehot':
208 | return EncoderProjectionConfigOneHot()
209 | else:
210 | raise ValueError(f"Invalid model name: {name}")
211 |
212 |
213 | def build_lora_config(rank=4, alpha=32, dropout=0.1, bias='none', layers_to_transform=None,
214 | task_type=TaskType.FEATURE_EXTRACTION, target_modules=None):
215 |
216 | return LoraConfig(
217 | r=rank,
218 | lora_alpha=alpha,
219 | lora_dropout=dropout,
220 | bias=bias,
221 | layers_to_transform=layers_to_transform,
222 | task_type=task_type,
223 | target_modules=target_modules
224 | )
225 |
--------------------------------------------------------------------------------
/src/lora.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from peft import get_peft_model
5 |
6 |
7 | def setup_peft_esm2(peft_config, no_lora = False, regular_ft=False):
8 |
9 | from transformers import EsmModel, EsmTokenizer
10 |
11 | # Load the pretrained ESM-2 model
12 | esm_model = EsmModel.from_pretrained('facebook/esm2_t33_650M_UR50D')
13 | esm_tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t33_650M_UR50D')
14 |
15 | if regular_ft:
16 | return esm_model, esm_tokenizer
17 |
18 | # Apply LoRA to the model
19 | peft_lm = get_peft_model(esm_model, peft_config)
20 |
21 | #NOT APPLYING LoRA to the model:
22 | if no_lora:
23 | for name, param in esm_model.named_parameters():
24 | param.requires_grad = False
25 | return esm_model, esm_tokenizer
26 |
27 | # freeze all the layers except the LoRA adapter matrices
28 | for name, param in peft_lm.named_parameters():
29 | if "lora" in name:
30 | param.requires_grad = True
31 | else:
32 | param.required_grad = False
33 |
34 | return peft_lm, esm_tokenizer
35 |
36 | def setup_peft_esm3(peft_config, no_lora = False):
37 |
38 | from esm.models.esm3 import ESM3
39 | from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig
40 | from esm.tokenization import EsmSequenceTokenizer
41 |
42 | # Load the pretrained ESM-3 model
43 | esm3_model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1")
44 | esm3_tokenizer = EsmSequenceTokenizer()
45 |
46 | #NOT APPLYING LoRA to the model:
47 | if no_lora:
48 | for name, param in esm3_model.named_parameters():
49 | param.requires_grad = False
50 | return esm3_model, esm3_tokenizer
51 |
52 | # Apply LoRA to the model
53 | peft_lm = get_peft_model(esm3_model, peft_config)
54 |
55 | # freeze all the layers except the LoRA adapter matrices
56 | for name, param in esm3_model.named_parameters():
57 | if "lora" in name:
58 | param.requires_grad = True
59 | else:
60 | param.required_grad = False
61 |
62 | return esm3_model, esm3_tokenizer
63 |
64 | def setup_peft_ablang(peft_config, chain="H"):
65 |
66 | from transformers import AutoTokenizer, AutoModelForMaskedLM
67 |
68 | if chain == "H":
69 | # Load the pretrained AbLang H model
70 | ablang_tokenizer = AutoTokenizer.from_pretrained("qilowoq/AbLang_heavy", trust_remote_code=True)
71 | ablang_model = AutoModelForMaskedLM.from_pretrained("qilowoq/AbLang_heavy", trust_remote_code=True)
72 |
73 | if chain == "L":
74 | # Load the pretrained AbLang L model
75 | ablang_tokenizer = AutoTokenizer.from_pretrained("qilowoq/AbLang_light", trust_remote_code=True)
76 | ablang_model = AutoModelForMaskedLM.from_pretrained("qilowoq/AbLang_light", trust_remote_code=True)
77 |
78 | # take out the decoder layer, which we don't need
79 | ablang_model = ablang_model.roberta
80 |
81 | # Apply LoRA to the model
82 | peft_lm = get_peft_model(ablang_model, peft_config)
83 |
84 | # freeze all the layers except the LoRA adapter matrices
85 | for name, param in peft_lm.named_parameters():
86 | if "lora" in name:
87 | param.requires_grad = True
88 | else:
89 | param.required_grad = False
90 |
91 | return peft_lm, ablang_tokenizer
92 |
93 | def setup_peft_ablang2(peft_config, receptor_type='BCR', device='cpu', no_lora=False):
94 | import ablang2
95 |
96 | # Load the pretrained AbLang2 model
97 | if receptor_type == 'TCR':
98 | ablang2_module = ablang2.pretrained(model_to_use='tcrlang-paired', random_init=False, device=device)
99 | elif receptor_type == 'BCR':
100 | ablang2_module = ablang2.pretrained(model_to_use='ablang2-paired', random_init=False, device=device)
101 | else:
102 | raise ValueError(f"Receptor type {receptor_type} not supported")
103 | ablang2_model = ablang2_module.AbRep
104 |
105 | # NOT APPLYING LoRA to the model:
106 | if no_lora:
107 | for name, param in ablang2_model.named_parameters():
108 | param.requires_grad = False
109 | return ablang2_model, ablang2_module.tokenizer
110 |
111 | # Apply LoRA to the model
112 | peft_lm = get_peft_model(ablang2_model, peft_config)
113 |
114 | # freeze all the layers except the LoRA adapter matrices
115 | lora_count = 0
116 | for name, param in ablang2_model.named_parameters():
117 | if "lora" in name:
118 | lora_count += 1
119 | param.requires_grad = True
120 | else:
121 | param.required_grad = False
122 | assert lora_count >= 4 # make sure we have LoRA adapter matrices
123 |
124 | return ablang2_model, ablang2_module.tokenizer
125 |
126 | def setup_peft_aberta2(peft_config):
127 | from transformers import (
128 | RoFormerForMaskedLM,
129 | RoFormerTokenizer,
130 | )
131 |
132 | # Load the pretrained Aberta2 model
133 | aberta2_model = RoFormerForMaskedLM.from_pretrained("alchemab/antiberta2")
134 | aberta2_tokenizer = RoFormerTokenizer.from_pretrained("alchemab/antiberta2")
135 |
136 | # only take the RoFormer module:
137 | aberta2_model = aberta2_model.roformer
138 |
139 | # Apply LoRA to the model
140 | peft_lm = get_peft_model(aberta2_model, peft_config)
141 |
142 | # freeze all the layers except the LoRA adapter matrices
143 | for name, param in peft_lm.named_parameters():
144 | if "lora" in name:
145 | param.requires_grad = True
146 | else:
147 | param.required_grad = False
148 |
149 | return peft_lm, aberta2_tokenizer
150 |
151 | def setup_peft_tcrbert(peft_config, no_lora=False, regular_ft=False):
152 | from transformers import (
153 | BertModel,
154 | AutoTokenizer,
155 | )
156 |
157 | # Load the pretrained TCRBert model
158 | tcrbert_model = BertModel.from_pretrained("wukevin/tcr-bert-mlm-only")
159 | tcrbert_tokenizer = AutoTokenizer.from_pretrained("wukevin/tcr-bert-mlm-only", trust_remote_code=True)
160 |
161 | if regular_ft:
162 | return tcrbert_model, tcrbert_tokenizer
163 |
164 | # Apply LoRA to the model
165 | peft_lm = get_peft_model(tcrbert_model, peft_config)
166 |
167 | # NOT APPLYING LoRA to the model:
168 | if no_lora:
169 | for name, param in tcrbert_model.named_parameters():
170 | param.requires_grad = False
171 | return tcrbert_model, tcrbert_tokenizer
172 |
173 | # freeze all the layers except the LoRA adapter matrices
174 | for name, param in peft_lm.named_parameters():
175 | if "lora" in name:
176 | param.requires_grad = True
177 | else:
178 | param.required_grad = False
179 |
180 | return peft_lm, tcrbert_tokenizer
181 |
182 | def setup_peft_inhouse(peft_config, no_lora=False, model_ckpt_path=None):
183 | from .pretrain.model import CdrBERT, getCdrTokenizer, MODEL_CONFIG
184 |
185 | # load the in-house TCR model:
186 | inhouse_tokenizer = getCdrTokenizer()
187 | inhouse_model = CdrBERT(MODEL_CONFIG, inhouse_tokenizer)
188 | inhouse_ckpt = torch.load(model_ckpt_path)
189 | # Remove "model." prefix from keys. Artifact of Pytorch Lightning
190 | new_state_dict = {}
191 | for key, value in inhouse_ckpt['state_dict'].items():
192 | new_key = key.replace("model.", "")
193 | new_state_dict[new_key] = value
194 | inhouse_model.load_state_dict(new_state_dict)
195 |
196 | # Apply LoRA to the model
197 | peft_lm = get_peft_model(inhouse_model, peft_config)
198 |
199 | # NOT APPLYING LoRA to the model:
200 | if no_lora:
201 | for name, param in inhouse_model.named_parameters():
202 | param.requires_grad = False
203 | return inhouse_model, inhouse_tokenizer
204 |
205 | # freeze all the layers except the LoRA adapter matrices
206 | for name, param in peft_lm.named_parameters():
207 | if "lora" in name:
208 | param.requires_grad = True
209 | else:
210 | param.required_grad = False
211 |
212 | return peft_lm, inhouse_tokenizer
--------------------------------------------------------------------------------
/src/training.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | from pytorch_lightning import Trainer
4 | from lightning.pytorch.accelerators import find_usable_cuda_devices
5 | from pytorch_lightning.utilities import rank_zero_only
6 | import pytorch_lightning.loggers as log
7 | from lightning.pytorch.strategies import DDPStrategy
8 |
9 | import argparse
10 | import os
11 | import wandb
12 |
13 | from .alignment import CLIPModel
14 | from .data_module import EpitopeReceptorDataModule
15 | from .configs import get_lightning_config, get_projection_config
16 | from .logging import call_backs, combine_args_and_configs
17 |
18 |
19 | def setup_parser():
20 |
21 | # Command line interface arguments and parsing
22 | parser = argparse.ArgumentParser(description='argument parser for training')
23 | parser.add_argument('--batch-size', type=int, default=4, help='Batch size for training')
24 | parser.add_argument('--grad-accum', type=int, default=1, help='Number of gradient accumulation steps')
25 | parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate')
26 | parser.add_argument('--weight-decay', type=float, default=0.01, help='Weight decay parameter for AdamW algorithm')
27 | parser.add_argument('--random-seed', type=int, default=14, help='Random seed for reproducibility')
28 |
29 | # WandB configuration:
30 | parser.add_argument('--entity', type=str, default='lordim', help='entity name')
31 | parser.add_argument('--project', type=str, default='clip_antibody', help='project name')
32 | parser.add_argument('--group', type=str, default='clipbody_test', help='group name')
33 | parser.add_argument('--run-id', type=str, default='clipbody_test', help='run id')
34 | parser.add_argument('--use-wandb', default=False, action='store_true', help='use wandb for logging')
35 |
36 | # Training and Data configuration
37 | parser.add_argument('--receptor-model-name', type=str, default='ablang', help='name of the receptor foundation model')
38 | parser.add_argument('--receptor-type', type=str, default='TCR', help='Is the receptor BCR or TCR')
39 | parser.add_argument('--include-mhc', default=False, action='store_true', help='include MHC sequences alongside epitope in the training data')
40 | parser.add_argument('--mhc-groove-only', default=False, action='store_true', help='only include A1-A2 domains for class I MHC, A1-B1 domains for class II MHC')
41 | parser.add_argument('--unique-epitopes', default=False, action='store_true', help='split the data based on unique epitopes')
42 | parser.add_argument('--no-lora', default=False, action='store_true', help='do not use LoRA adapter matrices for the models')
43 | parser.add_argument('--regular-ft', default=False, action='store_true', help='use regular fine-tuning')
44 | parser.add_argument('--mask-seqs', default=False, action='store_true', help='mask the sequences for training')
45 | parser.add_argument('--mask-prob', type=float, default=0.15, help='probability of masking a residue')
46 | parser.add_argument('--mse-weight', type=float, default=0., help='weight for the MSE loss')
47 | parser.add_argument('--weigh-epitope-count', default=False, action='store_true', help='weight the epitope count in the clip loss')
48 | parser.add_argument('--swe-pooling', default=False, action='store_true', help='use SWE pooling for sequence embeddings')
49 | parser.add_argument('--hidden-dim', type=int, default=None, help='dimension of the hidden layer')
50 | parser.add_argument('--projection-dim', type=int, default=None, help='dimension of the projection layer')
51 | parser.add_argument('--lightning-config-name', type=str, default='default')
52 | parser.add_argument('--dataset-path', type=str, required=True, help='path to the dataset')
53 | parser.add_argument('--mhc-path', type=str, default=None, help='path to file with MHC sequence info. Required if --include-mhc is set to True')
54 | parser.add_argument('--oversample', default=False, action='store_true', help='oversample the epitopes with few receptor data')
55 | parser.add_argument('--fewshot-ratio', type=float, default=None, help='ratio of few-shot data to the total data')
56 | parser.add_argument('--lr-scheduler', type=str, default='cos_anneal', help='learning rate scheduler')
57 |
58 | # PyTorch Lightning configuration
59 | parser.add_argument('--torch-device', type=str, default='gpu')
60 | parser.add_argument('--output-dir', type=str, required=True, help='wandb and checkpoint output')
61 | parser.add_argument('--num-gpus', type=int, default=1, help='number of GPUs to use')
62 | parser.add_argument('--max-epochs', type=int, default = 1, required=False)
63 | parser.add_argument('--gpus-used', type=int, nargs='+', required=False, help='which GPUs used for env variable CUDA_VISIBLE_DEVICES')
64 | parser.add_argument('--stage', type=str, default='fit', help='stage of training')
65 | parser.add_argument('--check-val-every-n-epoch', type=int, default=1, help='check validation every n epochs')
66 | parser.add_argument('--val-check-interval', type=float, default=1.0, help='validation check interval')
67 | parser.add_argument('--save-ckpts', default=False, action='store_true', help='save checkpoints')
68 | parser.add_argument('--from-checkpoint', type=str, default=None, help='path to checkpoint')
69 | parser.add_argument('--save-embed-path', type=str, default=None, help='path to save embeddings for eval')
70 |
71 | args = parser.parse_args()
72 |
73 | return args
74 |
75 |
76 |
77 | if __name__ == '__main__':
78 | # utilizing Tensor Cores:
79 | torch.set_float32_matmul_precision('high')
80 |
81 | # setting up the environment variables:
82 | os.environ["TOKENIZERS_PARALLELISM"] = "true" # resolving tokenizers parallelism issue
83 |
84 | args = setup_parser()
85 |
86 | # retrieve the configs:
87 | lightning_config = get_lightning_config()
88 | model_config = get_projection_config(args.receptor_model_name)
89 |
90 | # update configs based on input arguments:
91 | combine_args_and_configs(args, [lightning_config, model_config])
92 |
93 | # setup callbacks:
94 | if args.stage == 'fit' and args.output_dir is not None:
95 | if not os.path.exists(os.path.join(args.output_dir, args.run_id)):
96 | os.makedirs(os.path.join(args.output_dir, args.run_id))
97 |
98 | output_dir = os.path.join(args.output_dir, args.run_id)
99 |
100 | # get callbacks
101 | callbacks = call_backs(output_dir, args.save_ckpts)
102 | else:
103 | callbacks = None
104 |
105 | # construct PyTorch Lightning Module:
106 | if args.receptor_type == 'TCR':
107 | print("Using TCR data!")
108 | else:
109 | print("Using BCR data!")
110 | tsv_file_path = args.dataset_path
111 | mhc_file_path = args.mhc_path
112 |
113 | if args.mask_seqs:
114 | print("WARNING: Partially making sequence residues during training")
115 |
116 | if args.unique_epitopes:
117 | print("WARNING: Splitting data based on unique epitopes")
118 |
119 | pl_datamodule = EpitopeReceptorDataModule(tsv_file_path, mhc_file=mhc_file_path, ln_cfg=lightning_config,
120 | batch_size=lightning_config.batch_size, include_mhc=lightning_config.include_mhc,
121 | model_config=model_config, random_seed=args.random_seed)
122 |
123 | # construct the CLIP model:
124 | clip_model = CLIPModel(lightning_config, model_config)
125 |
126 | if rank_zero_only.rank == 0:
127 | # initalize wandb:
128 | if args.use_wandb:
129 | run = wandb.init(project=args.project,
130 | entity=args.entity,
131 | group=args.group,
132 | dir=output_dir,
133 | name=args.run_id,
134 | id=args.run_id,
135 | resume=True if args.from_checkpoint else None,
136 | )
137 |
138 | run_output_dir = run.dir
139 | wandb_logger = log.WandbLogger(save_dir=run_output_dir, log_model=False)
140 | wandb_logger.watch(clip_model)
141 |
142 | if len(args.gpus_used) > 1:
143 | strat = 'ddp'
144 | if args.regular_ft:
145 | strat = 'ddp_find_unused_parameters_true'
146 | else:
147 | strat = 'auto'
148 | # build PyTorch Lightning Trainer:
149 | trainer = Trainer(max_epochs=args.max_epochs,
150 | logger=wandb_logger if args.use_wandb else None,
151 | accelerator=args.torch_device,
152 | devices=args.gpus_used if args.gpus_used else 1, #TODO: smooth CPU/GPU conversion
153 | enable_progress_bar=True,
154 | callbacks=callbacks if callbacks is not None else None,
155 | accumulate_grad_batches=args.grad_accum,
156 | reload_dataloaders_every_n_epochs=1 if args.oversample else 0,
157 | strategy=strat,
158 | )
159 | else:
160 | strat = 'ddp'
161 | if args.regular_ft:
162 | strat = 'ddp_find_unused_parameters_true'
163 | # build PyTorch Lightning Trainer:
164 | trainer = Trainer(max_epochs=args.max_epochs,
165 | logger=None,
166 | accelerator=args.torch_device,
167 | devices=args.gpus_used if args.gpus_used else 1, #TODO: smooth CPU/GPU conversion
168 | enable_progress_bar=True,
169 | callbacks=callbacks if callbacks is not None else None,
170 | accumulate_grad_batches=args.grad_accum,
171 | reload_dataloaders_every_n_epochs=1 if args.oversample else 0,
172 | strategy=strat,
173 | )
174 |
175 | # run the model:
176 | if args.stage == 'fit':
177 | print('Start Training...')
178 | trainer.fit(model=clip_model, datamodule=pl_datamodule, ckpt_path=args.from_checkpoint)
179 | else:
180 | print("**********************")
181 | print("* Inference Mode... *")
182 | print("**********************")
183 | trainer.test(model=clip_model, datamodule=pl_datamodule, ckpt_path=args.from_checkpoint)
184 |
--------------------------------------------------------------------------------
/src/swe_pooling.py:
--------------------------------------------------------------------------------
1 | '''
2 | Contents of this file are from the open source code for
3 |
4 | NaderiAlizadeh, Navid, and Rohit Singh.
5 | Aggregating Residue-Level Protein Language Model Embeddings with Optimal Transport.
6 | bioRxiv (2024): 2024-01.
7 |
8 | MIT License
9 |
10 | Copyright (c) 2024 Navid NaderiAlizadeh and Rohit Singh
11 |
12 | Permission is hereby granted, free of charge, to any person obtaining a copy
13 | of this software and associated documentation files (the "Software"), to deal
14 | in the Software without restriction, including without limitation the rights
15 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16 | copies of the Software, and to permit persons to whom the Software is
17 | furnished to do so, subject to the following conditions:
18 |
19 | The above copyright notice and this permission notice shall be included in all
20 | copies or substantial portions of the Software.
21 |
22 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28 | SOFTWARE.
29 | '''
30 |
31 | from types import SimpleNamespace
32 |
33 | import os
34 | import pickle as pk
35 | from functools import lru_cache
36 |
37 | import torch
38 | import torch.nn as nn
39 | from torch.nn.utils.rnn import pad_sequence
40 | from torch.utils.data import DataLoader, Dataset
41 | from tqdm import tqdm
42 |
43 | import contextlib
44 |
45 | class Interp1d(torch.autograd.Function):
46 | def __call__(self, x, y, xnew, out=None):
47 | return self.forward(x, y, xnew, out)
48 |
49 | def forward(ctx, x, y, xnew, out=None):
50 | """
51 | Linear 1D interpolation on the GPU for Pytorch.
52 | This function returns interpolated values of a set of 1-D functions at
53 | the desired query points `xnew`.
54 | This function is working similarly to Matlab™ or scipy functions with
55 | the `linear` interpolation mode on, except that it parallelises over
56 | any number of desired interpolation problems.
57 | The code will run on GPU if all the tensors provided are on a cuda
58 | device.
59 | Parameters
60 | ----------
61 | x : (N, ) or (D, N) Pytorch Tensor
62 | A 1-D or 2-D tensor of real values.
63 | y : (N,) or (D, N) Pytorch Tensor
64 | A 1-D or 2-D tensor of real values. The length of `y` along its
65 | last dimension must be the same as that of `x`
66 | xnew : (P,) or (D, P) Pytorch Tensor
67 | A 1-D or 2-D tensor of real values. `xnew` can only be 1-D if
68 | _both_ `x` and `y` are 1-D. Otherwise, its length along the first
69 | dimension must be the same as that of whichever `x` and `y` is 2-D.
70 | out : Pytorch Tensor, same shape as `xnew`
71 | Tensor for the output. If None: allocated automatically.
72 | """
73 | # making the vectors at least 2D
74 | is_flat = {}
75 | require_grad = {}
76 | v = {}
77 | device = []
78 | eps = torch.finfo(y.dtype).eps
79 | for name, vec in {'x': x, 'y': y, 'xnew': xnew}.items():
80 | assert len(vec.shape) <= 2, 'interp1d: all inputs must be '\
81 | 'at most 2-D.'
82 | if len(vec.shape) == 1:
83 | v[name] = vec[None, :]
84 | else:
85 | v[name] = vec
86 | is_flat[name] = v[name].shape[0] == 1
87 | require_grad[name] = vec.requires_grad
88 | device = list(set(device + [str(vec.device)]))
89 | assert len(device) == 1, 'All parameters must be on the same device.'
90 | device = device[0]
91 |
92 | # Checking for the dimensions
93 | assert (v['x'].shape[1] == v['y'].shape[1]
94 | and (
95 | v['x'].shape[0] == v['y'].shape[0]
96 | or v['x'].shape[0] == 1
97 | or v['y'].shape[0] == 1
98 | )
99 | ), ("x and y must have the same number of columns, and either "
100 | "the same number of row or one of them having only one "
101 | "row.")
102 |
103 | reshaped_xnew = False
104 | if ((v['x'].shape[0] == 1) and (v['y'].shape[0] == 1)
105 | and (v['xnew'].shape[0] > 1)):
106 | # if there is only one row for both x and y, there is no need to
107 | # loop over the rows of xnew because they will all have to face the
108 | # same interpolation problem. We should just stack them together to
109 | # call interp1d and put them back in place afterwards.
110 | original_xnew_shape = v['xnew'].shape
111 | v['xnew'] = v['xnew'].contiguous().view(1, -1)
112 | reshaped_xnew = True
113 |
114 | # identify the dimensions of output and check if the one provided is ok
115 | D = max(v['x'].shape[0], v['xnew'].shape[0])
116 | shape_ynew = (D, v['xnew'].shape[-1])
117 | if out is not None:
118 | if out.numel() != shape_ynew[0]*shape_ynew[1]:
119 | # The output provided is of incorrect shape.
120 | # Going for a new one
121 | out = None
122 | else:
123 | ynew = out.reshape(shape_ynew)
124 | if out is None:
125 | ynew = torch.zeros(*shape_ynew, device=device)
126 |
127 | # moving everything to the desired device in case it was not there
128 | # already (not handling the case things do not fit entirely, user will
129 | # do it if required.)
130 | for name in v:
131 | v[name] = v[name].to(device)
132 |
133 | # calling searchsorted on the x values.
134 | ind = ynew.long()
135 |
136 | # expanding xnew to match the number of rows of x in case only one xnew is
137 | # provided
138 | if v['xnew'].shape[0] == 1:
139 | v['xnew'] = v['xnew'].expand(v['x'].shape[0], -1)
140 |
141 | torch.searchsorted(v['x'].contiguous(),
142 | v['xnew'].contiguous(), out=ind)
143 |
144 | # the `-1` is because searchsorted looks for the index where the values
145 | # must be inserted to preserve order. And we want the index of the
146 | # preceeding value.
147 | ind -= 1
148 | # we clamp the index, because the number of intervals is x.shape-1,
149 | # and the left neighbour should hence be at most number of intervals
150 | # -1, i.e. number of columns in x -2
151 | ind = torch.clamp(ind, 0, v['x'].shape[1] - 1 - 1)
152 |
153 | # helper function to select stuff according to the found indices.
154 | def sel(name):
155 | if is_flat[name]:
156 | return v[name].contiguous().view(-1)[ind]
157 | return torch.gather(v[name], 1, ind)
158 |
159 | # activating gradient storing for everything now
160 | enable_grad = False
161 | saved_inputs = []
162 | for name in ['x', 'y', 'xnew']:
163 | if require_grad[name]:
164 | enable_grad = True
165 | saved_inputs += [v[name]]
166 | else:
167 | saved_inputs += [None, ]
168 | # assuming x are sorted in the dimension 1, computing the slopes for
169 | # the segments
170 | is_flat['slopes'] = is_flat['x']
171 | # now we have found the indices of the neighbors, we start building the
172 | # output. Hence, we start also activating gradient tracking
173 | with torch.enable_grad() if enable_grad else contextlib.suppress():
174 | v['slopes'] = (
175 | (v['y'][:, 1:]-v['y'][:, :-1])
176 | /
177 | (eps + (v['x'][:, 1:]-v['x'][:, :-1]))
178 | )
179 |
180 | # now build the linear interpolation
181 | ynew = sel('y') + sel('slopes')*(
182 | v['xnew'] - sel('x'))
183 |
184 | if reshaped_xnew:
185 | ynew = ynew.view(original_xnew_shape)
186 |
187 | ctx.save_for_backward(ynew, *saved_inputs)
188 | return ynew
189 |
190 | @staticmethod
191 | def backward(ctx, grad_out):
192 | inputs = ctx.saved_tensors[1:]
193 | gradients = torch.autograd.grad(
194 | ctx.saved_tensors[0],
195 | [i for i in inputs if i is not None],
196 | grad_out, retain_graph=True)
197 | result = [None, ] * 5
198 | pos = 0
199 | for index in range(len(inputs)):
200 | if inputs[index] is not None:
201 | result[index] = gradients[pos]
202 | pos += 1
203 | return (*result,)
204 |
205 | class SWE_Pooling(nn.Module):
206 | def __init__(self, d_in, num_ref_points, num_slices):
207 | '''
208 | Produces fixed-dimensional permutation-invariant embeddings for input sets of arbitrary size based on sliced-Wasserstein embedding.
209 | Inputs:
210 | d_in: The dimensionality of the space that each set sample belongs to
211 | num_ref_points: Number of points in the reference set
212 | num_slices: Number of slices
213 | '''
214 | super(SWE_Pooling, self).__init__()
215 | self.d_in = d_in
216 | self.num_ref_points = num_ref_points
217 | self.num_slices = num_slices
218 |
219 | uniform_ref = torch.linspace(-1, 1, num_ref_points).unsqueeze(1).repeat(1, num_slices)
220 | self.reference = nn.Parameter(uniform_ref)
221 |
222 | self.theta = nn.utils.weight_norm(nn.Linear(d_in, num_slices, bias=False), dim=0)
223 | if num_slices <= d_in:
224 | nn.init.eye_(self.theta.weight_v)
225 | else:
226 | nn.init.normal_(self.theta.weight_v)
227 |
228 | self.theta.weight_g.data = torch.ones_like(self.theta.weight_g.data, requires_grad=False)
229 | self.theta.weight_g.requires_grad = False
230 |
231 | # weights to reduce the output embedding dimensionality
232 | self.weight = nn.Linear(num_ref_points, 1, bias=False)
233 |
234 | def forward(self, X, mask=None):
235 | '''
236 | Calculates GSW between two empirical distributions.
237 | Note that the number of samples is assumed to be equal
238 | (This is however not necessary and could be easily extended
239 | for empirical distributions with different number of samples)
240 | Input:
241 | X: B x N x dn tensor, containing a batch of B sets, each containing N samples in a dn-dimensional space
242 | mask [optional]: B x N binary tensor, with 1 iff the set element is valid; used for the case where set sizes are different
243 | Output:
244 | weighted_embeddings: B x num_slices tensor, containing a batch of B embeddings, each of dimension "num_slices" (i.e., number of slices)
245 | '''
246 |
247 | B, N, _ = X.shape
248 | Xslices = self.get_slice(X)
249 |
250 | M, _ = self.reference.shape
251 |
252 | if mask is None:
253 | # serial implementation should be used if set sizes are different
254 | Xslices_sorted, Xind = torch.sort(Xslices, dim=1)
255 |
256 | if M == N:
257 | Xslices_sorted_interpolated = Xslices_sorted
258 | else:
259 | x = torch.linspace(0, 1, N + 2)[1:-1].unsqueeze(0).repeat(B * self.num_slices, 1).to(X.device)
260 | xnew = torch.linspace(0, 1, M + 2)[1:-1].unsqueeze(0).repeat(B * self.num_slices, 1).to(X.device)
261 | y = torch.transpose(Xslices_sorted, 1, 2).reshape(B * self.num_slices, -1)
262 | Xslices_sorted_interpolated = torch.transpose(Interp1d()(x, y, xnew).view(B, self.num_slices, -1), 1, 2)
263 | else:
264 | # replace invalid set elements with points to the right of the maximum element for each slice and each set (which will not impact the sorting and interpolation process)
265 | invalid_elements_mask = ~mask.bool().unsqueeze(-1).repeat(1, 1, self.num_slices)
266 | Xslices_copy = Xslices.clone()
267 | Xslices_copy[invalid_elements_mask] = -1e10
268 |
269 | top2_Xslices, _ = torch.topk(Xslices_copy, k=2, dim=1)
270 | max_Xslices = top2_Xslices[:, 0].unsqueeze(1)
271 | delta_y = - torch.diff(top2_Xslices, dim=1)
272 |
273 | Xslices_modified = Xslices.clone()
274 |
275 | Xslices_modified[invalid_elements_mask] = max_Xslices.repeat(1, N, 1)[invalid_elements_mask]
276 |
277 | delta_x = 1 / (1 + torch.sum(mask, dim=1, keepdim=True))
278 | slope = delta_y / delta_x.unsqueeze(-1).repeat(1, 1, self.num_slices) # B x 1 x num_slices
279 | slope = slope.repeat(1, N, 1)
280 |
281 | eps = 1e-3
282 | x_shifts = eps * torch.cumsum(invalid_elements_mask, dim=1)
283 | y_shifts = slope * x_shifts
284 | Xslices_modified = Xslices_modified + y_shifts
285 |
286 | Xslices_sorted, _ = torch.sort(Xslices_modified, dim=1)
287 |
288 | x = torch.arange(1, N + 1).to(X.device) / (1 + torch.sum(mask, dim=1, keepdim=True)) # B x N
289 |
290 | invalid_elements_mask = ~mask.bool()
291 | x_copy = x.clone()
292 | x_copy[invalid_elements_mask] = -1e10
293 | max_x, _ = torch.max(x_copy, dim=1, keepdim=True)
294 | x[invalid_elements_mask] = max_x.repeat(1, N)[invalid_elements_mask]
295 |
296 | x = x.unsqueeze(1).repeat(1, self.num_slices, 1) + torch.transpose(x_shifts, 1, 2)
297 | x = x.view(-1, N) # BL x N
298 |
299 | xnew = torch.linspace(0, 1, M + 2)[1:-1].unsqueeze(0).repeat(B * self.num_slices, 1).to(X.device)
300 | y = torch.transpose(Xslices_sorted, 1, 2).reshape(B * self.num_slices, -1)
301 | Xslices_sorted_interpolated = torch.transpose(Interp1d()(x, y, xnew).view(B, self.num_slices, -1), 1, 2)
302 |
303 | Rslices = self.reference.expand(Xslices_sorted_interpolated.shape)
304 |
305 | _, Rind = torch.sort(Rslices, dim=1)
306 | embeddings = (Rslices - torch.gather(Xslices_sorted_interpolated, dim=1, index=Rind)).permute(0, 2, 1) # B x num_slices x M
307 |
308 | weighted_embeddings = self.weight(embeddings).sum(-1)
309 |
310 | return weighted_embeddings.view(-1, self.num_slices)
311 |
312 | def get_slice(self, X):
313 | '''
314 | Slices samples from distribution X~P_X
315 | Input:
316 | X: B x N x dn tensor, containing a batch of B sets, each containing N samples in a dn-dimensional space
317 | '''
318 | return self.theta(X)
--------------------------------------------------------------------------------
/src/alignment.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | import torch.nn.functional as F
6 |
7 | import pytorch_lightning as pl
8 | from torch.optim.lr_scheduler import ReduceLROnPlateau
9 |
10 | from .models import *
11 | from .utils import construct_label_matrices, construct_label_matrices_ones
12 | from .data_module import compute_weights
13 | from .lr_scheduler import CosineAnnealingWarmUpRestarts
14 |
15 | class CLIPModel(pl.LightningModule):
16 | def __init__(self, lightning_config, model_config, device='cuda', **kwargs):
17 | super().__init__()
18 | self.save_hyperparameters()
19 |
20 | self.ln_cfg = lightning_config
21 | self.model_config = model_config
22 |
23 | self.epitope_input_dim = model_config.epitope_input_dim
24 | self.receptor_input_dim = model_config.receptor_input_dim
25 | self.projection_dim = model_config.projection_dim
26 | self.hidden_dim = model_config.hidden_dim
27 |
28 | # loss functions:
29 | self.bceloss_logits = nn.BCEWithLogitsLoss(reduction='none')
30 | self.celoss = nn.CrossEntropyLoss(reduction='none')
31 | self.mse_weight = lightning_config.mse_weight
32 | self.epitope_weights = None
33 |
34 | # logging:
35 | self.log_iterations = None
36 | self.training_step_metrics = {}
37 | self.val_step_metrics = {}
38 | self.test_step_metrics = {}
39 |
40 | # for evaluation later:
41 | self.epitope_embeddings = []
42 | self.receptor_embeddings = []
43 | self.epitope_sequences = []
44 | self.receptor_sequences = []
45 |
46 |
47 | def forward(self, epitope_seqs, receptor_seqs, mask=False):
48 | epitope_proj = self.epitope_encoder(epitope_seqs, mask=mask)
49 | receptor_proj = self.receptor_encoder(receptor_seqs, mask=mask)
50 | return epitope_proj, receptor_proj
51 |
52 |
53 | def clip_loss_multiclass(self, epitope_features, receptor_features, label_matrix, temperature=1.0):
54 | """
55 | Compute the multi-class CLIP loss for epitope and receptor features based on label_indices
56 |
57 | Args:
58 | epitope_features: Tensor of shape (batch_size, feature_dim) representing the epitope embeddings.
59 | receptor_features: Tensor of shape (batch_size, feature_dim) representing the receptor embeddings.
60 | label_indices: list of length (batch_size) where each element is a list of indices
61 | of the correct labels for each epitope.
62 | temperature: A scaling factor to control the sharpness of the similarity distribution.
63 |
64 | Returns:
65 | loss: A scalar tensor representing the multi-class CLIP loss.
66 | """
67 |
68 | # Normalize the features to unit length (each dim bs x proj_dim)
69 | epitope_features = F.normalize(epitope_features, dim=-1)
70 | receptor_features = F.normalize(receptor_features, dim=-1)
71 |
72 | # MSE Loss between the normalized features:
73 | diff_norm = torch.norm(epitope_features - receptor_features, dim=-1)
74 | mse_loss = F.mse_loss(diff_norm, torch.zeros(len(diff_norm)).to(self.device), reduction='mean')
75 |
76 | # Compute the logits (similarities) as the dot product of epitope and receptor features
77 | logits_per_epitope = epitope_features @ receptor_features.t()
78 | logits_per_receptor = receptor_features @ epitope_features.t()
79 |
80 | # Scale by temperature
81 | logits_per_epitope /= temperature
82 | logits_per_receptor /= temperature
83 |
84 | # Compute the cross-entropy loss for both epitope-to-receptor and receptor-to-epitope
85 | # epitopes_loss = self.celoss(logits_per_epitope, label_matrix)
86 | # receptor_loss = self.celoss(logits_per_receptor, label_matrix)
87 |
88 | # Compute the binary cross-entropy loss for both epitope-to-receptor and receptor-to-epitope
89 | epitopes_loss = self.bceloss_logits(logits_per_epitope, label_matrix)
90 | receptor_loss = self.bceloss_logits(logits_per_receptor, label_matrix)
91 |
92 | # multiply the loss with inverse square-rooted count weights:
93 |
94 | clip_loss = (epitopes_loss + receptor_loss) / 2.0 # shape: (batch_size)
95 | return clip_loss, mse_loss
96 |
97 |
98 | def training_step(self, batch, batch_idx):
99 | """
100 | Training step for the CLIPBody Model
101 | """
102 | epitope_seqs, receptor_seqs = batch
103 |
104 | epitope_proj, receptor_proj = self(epitope_seqs, receptor_seqs, mask=self.ln_cfg.mask_seqs)
105 |
106 | # label_matrix = construct_label_matrices(epitope_seqs, receptor_seqs).to(self.device)
107 | label_matrix = construct_label_matrices_ones(epitope_seqs, receptor_seqs, self.ln_cfg.include_mhc).to(self.device)
108 |
109 | # print("epitope seqs:", epitope_seqs)
110 |
111 | # construct weight matrices for the epitope sequences:
112 | if self.ln_cfg.weigh_epitope_count:
113 | weights = torch.tensor([self.epitope_weights[seq] for seq in epitope_seqs]).to(self.device)
114 | clip_loss, mse_loss = self.clip_loss_multiclass(epitope_proj, receptor_proj, label_matrix, temperature=0.07)
115 | clip_loss = clip_loss * weights
116 | clip_loss = clip_loss.sum()
117 | else:
118 | clip_loss, mse_loss = self.clip_loss_multiclass(epitope_proj, receptor_proj, label_matrix, temperature=0.07)
119 | clip_loss = clip_loss.mean()
120 |
121 | loss = clip_loss * (1 - self.mse_weight) + mse_loss * self.mse_weight
122 | training_metrics = {
123 | 'loss': loss,
124 | }
125 | self.training_step_metrics.setdefault('loss', []).append(loss.detach().item())
126 | if self.ln_cfg.mse_weight > 0:
127 | self.training_step_metrics.setdefault('clip_loss', []).append(clip_loss.detach().item())
128 | self.training_step_metrics.setdefault('mse_loss', []).append(mse_loss.detach().item())
129 |
130 | return training_metrics
131 |
132 | def validation_step(self, batch, batch_idx):
133 | """
134 | Validation step for the CLIPBody Model
135 | """
136 |
137 | epitope_seqs, receptor_seqs = batch
138 | try:
139 | epitope_proj, receptor_proj = self(epitope_seqs, receptor_seqs)
140 | except:
141 | print("Error in feeding sequences")
142 | print("epitope_seqs", epitope_seqs)
143 | print("receptor_seqs", receptor_seqs)
144 | raise ValueError
145 |
146 | # label_matrix = construct_label_matrices(epitope_seqs, receptor_seqs).to(self.device)
147 | label_matrix = construct_label_matrices_ones(epitope_seqs, receptor_seqs, self.ln_cfg.include_mhc).to(self.device)
148 |
149 | # construct weight matrices for the epitope sequences:
150 | if self.ln_cfg.weigh_epitope_count and not self.ln_cfg.unique_epitopes:
151 | weights = torch.tensor([self.epitope_weights[seq] for seq in epitope_seqs]).to(self.device)
152 | clip_loss, mse_loss = self.clip_loss_multiclass(epitope_proj, receptor_proj, label_matrix, temperature=0.07)
153 | clip_loss = clip_loss * weights
154 | clip_loss = clip_loss.sum()
155 | else:
156 | clip_loss, mse_loss = self.clip_loss_multiclass(epitope_proj, receptor_proj, label_matrix, temperature=0.07)
157 | clip_loss = clip_loss.mean()
158 |
159 | loss = clip_loss * (1 - self.mse_weight) + mse_loss * self.mse_weight
160 | val_metrics = {
161 | 'loss': loss,
162 | }
163 | self.val_step_metrics.setdefault('loss', []).append(loss.detach().item())
164 | if self.ln_cfg.mse_weight > 0:
165 | self.val_step_metrics.setdefault('clip_loss', []).append(clip_loss.detach().item())
166 | self.val_step_metrics.setdefault('mse_loss', []).append(mse_loss.detach().item())
167 |
168 | return val_metrics
169 |
170 | def test_step(self, batch, batch_idx):
171 | """
172 | Test step for the CLIPBody Model
173 | """
174 | epitope_seqs, receptor_seqs = batch
175 | epitope_proj, receptor_proj = self(epitope_seqs, receptor_seqs)
176 |
177 | # save the embeddings batches for evaluation later
178 | if self.ln_cfg.include_mhc:
179 | epitope_seqs_to_save = epitope_seqs[0] # extract epitope seqs from the array [epitopes_seqs, mhca_seqs, mhcb_seqs]
180 | else:
181 | epitope_seqs_to_save = epitope_seqs
182 | self.epitope_sequences.append(epitope_seqs_to_save)
183 | self.receptor_sequences.append(receptor_seqs)
184 | self.epitope_embeddings.append(epitope_proj)
185 | self.receptor_embeddings.append(receptor_proj)
186 |
187 | # label_matrix = construct_label_matrices(epitope_seqs, receptor_seqs).to(self.device)
188 | label_matrix = construct_label_matrices_ones(epitope_seqs, receptor_seqs, self.ln_cfg.include_mhc).to(self.device)
189 |
190 | clip_loss, mse_loss = self.clip_loss_multiclass(epitope_proj, receptor_proj, label_matrix, temperature=0.07)
191 |
192 | clip_loss = clip_loss.mean()
193 |
194 | loss = clip_loss * (1 - self.mse_weight) + mse_loss * self.mse_weight
195 | test_metrics = {
196 | 'loss': loss,
197 | }
198 | self.test_step_metrics.setdefault('loss', []).append(loss.detach().item())
199 | if self.ln_cfg.mse_weight > 0:
200 | self.test_step_metrics.setdefault('clip_loss', []).append(clip_loss.detach().item())
201 | self.test_step_metrics.setdefault('mse_loss', []).append(mse_loss.detach().item())
202 |
203 | return test_metrics
204 |
205 | def on_fit_start(self):
206 | # compute the weights for each epitope sequence
207 | if self.ln_cfg.weigh_epitope_count:
208 | print("Weighing the Epitopes by inverse sqrt of their counts!")
209 | self.epitope_weights = compute_weights(self.trainer.datamodule.train_dataloader().dataset.data['epitope'].tolist())
210 |
211 | def on_train_epoch_end(self):
212 | pass
213 |
214 | def on_validation_epoch_end(self):
215 | for metric, values in self.training_step_metrics.items():
216 | avg_metric = self.aggregate_metric(values)
217 | self.log(f'train_{metric}', avg_metric, prog_bar=False, sync_dist=True)
218 | print(f'Epoch train end: {metric}/train', avg_metric)
219 | self.training_step_metrics.clear()
220 |
221 | for metric, values in self.val_step_metrics.items():
222 | avg_metric = self.aggregate_metric(values)
223 | self.log(f'val_{metric}', avg_metric, prog_bar=False, sync_dist=True)
224 | print(f'Epoch validation end: {metric}/val', avg_metric)
225 | self.val_step_metrics.clear()
226 |
227 | def on_test_epoch_end(self):
228 | for metric, values in self.test_step_metrics.items():
229 | avg_metric = self.aggregate_metric(values)
230 | # self.log(f'test_{metric}', avg_metric, prog_bar=False, sync_dist=True)
231 | print(f'Epoch test end: {metric}/test', avg_metric)
232 | self.test_step_metrics.clear()
233 |
234 | # save the embeddings as numpy arrays:
235 | if self.ln_cfg.save_embed_path:
236 | if not os.path.isdir(self.ln_cfg.save_embed_path):
237 | os.makedirs(self.ln_cfg.save_embed_path)
238 |
239 | epitope_sequences = np.concatenate(self.epitope_sequences, axis=0)
240 | receptor_sequences = np.concatenate(self.receptor_sequences, axis=1)
241 | epitope_embeddings = torch.cat(self.epitope_embeddings, dim=0).detach().cpu().numpy()
242 | receptor_embeddings = torch.cat(self.receptor_embeddings, dim=0).detach().cpu().numpy()
243 |
244 | # actually save the embeds
245 | print("Saving sequences and embeddings to disk...")
246 | np.save(self.ln_cfg.save_embed_path + '/epitope_seqs.npy', epitope_sequences)
247 | np.save(self.ln_cfg.save_embed_path + '/receptor_seqs.npy', receptor_sequences)
248 | np.save(self.ln_cfg.save_embed_path + '/epitope_embeds.npy', epitope_embeddings)
249 | np.save(self.ln_cfg.save_embed_path + '/receptor_embeds.npy', receptor_embeddings)
250 |
251 |
252 | @staticmethod
253 | def aggregate_metric(step_outputs):
254 | return np.mean(step_outputs)
255 |
256 | def configure_optimizers(self):
257 | if self.ln_cfg.regular_ft:
258 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.ln_cfg.lr, weight_decay=self.ln_cfg.weight_decay)
259 | return {
260 | "optimizer": optimizer,
261 | }
262 |
263 | if self.ln_cfg.lr_scheduler == 'plateau':
264 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.ln_cfg.lr, weight_decay=self.ln_cfg.weight_decay)
265 | scheduler_lr = ReduceLROnPlateau(optimizer, 'min', factor=0.3, patience=2, min_lr=1e-6)
266 |
267 | elif self.ln_cfg.lr_scheduler == 'cos_anneal':
268 | optimizer = torch.optim.AdamW(self.parameters(), lr=1e-6, weight_decay=self.ln_cfg.weight_decay)
269 | scheduler_lr = CosineAnnealingWarmUpRestarts(optimizer, T_0=10, T_mult=1, eta_max=self.ln_cfg.lr, T_up=2, gamma=0.7)
270 |
271 | return {
272 | "optimizer": optimizer,
273 | "lr_scheduler": {
274 | "scheduler": scheduler_lr,
275 | "interval": "epoch",
276 | "monitor": "val_loss",
277 | "frequency": 1,
278 | },
279 | }
280 |
281 |
282 | def configure_model(self):
283 |
284 | self.epitope_encoder = EpitopeEncoderESM(self.epitope_input_dim, self.projection_dim, hidden_dim=self.hidden_dim,
285 | ln_cfg=self.ln_cfg, model_config=self.model_config, device=self.device)
286 |
287 | if self.model_config.receptor_model_name == 'ablang':
288 | self.receptor_encoder = AntibodyEncoderAbLang(self.receptor_input_dim, self.projection_dim, ln_cfg=self.ln_cfg, device=self.device)
289 | elif self.model_config.receptor_model_name == 'ablang2':
290 | self.receptor_encoder = AntibodyEncoderAbLang2(self.receptor_input_dim, self.projection_dim, ln_cfg=self.ln_cfg, device=self.device)
291 | elif self.model_config.receptor_model_name == 'antiberta2':
292 | self.receptor_encoder = AntibodyEncoderAntiberta2(self.receptor_input_dim, self.projection_dim, ln_cfg=self.ln_cfg, device=self.device)
293 | elif self.model_config.receptor_model_name == 'tcrbert':
294 | self.receptor_encoder = TCREncoderTCRBert(self.receptor_input_dim, self.projection_dim,
295 | hidden_dim=self.hidden_dim, ln_cfg=self.ln_cfg, device=self.device)
296 | elif self.model_config.receptor_model_name == 'tcrlang':
297 | self.receptor_encoder = TCREncoderTCRLang(self.receptor_input_dim, self.projection_dim,
298 | hidden_dim=self.hidden_dim, ln_cfg=self.ln_cfg, device=self.device)
299 | elif self.model_config.receptor_model_name in ['esm2', 'esm3']:
300 | if "catcr" in self.ln_cfg.dataset_path:
301 | self.receptor_encoder = TCREncoderESMBetaOnly(self.receptor_input_dim, self.projection_dim, hidden_dim=self.hidden_dim,
302 | ln_cfg=self.ln_cfg, model_config=self.model_config, device=self.device)
303 | else:
304 | #TODO: REPLACE THIS LINE!
305 | # self.receptor_encoder = TCREncoderESMBetaOnly(self.receptor_input_dim, self.projection_dim, hidden_dim=self.hidden_dim,
306 | # ln_cfg=self.ln_cfg, model_config=self.model_config, device=self.device)
307 | self.receptor_encoder = TCREncoderESM(self.receptor_input_dim, self.projection_dim, hidden_dim=self.hidden_dim,
308 | ln_cfg=self.ln_cfg, model_config=self.model_config, device=self.device)
309 | elif self.model_config.receptor_model_name == 'inhouse':
310 | self.receptor_encoder = TCREncoderInHouse(self.receptor_input_dim, self.projection_dim, hidden_dim=self.hidden_dim,
311 | ln_cfg=self.ln_cfg, model_config=self.model_config, device=self.device)
312 | elif self.model_config.receptor_model_name == 'onehot':
313 | self.epitope_encoder = EpitopeEncoderOneHot(self.epitope_input_dim, self.projection_dim,
314 | ln_cfg=self.ln_cfg, model_config=self.model_config, device=self.device)
315 | self.receptor_encoder = TCREncoderOneHot(self.receptor_input_dim, self.projection_dim,
316 | ln_cfg=self.ln_cfg, model_config=self.model_config, device=self.device)
317 |
318 | else:
319 | raise NotImplementedError("Such Ab Model not implemented yet. Please choose from existing models.")
320 |
321 |
322 | # for inference later
323 | def put_submodules_to_device(self, device):
324 | self.epitope_encoder.device = device
325 | self.receptor_encoder.device = device
--------------------------------------------------------------------------------
/src/data_module.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | from torch.utils.data import DataLoader, Dataset, random_split, Sampler
3 | from sklearn.model_selection import train_test_split
4 | import pandas as pd
5 | import torch
6 | import numpy as np
7 | from collections import defaultdict, deque
8 | import random
9 | import os
10 |
11 | from .utils import (load_iedb_data, load_iedb_data_cdr3, load_vdjdb_data_cdr3,
12 | load_vdjdb_data_pmhc, load_pird_data_cdr3, load_mixtcrpred_data,
13 | load_mixtcrpred_data_pmhc, load_catcr_data)
14 |
15 | class EpitopeReceptorDataset(Dataset):
16 | '''Returns receptor paired with epitope-only data'''
17 | def __init__(self, data):
18 | self.data = data
19 |
20 | def __len__(self):
21 | return len(self.data)
22 |
23 | def __getitem__(self, idx):
24 | # print("idx", idx)
25 | epitope_seq = self.data.iloc[idx]['epitope']
26 | heavy_chain_seq = self.data.iloc[idx]['heavy_chain']
27 | light_chain_seq = self.data.iloc[idx]['light_chain']
28 | return epitope_seq, (heavy_chain_seq, light_chain_seq)
29 |
30 | class pMhcReceptorDataset(Dataset):
31 | '''Returns receptor paired with epitope+MHC data'''
32 | def __init__(self, data):
33 | self.data = data
34 |
35 | def __len__(self):
36 | return len(self.data)
37 |
38 | def __getitem__(self, idx):
39 | epitope_seq = self.data.iloc[idx]['epitope']
40 | mhc_a = self.data.iloc[idx]['mhc.a_seq']
41 | mhc_b = self.data.iloc[idx]['mhc.b_seq']
42 | heavy_chain_seq = self.data.iloc[idx]['heavy_chain']
43 | light_chain_seq = self.data.iloc[idx]['light_chain']
44 | return (epitope_seq, mhc_a, mhc_b), (heavy_chain_seq, light_chain_seq)
45 |
46 | class EpitopeReceptorDataModule(pl.LightningDataModule):
47 | def __init__(self, tsv_file, mhc_file=None, batch_size=32, include_mhc=False, ln_cfg = None,
48 | model_config = None, split_ratio=(0.7, 0.15, 0.15), random_seed=7):
49 | super().__init__()
50 | self.tsv_file = tsv_file
51 | self.batch_size = batch_size
52 | self.model_config = model_config
53 | self.ln_cfg = ln_cfg
54 | self.include_mhc = include_mhc
55 | self.mhc_file = mhc_file
56 | if self.include_mhc:
57 | assert self.mhc_file is not None, "Must provide a file with MHC data"
58 | self.split_ratio = split_ratio
59 |
60 | self.random_seed = random_seed
61 |
62 | def prepare_data_must(self):
63 | # Read the TSV file
64 | if 'IEDB' in self.tsv_file:
65 | if self.model_config.receptor_model_name == 'ablang':
66 | self.data = load_iedb_data(self.tsv_file, replace_X=True)
67 | elif self.model_config.receptor_model_name == 'tcrlang':
68 | self.data = load_iedb_data_cdr3(self.tsv_file, replace_hashtag=True)
69 | elif self.model_config.receptor_model_name == 'tcrbert':
70 | self.data = load_iedb_data_cdr3(self.tsv_file)
71 | else:
72 | self.data = load_iedb_data(self.tsv_file)
73 | elif 'vdjdb' in self.tsv_file:
74 | if self.include_mhc:
75 | self.data = load_vdjdb_data_pmhc(self.tsv_file, self.mhc_file)
76 | else:
77 | self.data = load_vdjdb_data_cdr3(self.tsv_file)
78 | elif 'mixtcrpred' in self.tsv_file:
79 | if self.include_mhc:
80 | self.data = load_mixtcrpred_data_pmhc(self.tsv_file, self.mhc_file)
81 | else:
82 | self.data = load_mixtcrpred_data(self.tsv_file)
83 | elif 'pird' in self.tsv_file:
84 | self.data = load_pird_data_cdr3(self.tsv_file)
85 | elif 'catcr' in self.tsv_file:
86 | self.data = load_catcr_data(self.tsv_file)
87 |
88 | # self.train_data, self.test_data = load_catcr_data(self.tsv_file)
89 | # return
90 | else:
91 | raise ValueError(f"Can't process this tsv file: {self.tsv_file}")
92 |
93 | # Ensure the data has the correct columns
94 | assert 'epitope' in self.data.columns
95 | assert 'heavy_chain' in self.data.columns
96 | assert 'light_chain' in self.data.columns
97 |
98 | def split_data_random(self):
99 | if self.ln_cfg.unique_epitopes:
100 | # ------------------------------------------------------
101 | # Splitting data via unique epitopes:
102 |
103 | np.random.seed(self.random_seed)
104 |
105 | # Get unique values in epitope column
106 | unique_epitopes = self.data['epitope'].unique()
107 |
108 | # Shuffle the unique values:
109 | np.random.shuffle(unique_epitopes)
110 |
111 | # Split the unique values into train, dev, and test sets
112 | train_size = int(self.split_ratio[0] * len(unique_epitopes))
113 | dev_size = int(self.split_ratio[1] * len(unique_epitopes))
114 | test_size = len(unique_epitopes) - train_size - dev_size
115 |
116 | train_values = unique_epitopes[:train_size]#[:100]
117 | dev_values = unique_epitopes[train_size:train_size + dev_size]
118 | test_values = unique_epitopes[train_size + dev_size:]
119 |
120 | # Create train, dev, and test dataframes
121 | # making sure that each set has a unique set of epitopes
122 | self.train_data = self.data[self.data['epitope'].isin(train_values)]
123 | self.dev_data = self.data[self.data['epitope'].isin(dev_values)]
124 | self.test_data = self.data[self.data['epitope'].isin(test_values)]
125 |
126 | elif self.ln_cfg.fewshot_ratio:
127 | self.train_data, self.dev_data = split_df_by_ratio(self.data, 0.85, random_seed=self.random_seed)
128 | if self.ln_cfg.fewshot_ratio < 1:
129 | self.train_data, _ = split_df_by_ratio(self.train_data, self.ln_cfg.fewshot_ratio, random_seed=self.random_seed)
130 | self.test_data = self.dev_data.copy()
131 |
132 | else:
133 | # ------------------------------------------------------
134 | # Split the data into train, dev, and test sets
135 | total_size = len(self.data)
136 | train_size = int(self.split_ratio[0] * total_size)
137 | dev_size = int(self.split_ratio[1] * total_size)
138 | test_size = total_size - train_size - dev_size
139 |
140 | self.train_data, self.temp = train_test_split(self.data, test_size=0.3, random_state=self.random_seed)
141 | self.dev_data, self.test_data = train_test_split(self.temp, test_size=0.5, random_state=self.random_seed)
142 |
143 | # # oversample here:
144 | # if self.ln_cfg.oversample:
145 | # self.train_data = upsample_epitopes(self.train_data, 'epitope')
146 |
147 | # Reset the index of the dataframes
148 | self.train_data = self.train_data.reset_index(drop=True)
149 | self.dev_data = self.dev_data.reset_index(drop=True)
150 | self.test_data = self.test_data.reset_index(drop=True)
151 |
152 |
153 | def setup(self, stage=None):
154 | self.prepare_data_must()
155 |
156 | if "catcr" in self.tsv_file:
157 | # # copy the test data into dev:
158 | # self.dev_data = self.test_data.copy()
159 | self.split_data_random()
160 | else:
161 | self.split_data_random()
162 |
163 | if self.ln_cfg.save_embed_path:
164 | self.save_datasplit(self.ln_cfg.save_embed_path)
165 |
166 | def train_dataloader(self):
167 | if self.include_mhc:
168 | train_dataset = pMhcReceptorDataset(self.train_data)
169 | else:
170 | train_dataset = EpitopeReceptorDataset(self.train_data)
171 |
172 | if self.ln_cfg.oversample:
173 | train_sampler = OversampleSampler(self.train_data, self.batch_size)
174 | return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=False, sampler=train_sampler,
175 | num_workers=4, persistent_workers=True)
176 | else:
177 | return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True,
178 | num_workers=4, pin_memory=True, persistent_workers=True)
179 |
180 | def val_dataloader(self):
181 | if self.include_mhc:
182 | dev_dataset = pMhcReceptorDataset(self.dev_data)
183 | else:
184 | dev_dataset = EpitopeReceptorDataset(self.dev_data)
185 | # return DataLoader(dev_dataset, batch_size=self.batch_size, shuffle=False,
186 | # num_workers=4, pin_memory=True, persistent_workers=True)
187 | return DataLoader(dev_dataset, batch_size=self.batch_size, shuffle=False,
188 | num_workers=4, persistent_workers=True)
189 |
190 |
191 | def test_dataloader(self):
192 | if self.include_mhc:
193 | test_dataset = pMhcReceptorDataset(self.test_data)
194 | else:
195 | test_dataset = EpitopeReceptorDataset(self.test_data)
196 | return DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False,
197 | num_workers=4, pin_memory=True, persistent_workers=True)
198 |
199 |
200 | def save_datasplit(self, savepath):
201 | '''
202 | Save the pandas dataframes of train/dev/test splits to a specified path
203 | '''
204 |
205 | if not os.path.isdir(savepath):
206 | os.makedirs(savepath)
207 |
208 | train_path = os.path.join(savepath, 'train.tsv')
209 | dev_path = os.path.join(savepath, 'dev.tsv')
210 | test_path = os.path.join(savepath, 'test.tsv')
211 |
212 | self.train_data.to_csv(train_path, sep='\t', index=False)
213 | self.dev_data.to_csv(dev_path, sep='\t', index=False)
214 | self.test_data.to_csv(test_path, sep='\t', index=False)
215 |
216 |
217 | # class OversampleSampler(Sampler):
218 | # def __init__(self, df, batch_size):
219 | # self.df = df
220 | # self.batch_size = batch_size
221 | # self.indices = self.generate_indices()
222 |
223 | # def generate_indices(self):
224 | # epitope_counts = self.df['epitope'].value_counts()
225 | # max_count = epitope_counts.max()
226 |
227 | # # Group indices by epitope and shuffle them
228 | # epitope_index_dict = {
229 | # epitope: np.random.permutation(self.df[self.df['epitope'] == epitope].index.tolist() * int( max_count // count )).tolist()
230 | # for epitope, count in epitope_counts.items()
231 | # }
232 |
233 | # # Generate batches with as distinct epitopes as possible
234 | # batched_indices = []
235 | # while any(epitope_index_dict.values()):
236 | # batch = []
237 | # available_epitopes = [epitope for epitope, indices in epitope_index_dict.items() if indices]
238 | # np.random.shuffle(available_epitopes)
239 | # selected_epitopes = available_epitopes[:self.batch_size]
240 |
241 | # for epitope in selected_epitopes:
242 | # if epitope_index_dict[epitope]:
243 | # batch.append(epitope_index_dict[epitope].pop())
244 |
245 | # # Fill the remaining batch size with other available indices if needed
246 | # if len(batch) < self.batch_size:
247 | # remaining_epitopes = [epitope for epitope, indices in epitope_index_dict.items() if indices]
248 | # np.random.shuffle(remaining_epitopes)
249 | # for epitope in remaining_epitopes:
250 | # if len(batch) >= self.batch_size:
251 | # break
252 | # if epitope_index_dict[epitope]:
253 | # batch.append(epitope_index_dict[epitope].pop())
254 |
255 | # np.random.shuffle(batch)
256 | # batched_indices.append(batch)
257 |
258 | # # shuffle the order of minibatches as well
259 | # np.random.shuffle(batched_indices)
260 | # batched_indices = sum(batched_indices, [])
261 |
262 | # return np.array(batched_indices)
263 |
264 | # def __iter__(self):
265 | # return iter(self.indices)
266 |
267 | # def __len__(self):
268 | # return len(self.indices)
269 |
270 |
271 | class OversampleSampler(Sampler):
272 | def __init__(self, df, batch_size):
273 | self.df = df
274 | self.indices = self.generate_indices()
275 |
276 | # print("DF size: ", self.df.shape)
277 | # print("Oversampled indices:", self.indices[:1000])
278 |
279 | def generate_indices(self):
280 | epitope_counts = self.df['epitope'].value_counts()
281 | max_count = epitope_counts.max()
282 |
283 | oversample_indices = []
284 | for epitope, count in epitope_counts.items():
285 | epitope_indices = self.df[self.df['epitope'] == epitope].index.tolist()
286 | oversample_ratio = max_count // count # int( np.sqrt(max_count // count) )
287 | oversample_indices.extend(epitope_indices * oversample_ratio)
288 |
289 | # return oversample_indices
290 |
291 | # Shuffle the oversampled indices to ensure randomness
292 | np.random.shuffle(oversample_indices)
293 | return np.array(oversample_indices)
294 |
295 | def __iter__(self):
296 | # # Shuffle the oversampled indices to ensure randomness
297 | # np.random.shuffle(self.indices)
298 | # return iter(np.array(self.indices))
299 | return iter(self.indices)
300 |
301 | def __len__(self):
302 | return len(self.indices)
303 |
304 |
305 | class UniqueValueSampler(Sampler):
306 | def __init__(self, dataframe, batch_size, seed=42):
307 | self.dataframe = dataframe
308 | self.batch_size = batch_size
309 | self.unique_values = list(dataframe['epitope'].unique())
310 | self.original_indices_by_value = defaultdict(list)
311 | self.indices_by_value = defaultdict(list)
312 |
313 | for idx, value in enumerate(dataframe['epitope']):
314 | self.original_indices_by_value[value].append(idx)
315 |
316 | self.seed = seed
317 | self.reset_indices()
318 |
319 | def reset_indices(self):
320 | # Reset the indices for each unique value from the original indices
321 | self.indices_by_value = {value: indices[:] for value, indices in self.original_indices_by_value.items()}
322 |
323 | def __iter__(self):
324 | random.seed(self.seed)
325 | shuffled_unique_values = self.unique_values[:]
326 | random.shuffle(shuffled_unique_values)
327 |
328 | batches = []
329 | current_batch = []
330 | used_values = set()
331 |
332 | for value in shuffled_unique_values:
333 | if value not in used_values and self.indices_by_value[value]:
334 | index = self.indices_by_value[value].pop(0)
335 | current_batch.append(index)
336 | used_values.add(value)
337 |
338 | if len(current_batch) == self.batch_size:
339 | batches.append(current_batch)
340 | current_batch = []
341 | used_values.clear()
342 |
343 | # Add the last batch if it contains any items
344 | if current_batch:
345 | batches.append(current_batch)
346 |
347 | # Ensure we cover all indices, even if they don't form a full batch
348 | remaining_indices = [idx for value in shuffled_unique_values for idx in self.indices_by_value[value]]
349 | for i in range(0, len(remaining_indices), self.batch_size):
350 | batches.append(remaining_indices[i:i+self.batch_size])
351 |
352 | # Flatten the list of batches to a list of indices
353 | flattened_batches = [idx for batch in batches for idx in batch]
354 |
355 | self.reset_indices() # Reset indices for the next epoch
356 | return iter(flattened_batches)
357 |
358 | def __len__(self):
359 | return len(self.dataframe)
360 |
361 |
362 | def compute_weights(epitope_seqs):
363 | '''
364 | given a list of redundant epitope sequences, count the number of times each unique epitope appears
365 | and compute the inverse square-rooted count weights for each epitope
366 | and save them into a dictionary
367 | '''
368 | epitope_weights = {}
369 | for seq in epitope_seqs:
370 | if seq in epitope_weights:
371 | epitope_weights[seq] += 1
372 | else:
373 | epitope_weights[seq] = 1
374 |
375 | # compute the inverse square-rooted count weights
376 | for seq in epitope_weights:
377 | epitope_weights[seq] = np.sqrt(1 / epitope_weights[seq])
378 |
379 | return epitope_weights
380 |
381 |
382 | def upsample_epitopes(df: pd.DataFrame, epitope_column: str) -> pd.DataFrame:
383 | """
384 | Upsample the DataFrame so that each unique epitope is repeated by the ratio
385 | of max count to its current count.
386 |
387 | Parameters:
388 | - df: pandas DataFrame containing the data.
389 | - epitope_column: name of the column containing the epitope identifiers.
390 |
391 | Returns:
392 | - upsampled_df: pandas DataFrame with epitopes upsampled by the calculated ratio.
393 | """
394 | # Step 1: Count the number of entries for each epitope
395 | epitope_counts = df[epitope_column].value_counts()
396 |
397 | # Step 2: Find the maximum count
398 | max_count = epitope_counts.max()
399 |
400 | # Step 3: Function to upsample each group by the ratio
401 | def upsample(group):
402 | # Calculate the number of repetitions needed for each epitope group
403 | num_repeats = max_count // len(group)
404 | # Repeat each group by the calculated number of repeats
405 | return group.loc[group.index.repeat(num_repeats)]
406 |
407 | # Step 4: Apply the upsample function to each epitope group
408 | upsampled_df = df.groupby(epitope_column, group_keys=False).apply(upsample)
409 |
410 | return upsampled_df
411 |
412 | def split_df_by_ratio(df, r, random_seed=14):
413 | # Create two empty lists to hold the dataframes for the two sets
414 | df_1_list = []
415 | df_2_list = []
416 |
417 | # Group the dataframe by the epitope
418 | grouped = df.groupby('epitope')
419 |
420 | # Iterate through each group
421 | for epitope, group in grouped:
422 | # Shuffle the group
423 | shuffled_group = group.sample(frac=1, random_state=random_seed)
424 |
425 | # Determine the split index
426 | split_idx = int(len(shuffled_group) * r)
427 |
428 | # Split the group into two parts based on the ratio
429 | df_1_list.append(shuffled_group.iloc[:split_idx])
430 | df_2_list.append(shuffled_group.iloc[split_idx:])
431 |
432 | # Concatenate all the individual dataframes to create the final dataframes
433 | df_1 = pd.concat(df_1_list).reset_index(drop=True)
434 | df_2 = pd.concat(df_2_list).reset_index(drop=True)
435 |
436 | return df_1, df_2
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pandas as pd
4 | import re
5 | import json
6 | import functools
7 | import os
8 |
9 | # adapted from the HuggingFace repo for AbLang
10 | def get_sequence_embeddings(encoded_input, model_output, is_sep=True, is_cls=True, epitope_mask=None):
11 | if isinstance(model_output, dict):
12 | output_last_h_state = model_output['last_hidden_state']
13 | else:
14 | output_last_h_state = model_output.last_hidden_state
15 |
16 | mask = encoded_input['attention_mask'].float()
17 | if is_sep:
18 | d = {k: v for k, v in torch.nonzero(mask).cpu().numpy()} # dict of sep tokens
19 | # make sep token invisible
20 | for i in d:
21 | mask[i, d[i]] = 0
22 | if is_cls:
23 | mask[:, 0] = 0.0 # make cls token invisible
24 | if epitope_mask is not None:
25 | mask = mask * epitope_mask # make non-epitope regions invisible
26 | mask = mask.unsqueeze(-1).expand(output_last_h_state.size())
27 | sum_embeddings = torch.sum(output_last_h_state * mask, 1)
28 | sum_mask = torch.clamp(mask.sum(1), min=1e-9)
29 | return sum_embeddings / sum_mask
30 |
31 | def get_attention_mask(encoded_input, is_sep=True, is_cls=True):
32 | mask = encoded_input['attention_mask'].float()
33 | if is_sep:
34 | d = {k: v for k, v in torch.nonzero(mask).cpu().numpy()} # dict of sep tokens
35 | # make sep token invisible
36 | for i in d:
37 | mask[i, d[i]] = 0
38 | if is_cls:
39 | mask[:, 0] = 0.0 # make cls token invisible
40 | return mask
41 |
42 | # load CATCR dataset (contains epitope/CDR3-B data from VDJdb, IEDB, McPAS-TCR):
43 | def load_catcr_data(tsv_path):
44 | df_catcr = pd.read_csv(tsv_path, delimiter=',')
45 |
46 | # rename the TCR column to light_chain:
47 | df_catcr = df_catcr.rename(columns={'CDR3_B':'light_chain', 'EPITOPE':'epitope'})
48 |
49 | # no alpha chain, so create an empty heavy chain column:
50 | df_catcr['heavy_chain'] = ""
51 |
52 | # reset index
53 | df_catcr = df_catcr.reset_index(drop=True)
54 |
55 | return df_catcr
56 |
57 | # load CATCR dataset (contains epitope/CDR3-B data from VDJdb, IEDB, McPAS-TCR):
58 | def load_catcr_data_presplit(tsv_path):
59 | df_catcr_train = pd.read_csv(os.path.join(tsv_path, "train.csv"), delimiter=',')
60 | df_catcr_test = pd.read_csv(os.path.join(tsv_path, "test.csv"), delimiter=',')
61 |
62 | # rename the TCR column to light_chain:
63 | df_catcr_train = df_catcr_train.rename(columns={'CDR3_B':'light_chain',
64 | 'EPITOPE':'epitope'})
65 | df_catcr_test = df_catcr_test.rename(columns={'CDR3_B':'light_chain',
66 | 'EPITOPE':'epitope'})
67 |
68 | # no alpha chain, so create an empty heavy chain column:
69 | df_catcr_train['heavy_chain'] = ""
70 | df_catcr_test['heavy_chain'] = ""
71 |
72 | # reset index
73 | df_catcr_train = df_catcr_train.reset_index(drop=True)
74 | df_catcr_test = df_catcr_test.reset_index(drop=True)
75 |
76 | return df_catcr_train, df_catcr_test
77 |
78 | # load MixTCRPred dataset (contains data from VDJdb, IEDB, McPAS, and 10x Genomics)
79 | def load_mixtcrpred_data(tsv_path):
80 | df_mix = pd.read_csv(tsv_path, delimiter=',')
81 |
82 | # drop rows where Epitope name is not a peptide sequence:
83 | df_mix = df_mix[df_mix['epitope'].str.isupper()]
84 | df_mix = df_mix[df_mix['epitope'].apply(is_alpha_only)]
85 |
86 | # drop rows whose TCR sequences are missing:
87 | df_mix = df_mix.dropna(subset=['cdr3_TRA', 'cdr3_TRB'])
88 |
89 | # rename the TCR columns to heavy_chain and light_chain:
90 | df_mix = df_mix.rename(columns={'cdr3_TRA':'heavy_chain', 'cdr3_TRB':'light_chain'})
91 |
92 | # drop duplicates of epitope-TRA-TRB triplets:
93 | df_mix = df_mix.drop_duplicates(subset=['epitope', 'heavy_chain', 'light_chain']) # drops ~2700 entries
94 |
95 | # only restrict it to human data:
96 | df_mix = df_mix.loc[df_mix["species"] == "HomoSapiens"]
97 |
98 | # reset index
99 | df_mix = df_mix.reset_index(drop=True)
100 |
101 | return df_mix
102 |
103 | def load_mixtcrpred_data_pmhc(tsv_path, mhc_map_path):
104 | df_mhc_map = pd.read_csv(mhc_map_path, delimiter='\t', index_col=0) # format where index is two-part name like "A*01:01" and the column label is "max_sequence" containing the sequence with signal peptide removed
105 | # raise NotImplementedError("mhc_map needs to be updated to include all the mouse MHCs if we are including mouse data")
106 |
107 | df_mix = load_mixtcrpred_data(tsv_path)
108 |
109 | def clean_mhc_allele(allele):
110 | ''' mixtcrpred will sometimes have MHC names formatted like "HLA-DRB1:01".
111 | In these cases, the colon must be replaced with an asterisk'''
112 | if allele != "B2M" and "-" not in allele: # from splitting HLA-__A/__B entries, some HLAs have lost their prefix. Add it back. H2-* alleles are fine
113 | allele = "HLA-" + allele
114 |
115 | pattern = re.compile(r'(HLA-[A-Za-z0-9]+):([:0-9]+)')
116 | return pattern.sub(r'\1*\2', allele)
117 |
118 | # Create new columns (will change their values below)
119 | df_mix["mhc.a"] = df_mix["MHC"]
120 | df_mix["mhc.b"] = df_mix["MHC"]
121 |
122 | # make sure mouse names are consistent with map_mhc_allele's canonical alleles
123 | for i, row in df_mix.iterrows():
124 | if row.MHC_class == "MHCI":
125 | df_mix.at[i, "mhc.a"] = row.MHC
126 | df_mix.at[i, "mhc.b"] = "B2M"
127 | else: # MHCII
128 | # Cases that need to be handled
129 | # Counter({'H2-IAb': 3674,
130 | # 'HLA-DPB1*04:01': 388,
131 | # 'HLA-DRB1:01': 71,
132 | # 'HLA-DQA1:02/DQB1*06:02': 46,
133 | # 'HLA-DRB1*04:05': 31,
134 | # 'H2-IEk': 25,
135 | # 'H2-Kb': 24,
136 | # 'HLA-DRA:01/DRB1:01': 21,
137 | # 'HLA-DRB1*07:01': 20,
138 | # 'HLA-DRB1*15:01': 15,
139 | # 'HLA-DQA1*05:01/DQB1*02:01': 14,
140 | # 'HLA-DRB1*04:01': 14,
141 | # 'HLA-DQA': 13,
142 | # 'H-2q': 12,
143 | # 'HLA-DRB1*11:01': 10,
144 | # 'HLA-DQ2': 10,
145 | # 'HLA-DRA:01': 10})
146 | if "/" in row.MHC:
147 | df_mix.at[i, "mhc.a"] = row.MHC.split("/")[0]
148 | df_mix.at[i, "mhc.b"] = row.MHC.split("/")[1]
149 | else:
150 | # inconsistencies in mixtcrpred: there are 24 examples of 'H2-Kb' and 12 examples of 'H-2q'
151 | # which are labeled as MHC II even though they are MHC I alleles. Switch them to MHC I
152 | if row.MHC == "H2-Kb":
153 | df_mix.at[i, "mhc.a"] = "H2-Kb"
154 | df_mix.at[i, "mhc.b"] = "B2M"
155 | elif row.MHC == "H-2q":
156 | # also switch formatting to H2- nomenclature
157 | df_mix.at[i, "mhc.a"] = "H2-Q"
158 | df_mix.at[i, "mhc.b"] = "B2M"
159 | else:
160 | if row.MHC == "H2-IAb": # switch this nomenclature to H2-A
161 | df_mix.at[i, "mhc.a"] = "H2-AA"
162 | df_mix.at[i, "mhc.b"] = "H2-AB"
163 | elif row.MHC == "H2-IEk": # add A and B chains for this allele
164 | df_mix.at[i, "mhc.a"] = "H2-IEkA"
165 | df_mix.at[i, "mhc.b"] = "H2-IEkB"
166 | elif row.MHC == "HLA-DQ2":
167 | df_mix.at[i, "mhc.a"] = "HLA-DQA1"
168 | df_mix.at[i, "mhc.b"] = "HLA-DQB1"
169 | elif row.MHC == "HLA-DQA":
170 | df_mix.at[i, "mhc.a"] = "HLA-DQA"
171 | df_mix.at[i, "mhc.b"] = "HLA-DQB"
172 | elif row.MHC == "HLA-DRA:01":
173 | df_mix.at[i, "mhc.a"] = "HLA-DRA:01"
174 | df_mix.at[i, "mhc.b"] = "HLA-DRB"
175 | else:
176 | # remainder are all 'HLA-D*B[:*]...' alleles
177 | # extract the allele name
178 | pattern = re.compile(r'HLA-([A-Za-z]+)[0-9]*[:*].*')
179 | allele_name = pattern.search(row.MHC).group(1) # e.g. extracts DRB from HLA-DRB1:01 or HLA-DRB1*11:01
180 | mhca_name = allele_name.replace("B", "A")
181 | df_mix.at[i, "mhc.a"] = f"HLA-{mhca_name}"
182 | df_mix.at[i, "mhc.b"] = row.MHC
183 |
184 |
185 | df_mix["mhc.a"] = df_mix["mhc.a"].apply(clean_mhc_allele)
186 | df_mix["mhc.b"] = df_mix["mhc.b"].apply(clean_mhc_allele)
187 |
188 | df_mix["mhc.a_seq"] = df_mix["mhc.a"].apply(functools.partial(map_mhc_allele, df_mhc_map=df_mhc_map))
189 | df_mix["mhc.b_seq"] = df_mix["mhc.b"].apply(functools.partial(map_mhc_allele, df_mhc_map=df_mhc_map))
190 |
191 |
192 |
193 | return df_mix
194 |
195 |
196 | # load the IEDB dataset:
197 | def load_iedb_data(tsv_path, replace_X=False, remove_X=False, use_anarci=False):
198 | df_iedb = pd.read_csv(tsv_path, delimiter='\t')
199 |
200 | # drop rows where Epitope name is not a peptide sequence:
201 | df_iedb = df_iedb[df_iedb['Epitope - Name'].str.isupper()]
202 | df_iedb = df_iedb[df_iedb['Epitope - Name'].apply(is_alpha_only)]
203 |
204 | # drop rows whose Ab HL sequences missing:
205 | df_iedb = df_iedb.dropna(subset=['Chain 1 - Protein Sequence', 'Chain 2 - Protein Sequence'])
206 |
207 | # drop rows whose CDRs are missing:
208 | cdr_columns = ['Chain 1 - CDR3 Calculated', 'Chain 1 - CDR2 Calculated', 'Chain 1 - CDR1 Calculated',
209 | 'Chain 2 - CDR3 Calculated', 'Chain 2 - CDR2 Calculated', 'Chain 2 - CDR1 Calculated']
210 | df_iedb = df_iedb.dropna(subset=cdr_columns)
211 |
212 | if use_anarci:
213 | from anarci import run_anarci
214 | # run ANARCI to get the Fv region of the sequences:
215 | print("running anarci on sequences...")
216 | for col_id in ['Chain 1 - Protein Sequence', 'Chain 2 - Protein Sequence']:
217 | seqs = df_iedb[col_id].str.upper()
218 | seqs_ = [(str(i), s) for i, s in enumerate(seqs)]
219 | anarci_results = run_anarci(seqs_)
220 | start_end_pairs = [(anarci_results[2][i][0]['query_start'], anarci_results[2][i][0]['query_end']) for i in range(len(seqs_))]
221 | seqs = [seq[a:b] for seq, (a,b) in zip(seqs, start_end_pairs)]
222 | df_iedb[col_id] = seqs
223 |
224 | df_iedb = df_iedb.reset_index(drop=True)
225 | # FOR FUTURE USERS: IF PERFORMING BCR CALCULATIONS, SAVE df_iedb TO A CSV FILE
226 | # e.g. df_iedb.to_csv("path/to/iedb_data_with_anarci.csv", index=False)
227 | print("done running anarci!")
228 |
229 | # change column names:
230 | df_iedb = df_iedb.rename(columns={'Epitope - Name': 'epitope',
231 | 'Chain 1 - Protein Sequence': 'heavy_chain',
232 | 'Chain 2 - Protein Sequence': 'light_chain',
233 | 'Chain 1 - CDR3 Calculated': 'heavy_chain_cdr3',
234 | 'Chain 1 - CDR2 Calculated': 'heavy_chain_cdr2',
235 | 'Chain 1 - CDR1 Calculated': 'heavy_chain_cdr1',
236 | 'Chain 2 - CDR3 Calculated': 'light_chain_cdr3',
237 | 'Chain 2 - CDR2 Calculated': 'light_chain_cdr2',
238 | 'Chain 2 - CDR1 Calculated': 'light_chain_cdr1'})
239 |
240 | if replace_X:
241 | # replace X's with [MASK] in the sequences for AbLang:
242 | df_iedb['heavy_chain'] = df_iedb['heavy_chain'].str.replace('X', '[MASK]')
243 | df_iedb['light_chain'] = df_iedb['light_chain'].str.replace('X', '[MASK]')
244 |
245 | if remove_X:
246 | # remove rows with 'X' in the sequences:
247 | df_iedb = df_iedb[~df_iedb['heavy_chain'].str.contains('X')]
248 | df_iedb = df_iedb[~df_iedb['light_chain'].str.contains('X')]
249 |
250 | # remove all examples where the heavy_chain and light_chain values are the same (likely means invalid pair):
251 | df_iedb = df_iedb[df_iedb['heavy_chain'] != df_iedb['light_chain']]
252 |
253 | # reset index
254 | df_iedb = df_iedb.reset_index(drop=True)
255 |
256 | return df_iedb
257 |
258 | def load_iedb_data_cdr3(tsv_path, replace_hashtag=False):
259 | df_iedb = pd.read_csv(tsv_path, delimiter='\t')
260 |
261 | # drop rows where Epitope name is not a peptide sequence:
262 | df_iedb = df_iedb[df_iedb['Epitope - Name'].str.isupper()]
263 | df_iedb = df_iedb[df_iedb['Epitope - Name'].apply(is_alpha_only)]
264 |
265 | # drop rows whose Ab HL sequences missing:
266 | df_iedb = df_iedb.dropna(subset=['Chain 1 - CDR3 Curated', 'Chain 2 - CDR3 Curated'])
267 |
268 | # change column names:
269 | df_iedb = df_iedb.rename(columns={'Epitope - Name': 'epitope',
270 | 'Chain 1 - CDR3 Curated': 'heavy_chain',
271 | 'Chain 2 - CDR3 Curated': 'light_chain',})
272 |
273 | # remove all examples where the heavy_chain and light_chain values are the same (likely means invalid pair):
274 | df_iedb = df_iedb[df_iedb['heavy_chain'] != df_iedb['light_chain']]
275 |
276 | # drop duplicates of epitope-TRA-TRB triplets:
277 | df_iedb = df_iedb.drop_duplicates(subset=['epitope', 'heavy_chain', 'light_chain'])
278 |
279 | if replace_hashtag:
280 | # replace #'s with X in the sequences for TCRLang:
281 | df_iedb['heavy_chain'] = df_iedb['heavy_chain'].str.replace('#', 'X')
282 | df_iedb['light_chain'] = df_iedb['light_chain'].str.replace('#', 'X')
283 |
284 | # make the AA's upper case in the alpha and beta chains:
285 | df_iedb['heavy_chain'] = df_iedb['heavy_chain'].str.upper()
286 | df_iedb['light_chain'] = df_iedb['light_chain'].str.upper()
287 |
288 | # strip the peptides with whitespace:
289 | df_iedb['heavy_chain'] = df_iedb['heavy_chain'].str.strip()
290 | df_iedb['light_chain'] = df_iedb['light_chain'].str.strip()
291 |
292 | # reset index
293 | df_iedb = df_iedb.reset_index(drop=True)
294 |
295 | return df_iedb
296 |
297 |
298 | def load_vdjdb_data_cdr3(tsv_path):
299 | print("path name:", tsv_path)
300 | df_vdj = pd.read_csv(tsv_path, delimiter='\t')
301 |
302 | # Get subset of columsn we are interested in
303 | df_vdj = df_vdj[["cdr3.alpha", "cdr3.beta", "species", "mhc.a", "mhc.b", "mhc.class", "antigen.epitope", "cdr3fix.alpha", "cdr3fix.beta"]]
304 |
305 | # subset to only paired data (both alpha and beta chain CDR3s are known)
306 | df_vdj = df_vdj.dropna(subset=["cdr3.alpha", "cdr3.beta"])
307 |
308 | # Extract the fixed CDR3 sequences and use those as the ground truth CDR3 sequence
309 | # https://github.com/antigenomics/vdjdb-db?tab=readme-ov-file#cdr3-sequence-fixing
310 | # There is always a fixed value for every existing CDR3, but sometimes the "fixed" value is the same as the empirical one
311 | df_vdj["heavy_chain"] = df_vdj["cdr3fix.alpha"].apply(json.loads).apply(vdjdb_extract_fixed_cdr3)
312 | df_vdj["light_chain"] = df_vdj["cdr3fix.beta"].apply(json.loads).apply(vdjdb_extract_fixed_cdr3)
313 | df_vdj = df_vdj.rename(columns={'antigen.epitope':'epitope'})
314 |
315 | # remove all examples where the heavy_chain and light_chain values are the same (likely means invalid pair):
316 | df_vdj = df_vdj[df_vdj["heavy_chain"] != df_vdj["light_chain"]] # only removes 1 entry
317 |
318 | # drop duplicates of epitope-TRA-TRB triplets:
319 | df_vdj = df_vdj.drop_duplicates(subset=['epitope', 'heavy_chain', 'light_chain']) # drops ~2700 entries
320 |
321 | # make the AA's upper case in the alpha and beta chains (shouldn't be necessary, but just to be safe)
322 | df_vdj['heavy_chain'] = df_vdj['heavy_chain'].str.upper()
323 | df_vdj['light_chain'] = df_vdj['light_chain'].str.upper()
324 |
325 | df_vdj = df_vdj.loc[df_vdj["species"] == "HomoSapiens"] # Before this point, species counts are {'HomoSapiens': 29556, 'MusMusculus': 2264}
326 |
327 | # reset index
328 | df_vdj = df_vdj.reset_index(drop=True)
329 |
330 | return df_vdj
331 |
332 | def load_vdjdb_data_pmhc(tsv_path, mhc_map_path):
333 | df_mhc_map = pd.read_csv(mhc_map_path, delimiter='\t', index_col=0) # format where index is two-part name like "A*01:01" and the column label is "max_sequence" containing the sequence with signal peptide removed
334 |
335 | df_vdj = load_vdjdb_data_cdr3(tsv_path)
336 |
337 | # Map the MHC allele names to their sequences
338 | df_vdj["mhc.a_seq"] = df_vdj["mhc.a"].apply(functools.partial(map_mhc_allele, df_mhc_map=df_mhc_map))
339 | df_vdj["mhc.b_seq"] = df_vdj["mhc.b"].apply(functools.partial(map_mhc_allele, df_mhc_map=df_mhc_map))
340 |
341 | return df_vdj
342 |
343 |
344 |
345 | # Helper function to extract fixed CDR3 from the VDJdb "cdr3fix.[alpha/beta]" column
346 | def vdjdb_extract_fixed_cdr3(obj):
347 | return obj["cdr3"]
348 |
349 | # Helper function to map a given MHC allele name to its sequence
350 | def map_mhc_allele(allele_name, df_mhc_map):
351 | '''
352 | allele_name should be in the original VDJ format, such as "HLA-A*03:01"
353 |
354 | For allele_names that only specify type and no subtype (e.g. "HLA-B*08"), will map to subtype 01 (e.g. "HLA-B*08:01")
355 |
356 | For allele names not found in the MHC map, will map to canonical allele as specified here: https://www.ebi.ac.uk/ipd/imgt/hla/alignment/help/references/
357 | '''
358 |
359 | canonical_alleles = { # only explicitly listed subset found in vdjdb human data
360 | "HLA-A": "HLA-A*01:01",
361 | "HLA-B": "HLA-B*07:02",
362 | "HLA-C": "HLA-C*01:02",
363 | "HLA-E": "HLA-E*01:01",
364 | "HLA-DRA": "HLA-DRA*01:01",
365 | "HLA-DRB": "HLA-DRB1*01:01", # not in imgt/hla, but found in mixtcrpred so mapping to DRB1
366 | "HLA-DRB1": "HLA-DRB1*01:01",
367 | "HLA-DRB3": "HLA-DRB3*01:01",
368 | "HLA-DRB5": "HLA-DRB5*01:01",
369 | "HLA-DQA": "HLA-DQA1*01:01", # not in imgt/hla, but found in vdjdb so mapping to DQA1
370 | "HLA-DQA1": "HLA-DQA1*01:01",
371 | "HLA-DQB": "HLA-DQB1*05:01", # not in imgt/hla but found in mixtcrpred so mapping to DQB1
372 | "HLA-DQB1": "HLA-DQB1*05:01",
373 | "HLA-DPA": "HLA-DPA1*01:03", # not in imgt/hla but found in vdjdb so mappin gto DPA1
374 | "HLA-DPA1": "HLA-DPA1*01:03",
375 | "HLA-DPB": "HLA-DPB1*01:01", # not in imgt/hla but found in vdjdb so mappin gto DPB1
376 | "HLA-DPB1": "HLA-DPB1*01:01",
377 | # "DQ2": "HLA-DQA2*01:01", # found in mixtcrpred, unsure which DQ allele it is referring to, but only DQA has a canonical allele for 2. DQB only has canonical allele for 1
378 | }
379 |
380 | if allele_name == "B2M": # Handles beta chain placeholder for Class I MHCs that do not have a beta chain
381 | return "B2M"
382 |
383 | hla_id = allele_name
384 | if hla_id.startswith("HLA-"):
385 | components = hla_id.split(":")
386 | num_parts = len(components)
387 | if num_parts > 2:
388 | hla_id = ":".join(components[:2])
389 | elif num_parts == 1: # either is of form HLA-A*01 or HLA-A
390 | pattern = re.compile(r'(HLA-[A-Za-z0-9]+)\*([0-9]+)')
391 | if pattern.match(hla_id):
392 | hla_id = ":".join([components[0], "01"])
393 | else:
394 | pass # leave it as is
395 |
396 | if hla_id not in df_mhc_map.index:
397 | hla_gene = hla_id.split("*")[0]
398 | if hla_gene in canonical_alleles:
399 | hla_id = canonical_alleles[hla_gene]
400 | else:
401 | raise Exception(f"Could not find MHC allele {allele_name} in the MHC map and no canonical allele for {hla_gene} specified.")
402 | else: # handles mouse H2 alleles
403 | hla_id = allele_name
404 |
405 |
406 | allele_sequence = df_mhc_map.loc[hla_id, "max_sequence"]
407 | return allele_sequence
408 |
409 |
410 |
411 | def load_pird_data_cdr3(csv_path):
412 | '''
413 | Load the PIRD dataset from a CSV file in latin-1 encoding (default from database download) and return a pandas dataframe.
414 | '''
415 | df_pird = pd.read_csv(csv_path, encoding='latin-1')
416 | df_pird = df_pird.replace('-', np.nan)
417 |
418 | # get subset of columns we are interested in:
419 | df_pird = df_pird[['Antigen.sequence', 'HLA', 'CDR3.alpha.aa', 'CDR3.beta.aa']]
420 |
421 | # subset to only paired data (both alpha and beta chain CDR3s are known)
422 | df_pird = df_pird.dropna(subset=["Antigen.sequence", "CDR3.alpha.aa", "CDR3.beta.aa"])
423 |
424 | # drop rows where Epitope name is not a peptide sequence:
425 | df_pird = df_pird[df_pird['Antigen.sequence'].str.isupper()]
426 | df_pird = df_pird[df_pird['Antigen.sequence'].apply(is_alpha_only)]
427 |
428 | # change column names:
429 | df_pird = df_pird.rename(columns={'Antigen.sequence': 'epitope',
430 | 'CDR3.alpha.aa': 'heavy_chain',
431 | 'CDR3.beta.aa': 'light_chain'})
432 |
433 | # remove all examples where the heavy_chain and light_chain values are the same (likely means invalid pair):
434 | df_pird = df_pird[df_pird["heavy_chain"] != df_pird["light_chain"]] # only removes 1 entry
435 |
436 | # drop duplicates of epitope-TRA-TRB triplets:
437 | df_pird = df_pird.drop_duplicates(subset=['epitope', 'heavy_chain', 'light_chain']) # drops 82 entries
438 |
439 | # make the AA's upper case in the alpha and beta chains (shouldn't be necessary, but just to be safe)
440 | df_pird['heavy_chain'] = df_pird['heavy_chain'].str.upper()
441 | df_pird['light_chain'] = df_pird['light_chain'].str.upper()
442 |
443 | # reset index
444 | df_pird = df_pird.reset_index(drop=True)
445 |
446 | return df_pird
447 |
448 | # Function to check if the string contains only alphabetic characters
449 | def is_alpha_only(s):
450 | return s.isalpha()
451 |
452 | def insert_spaces(sequence):
453 | # Regular expression to match single amino acids or special tokens like '[UNK]'
454 | pattern = re.compile(r'\[.*?\]|.')
455 |
456 | # Find all matches and join them with a space
457 | spaced_sequence = ' '.join(pattern.findall(sequence))
458 |
459 | return spaced_sequence
460 |
461 |
462 | def construct_label_matrices(epitope_seqs, receptor_seqs, include_mhc):
463 | if include_mhc:
464 | epitope_seqs = epitope_seqs[0] # extract epitope seqs from the array [epitopes_seqs, mhca_seqs, mhcb_seqs]
465 | bs = len(epitope_seqs)
466 |
467 | # Create a 2D tensor filled with zeros
468 | label_matrix = torch.zeros((bs, bs), dtype=torch.float32)
469 | # Construct the label matrix
470 | for i, correct_ep in enumerate(epitope_seqs):
471 | count = epitope_seqs.count(correct_ep)
472 | for j, ep in enumerate(epitope_seqs):
473 | if ep == correct_ep:
474 | label_matrix[i, j] = 1.0 / count
475 |
476 | return label_matrix
477 |
478 | def construct_label_matrices_ones(epitope_seqs, receptor_seqs, include_mhc):
479 | if include_mhc:
480 | epitope_seqs = epitope_seqs[0] # extract epitope seqs from the array [epitopes_seqs, mhca_seqs, mhcb_seqs]
481 | bs = len(epitope_seqs)
482 |
483 | # Create a 2D tensor filled with zeros
484 | label_matrix = torch.zeros((bs, bs), dtype=torch.float32)
485 | # Construct the label matrix
486 | for i, correct_ep in enumerate(epitope_seqs):
487 | for j, ep in enumerate(epitope_seqs):
488 | if ep == correct_ep:
489 | label_matrix[i, j] = 1.0
490 |
491 | return label_matrix
492 |
493 | def apply_masking_seq(sequences, mask_token='.', mask_regions=True, p=0.15):
494 | '''
495 | mask_regions: True or List[np.array(dtype=bool)]
496 | - if True, all amino acids will be considered
497 | - if List, only amino acids with True values in the list will be considered (i.e. mask for regions to mask)
498 |
499 | For each sequence string in the sequences list, apply masking by changing the
500 | amino acid with the mask_token with a certain probability.
501 | '''
502 |
503 | if mask_regions is True: # convert True value into all-True arrays every sequence
504 | mask_regions = [np.ones(len(seq), dtype=bool) for seq in sequences]
505 |
506 | masked_sequences = []
507 | mask_indices = []
508 | for n, seq in enumerate(sequences):
509 | masked_seq = ''
510 | # seq_mask_indices = []
511 | mask_count = 0
512 | for i, aa in enumerate(seq):
513 | if mask_regions[n][i] and torch.rand(1) < p and mask_count < sum(mask_regions[n]) - 1:
514 | masked_seq += mask_token
515 | mask_indices.append([n, i])
516 | mask_count += 1
517 | else:
518 | masked_seq += aa
519 | masked_sequences.append(masked_seq)
520 |
521 | return masked_sequences, mask_indices
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | import re
5 |
6 | from .utils import get_sequence_embeddings, insert_spaces, get_attention_mask, apply_masking_seq
7 | from .swe_pooling import SWE_Pooling
8 |
9 | class EpitopeEncoderESM(nn.Module):
10 | def __init__(self, input_dim, projection_dim, ln_cfg, model_config, hidden_dim=1024, device='cpu'):
11 | super().__init__()
12 |
13 | self.ln_config = ln_cfg
14 | self.model_config = model_config
15 | self.projection_dim = projection_dim
16 |
17 | if self.model_config.receptor_model_name == 'esm3':
18 | from .lora import setup_peft_esm3
19 | from .configs import peft_config_esm3
20 |
21 | # load the LoRA adapted ESM-3 Model here:
22 | self.esm_lora, self.esm_tokenizer = setup_peft_esm3(peft_config_esm3, ln_cfg.no_lora)
23 | else:
24 | from .lora import setup_peft_esm2
25 | from .configs import peft_config_esm2
26 |
27 | # load the LoRA adapted ESM-2 Model here:
28 | self.esm_lora, self.esm_tokenizer = setup_peft_esm2(peft_config_esm2, ln_cfg.no_lora, ln_cfg.regular_ft)
29 |
30 | # For ESM2, we need linker to represent multimers
31 | self.linker_size = 25
32 | self.gly_linker = 'G'*self.linker_size
33 |
34 | if self.projection_dim:
35 | if hidden_dim:
36 | print("Using multi-layer projection head")
37 | self.proj_head = nn.Sequential(
38 | nn.Linear(input_dim, hidden_dim),
39 | nn.LayerNorm(hidden_dim),
40 | nn.LeakyReLU(),
41 | nn.Dropout(p=0.3),
42 | nn.Linear(hidden_dim, projection_dim),
43 | )
44 | # Initialize the projection head weights
45 | nn.init.kaiming_uniform_(self.proj_head[0].weight)
46 | nn.init.kaiming_uniform_(self.proj_head[-1].weight)
47 | else:
48 | print("Using single-layer projection head")
49 | self.proj_head = nn.Sequential(
50 | nn.Linear(input_dim, projection_dim),
51 | nn.LayerNorm(projection_dim),
52 | )
53 | # Initialize the projection head weights
54 | nn.init.kaiming_uniform_(self.proj_head[0].weight)
55 | else:
56 | print("NOT using projection head")
57 |
58 | if self.ln_config.swe_pooling:
59 | self.swe_pooling = SWE_Pooling(d_in=input_dim, num_ref_points=512, num_slices=projection_dim)
60 |
61 | self.proj_head = nn.Sequential(
62 | nn.Linear(projection_dim, projection_dim // 2),
63 | nn.LayerNorm(projection_dim // 2),
64 | )
65 |
66 | self.device = device
67 |
68 |
69 | def forward(self, x, mask):
70 | seqs_tokens = self.process_seqs(x, mask, mask_prob=self.ln_config.mask_prob)
71 |
72 | if self.model_config.receptor_model_name == 'esm3':
73 | outputs = self.esm_lora(sequence_tokens=seqs_tokens['input_ids']).embeddings
74 | else:
75 | outputs = self.esm_lora(**seqs_tokens)
76 |
77 | if self.ln_config.swe_pooling:
78 | assert self.ln_config.include_mhc == False, "SWE pooling not supported for MHC sequences yet" # TODO: implement for MHC sequences
79 | # for SWE pooling:
80 | attn_mask = get_attention_mask(seqs_tokens, is_sep=False, is_cls=False)
81 | # attn_mask = get_attention_mask(seqs_tokens)
82 | if isinstance(outputs, dict):
83 | outputs = outputs['last_hidden_state']
84 | elif self.model_config.receptor_model_name == 'esm3':
85 | pass
86 | else:
87 | outputs = outputs.last_hidden_state
88 | seq_embeds = self.swe_pooling(outputs, attn_mask)
89 | else:
90 | # for regular mean pooling
91 | epitope_mask = None
92 | if self.ln_config.include_mhc:
93 | epitope_seqs, mhca_seqs, mhcb_seqs = x
94 | epitope_mask = torch.zeros_like(seqs_tokens['attention_mask'])
95 | for i, (seq, mhcA, mhcB) in enumerate(zip(epitope_seqs, mhca_seqs, mhcb_seqs)):
96 | # assumes no special tokens
97 | epitope_mask[i, len(mhcA) + self.linker_size : len(mhcA) + self.linker_size + len(seq)] = 1
98 |
99 | if self.model_config.receptor_model_name == 'esm3':
100 | outputs = {'last_hidden_state': outputs}
101 | seq_embeds = get_sequence_embeddings(seqs_tokens, outputs, is_sep=False, is_cls=False, epitope_mask=epitope_mask)
102 |
103 | if self.projection_dim:
104 | return self.proj_head(seq_embeds)
105 | else:
106 | return seq_embeds
107 |
108 | def process_seqs(self, inputs, mask, mask_prob=0.15):
109 | '''
110 | input: list of epitope sequences or epitope-mhc array (3,N) where N is the number of samples
111 |
112 | if self.include_mhc = True, expecting input to be list containing tuples of epitope and MHC sequences in form
113 | (epitope_seq, mhc.a_seq, mhc.b_seq)
114 |
115 | if self.include_mhc = False, expecting input to be list of strings of epitope sequences
116 | '''
117 | if self.ln_config.include_mhc:
118 | epitope_seqs, mhca_seqs, mhcb_seqs = inputs
119 |
120 | if self.ln_config.mhc_groove_only:
121 | # keep only A1+A2 domains (roundly AA 0-180) for class 1 MHCs, and A1+B1 domains (rougly AA 0-90) for class 2 MHCs
122 | for i, (mhcA, mhcB) in enumerate(zip(mhca_seqs, mhcb_seqs)):
123 | if mhcB == "B2M":
124 | mhca_seqs[i] = mhcA[:180]
125 | else:
126 | mhca_seqs[i] = mhcA[:90]
127 | mhcb_seqs[i] = mhcB[:90]
128 |
129 | # Create the pMHC sequence in the order [mhcA ..G.. epitope ..G.. mhcB]
130 | seqs = [
131 | (
132 | f"{mhcA}{self.gly_linker}{seq}{self.gly_linker}{mhcB}"
133 | if mhcB != "B2M" else
134 | f"{mhcA}{self.gly_linker}{seq}"
135 | )
136 | for seq, mhcA, mhcB in zip(epitope_seqs, mhca_seqs, mhcb_seqs)
137 | ]
138 |
139 | # marking where the Glycine linker starts
140 | # linker between mhcA and epitope
141 | attn_starts = [(i, len(mhcA)) for i, mhcA in enumerate(mhca_seqs)]
142 | # linker between epitope and mhcB
143 | attn_starts.extend([
144 | (i, len(mhcA) + self.linker_size + len(seq))
145 | for i, (seq, mhcA, mhcB) in enumerate(zip(epitope_seqs, mhca_seqs, mhcb_seqs)) if mhcB != "B2M"
146 | ])
147 | else:
148 | seqs = inputs
149 |
150 | # removing special tokens since epitopes are protein fragments (peptides)
151 | seqs_tokens = self.esm_tokenizer(seqs, return_tensors="pt", add_special_tokens=False, padding=True)
152 |
153 | if mask:
154 | if self.ln_config.include_mhc:
155 | mask_regions = [np.zeros(len(seq), dtype=bool) for seq in seqs]
156 | for i, (seq, mhcA, mhcB) in enumerate(zip(epitope_seqs, mhca_seqs, mhcb_seqs)):
157 | # always include epitope sequence for random masking
158 | epitope_offset = len(mhcA) + self.linker_size
159 | mask_regions[i][epitope_offset : epitope_offset + len(seq)] = True
160 | # if class I MHC, only apply random masks to A1+A2 domains (rougly AAs 0-180)
161 | if mhcB == "B2M":
162 | mask_regions[i][0 : min(180, len(mhcA))] = True
163 |
164 | # if class II MHC, only apply random masks to A1+B1 domains (rougly AAs 0-90 for each)
165 | else:
166 | mask_regions[i][0 : min(90, len(mhcA))] = True
167 | beta_offset = len(mhcA) + self.linker_size + len(seq) + self.linker_size
168 | mask_regions[i][beta_offset : min(beta_offset+90, beta_offset+len(mhcB))] = True
169 | else:
170 | mask_regions = True # seqs is just the epitope, so all tokens can be masked
171 |
172 | # masking the sequences for training
173 | seqs, attn_mask_indices = apply_masking_seq(seqs, mask_token='', mask_regions=mask_regions, p=mask_prob)
174 | indices_tensor = torch.tensor(attn_mask_indices, dtype=torch.long)
175 | if len(indices_tensor) > 0:
176 | seqs_tokens['attention_mask'][indices_tensor[:, 0], indices_tensor[:, 1]] = 0.
177 |
178 | # if necessary, masking the linker region
179 | if self.ln_config.include_mhc:
180 | for i, start in attn_starts:
181 | seqs_tokens['attention_mask'][i, start:start+self.linker_size] = 0.
182 |
183 | return seqs_tokens.to(self.device)
184 |
185 | class EpitopeEncoderOneHot(nn.Module):
186 | def __init__(self, input_dim, projection_dim, ln_cfg, model_config, device='cpu'):
187 | super().__init__()
188 |
189 | self.ln_config = ln_cfg
190 | self.projection_dim = projection_dim
191 |
192 | if self.projection_dim:
193 | print("Using single-layer projection head")
194 | self.proj_head = nn.Sequential(
195 | nn.Linear(input_dim, projection_dim),
196 | nn.LayerNorm(projection_dim),
197 | )
198 | else:
199 | assert False, "Projection head must be used with one-hot encoding!"
200 |
201 | # Define the amino acid to index mapping
202 | self.amino_acid_to_index = {
203 | 'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4,
204 | 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9,
205 | 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14,
206 | 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19,
207 | 'X': 20 # Unknown amino acid
208 | }
209 |
210 | self.device = device
211 |
212 | def forward(self, x, mask):
213 | seqs = x
214 | seqs_onehot = self.create_padded_one_hot_tensor(seqs, len(self.amino_acid_to_index))
215 |
216 | proj_output = self.proj_head(seqs_onehot)
217 |
218 | # average the projected embeddings by seq length:
219 | seq_lens = torch.sum(seqs_onehot, dim=(1, 2))
220 | # Create a mask with shape (batch_size, max_seq_length)
221 | seq_mask = torch.arange(proj_output.size(1)).unsqueeze(0).to(self.device) < seq_lens.unsqueeze(-1)
222 | seq_mask = seq_mask.unsqueeze(2) # Shape (batch_size, max_seq_length, 1)
223 | # Sum the embeddings across the sequence length dimension using the mask
224 | masked_embeddings = proj_output * seq_mask
225 | sum_embeddings = masked_embeddings.sum(dim=1)
226 |
227 | # Divide by the true sequence lengths to get the average
228 | avg_embeddings = sum_embeddings / seq_lens.unsqueeze(1)#.to(embeddings.device)
229 |
230 | return avg_embeddings
231 |
232 |
233 | # @staticmethod
234 | def encode_amino_acid_sequence(self, sequence):
235 | """ Convert an amino acid sequence to a list of indices. """
236 | return [self.amino_acid_to_index[aa] for aa in sequence]
237 |
238 | # @staticmethod
239 | def one_hot_encode_sequence(self, sequence, vocab_size):
240 | """ One-hot encode a single sequence. """
241 | encoding = np.zeros((len(sequence), vocab_size), dtype=int)
242 | for idx, char in enumerate(sequence):
243 | encoding[idx, char] = 1
244 | return encoding
245 |
246 | # @staticmethod
247 | def pad_sequences(self, encoded_sequences, max_length):
248 | """ Pad the encoded sequences to the maximum length. """
249 | padded_sequences = []
250 | for seq in encoded_sequences:
251 | padded_seq = np.pad(seq, ((0, max_length - len(seq)), (0, 0)), mode='constant', constant_values=0)
252 | padded_sequences.append(padded_seq)
253 | return np.array(padded_sequences)
254 |
255 | # @staticmethod
256 | def create_padded_one_hot_tensor(self, sequences, vocab_size):
257 | """ Convert a batch of sequences to a padded one-hot encoding tensor. """
258 | # Encode and one-hot encode each sequence
259 | encoded_sequences = [self.one_hot_encode_sequence(self.encode_amino_acid_sequence(seq), vocab_size) for seq in sequences]
260 |
261 | # Determine the maximum sequence length
262 | max_length = max(len(seq) for seq in sequences)
263 |
264 | # Pad the sequences
265 | padded_sequences = self.pad_sequences(encoded_sequences, max_length)
266 |
267 | # Convert to a PyTorch tensor
268 | padded_tensor = torch.tensor(padded_sequences, dtype=torch.float32)
269 |
270 | return padded_tensor.to(self.device)
271 |
272 |
273 | class TCREncoderTCRBert(nn.Module):
274 | def __init__(self, input_dim, projection_dim, ln_cfg, hidden_dim=1024, device='cpu'):
275 | super().__init__()
276 | from .lora import setup_peft_tcrbert
277 | from .configs import peft_config_tcrbert
278 |
279 | self.tcrbert_tra_lora, self.tcrbert_tra_tokenizer = setup_peft_tcrbert(peft_config_tcrbert, no_lora=ln_cfg.no_lora, regular_ft=ln_cfg.regular_ft)
280 | self.tcrbert_trb_lora, self.tcrbert_trb_tokenizer = setup_peft_tcrbert(peft_config_tcrbert, no_lora=ln_cfg.no_lora, regular_ft=ln_cfg.regular_ft)
281 |
282 | self.ln_config = ln_cfg
283 |
284 | if hidden_dim:
285 | print("Using multi-layer projection head")
286 | self.proj_head = nn.Sequential(
287 | nn.Linear(input_dim, hidden_dim),
288 | nn.LayerNorm(hidden_dim),
289 | nn.LeakyReLU(),
290 | nn.Dropout(p=0.3),
291 | nn.Linear(hidden_dim, projection_dim),
292 | )
293 | # Initialize the projection head weights
294 | nn.init.kaiming_uniform_(self.proj_head[0].weight)
295 | nn.init.kaiming_uniform_(self.proj_head[-1].weight)
296 | else:
297 | print("Using single-layer projection head")
298 | self.proj_head = nn.Sequential(
299 | nn.Linear(input_dim, projection_dim),
300 | nn.LayerNorm(projection_dim),
301 | )
302 | # Initialize the projection head weights
303 | nn.init.kaiming_uniform_(self.proj_head[0].weight)
304 |
305 | if self.ln_config.swe_pooling:
306 | self.swe_pooling_a = SWE_Pooling(d_in=input_dim // 2, num_ref_points=256, num_slices=projection_dim // 2)
307 | self.swe_pooling_b = SWE_Pooling(d_in=input_dim // 2, num_ref_points=256, num_slices=projection_dim // 2)
308 |
309 | self.proj_head = nn.Sequential(
310 | nn.Linear(projection_dim, projection_dim // 2),
311 | nn.LayerNorm(projection_dim // 2),
312 | )
313 |
314 | self.device = device
315 |
316 | def forward(self, x, mask):
317 | tra_tokens, trb_tokens = self.process_seqs(x, mask, mask_prob=self.ln_config.mask_prob)
318 |
319 | # feed to TCRBERT
320 | rescoding_tra = self.tcrbert_tra_lora(**tra_tokens)
321 | rescoding_trb = self.tcrbert_trb_lora(**trb_tokens)
322 |
323 | if self.ln_config.swe_pooling:
324 | # for SWE pooling:
325 | attn_mask_a = get_attention_mask(tra_tokens)
326 | if isinstance(rescoding_tra, dict):
327 | rescoding_tra = rescoding_tra['last_hidden_state']
328 | else:
329 | rescoding_tra = rescoding_tra.last_hidden_state
330 | tra_outputs = self.swe_pooling_a(rescoding_tra, attn_mask_a)
331 |
332 | attn_mask_b = get_attention_mask(trb_tokens)
333 | if isinstance(rescoding_trb, dict):
334 | rescoding_trb = rescoding_trb['last_hidden_state']
335 | else:
336 | rescoding_trb = rescoding_trb.last_hidden_state
337 | trb_outputs = self.swe_pooling_b(rescoding_trb, attn_mask_b)
338 | else:
339 | # for regular mean pooling
340 | tra_outputs = get_sequence_embeddings(tra_tokens, rescoding_tra)
341 | trb_outputs = get_sequence_embeddings(trb_tokens, rescoding_trb)
342 |
343 | tcr_embeds = torch.cat((tra_outputs, trb_outputs), dim=-1)
344 |
345 | return self.proj_head(tcr_embeds)
346 |
347 |
348 | def process_seqs(self, seqs, mask, mask_prob=0.15):
349 | tra_seqs_, trb_seqs_ = seqs
350 |
351 | # insert spaces between residues for correct formatting:
352 | tra_seqs = [insert_spaces(seq) for seq in tra_seqs_]
353 | trb_seqs = [insert_spaces(seq) for seq in trb_seqs_]
354 |
355 | tra_tokens = self.tcrbert_tra_tokenizer(tra_seqs, return_tensors="pt", padding=True)
356 | trb_tokens = self.tcrbert_trb_tokenizer(trb_seqs, return_tensors="pt", padding=True)
357 |
358 | if mask:
359 | # masking the sequences for training
360 | tra_seqs_, attn_mask_indices = apply_masking_seq(tra_seqs_, p=mask_prob)
361 | indices_tensor = torch.tensor(attn_mask_indices, dtype=torch.long)
362 | if len(indices_tensor) > 0:
363 | tra_tokens['attention_mask'][indices_tensor[:, 0], indices_tensor[:, 1] + 1] = 0. # +1 to account for the CLS token
364 |
365 | trb_seqs_, attn_mask_indices = apply_masking_seq(trb_seqs_, p=mask_prob)
366 | indices_tensor = torch.tensor(attn_mask_indices, dtype=torch.long)
367 | if len(indices_tensor) > 0:
368 | trb_tokens['attention_mask'][indices_tensor[:, 0], indices_tensor[:, 1] + 1] = 0. # +1 to account for the CLS token
369 |
370 | # print("TCR Seqs:", tra_seqs_)
371 | # print("seqs_tokens:", tra_tokens['attention_mask'])
372 |
373 | return tra_tokens.to(self.device), trb_tokens.to(self.device)
374 |
375 |
376 | class TCREncoderTCRLang(nn.Module):
377 | def __init__(self, input_dim, projection_dim, ln_cfg, hidden_dim=1024, device='cpu'):
378 | super().__init__()
379 | from .lora import setup_peft_ablang2
380 | from .configs import peft_config_ablang2
381 |
382 | self.ablang2_lora, self.ablang2_tokenizer = setup_peft_ablang2(peft_config_ablang2, receptor_type='TCR', device=device, no_lora=ln_cfg.no_lora)
383 | self.padding_idx = 21
384 | self.mask_token = 23
385 | self.sep_token_id = 25
386 |
387 | self.ln_config = ln_cfg
388 |
389 | if hidden_dim:
390 | print("Using multi-layer projection head")
391 | self.proj_head = nn.Sequential(
392 | nn.Linear(input_dim, hidden_dim),
393 | nn.LayerNorm(hidden_dim),
394 | nn.LeakyReLU(),
395 | nn.Dropout(p=0.5),
396 | nn.Linear(hidden_dim, projection_dim),
397 | )
398 | else:
399 | print("Using single-layer projection head")
400 | self.proj_head = nn.Sequential(
401 | nn.Linear(input_dim, projection_dim),
402 | nn.LayerNorm(projection_dim),
403 | )
404 |
405 | if self.ln_config.swe_pooling:
406 | self.swe_pooling = SWE_Pooling(d_in=input_dim, num_ref_points=512, num_slices=projection_dim)
407 |
408 | self.proj_head = nn.Sequential(
409 | nn.Linear(projection_dim, projection_dim // 2),
410 | nn.LayerNorm(projection_dim // 2),
411 | )
412 |
413 | self.device = device
414 |
415 | def forward(self, x, mask):
416 | seq_tokens = self.process_seqs(x, mask, mask_prob=self.ln_config.mask_prob)
417 |
418 | # print("seq tokens:", seq_tokens)
419 |
420 | # feed to TCRLang
421 | rescoding = self.ablang2_lora(seq_tokens)
422 |
423 | # process TCRLang outputs
424 | seq_inputs = {'attention_mask': ~((seq_tokens == self.padding_idx) | (seq_tokens == self.mask_token))}
425 | model_output = {'last_hidden_state': rescoding.last_hidden_states}
426 |
427 | if self.ln_config.swe_pooling:
428 | # for SWE pooling:
429 | attn_mask = seq_inputs['attention_mask']
430 | model_embed = model_output['last_hidden_state']
431 |
432 | seq_outputs = self.swe_pooling(model_embed, attn_mask)
433 |
434 | else:
435 | # for regular mean pooling
436 | seq_outputs = get_sequence_embeddings(seq_inputs, model_output, is_sep=False, is_cls=False)
437 |
438 | return self.proj_head(seq_outputs)
439 |
440 | def process_seqs(self, seqs, mask, mask_prob=0.15):
441 | H_seqs, L_seqs = seqs
442 |
443 | # format the seq strings accordingly to TCRLang (B chain comes first, so we swap H and L orders):
444 | ab_seqs = [f"{L_seqs[i]}|{H_seqs[i]}" for i in range(len(H_seqs))]
445 |
446 | seqs_tokens = self.ablang2_tokenizer(ab_seqs, pad=True, w_extra_tkns=False, device=self.device)
447 |
448 | if mask:
449 | # masking the sequences for training
450 | ab_seqs, attn_mask_indices = apply_masking_seq(ab_seqs, mask_token='*', p=mask_prob)
451 | indices_tensor = torch.tensor(attn_mask_indices, dtype=torch.long)
452 | if len(indices_tensor) > 0:
453 | seqs_tokens[indices_tensor[:, 0], indices_tensor[:, 1]] = self.mask_token
454 |
455 | # leave the SEP tokens ('|') unmasked!!
456 | for i, l_seq in enumerate(L_seqs):
457 | seqs_tokens[i, len(l_seq)] = self.sep_token_id
458 |
459 | return seqs_tokens
460 |
461 | class TCREncoderESM(nn.Module):
462 | def __init__(self, input_dim, projection_dim, ln_cfg, model_config=None, hidden_dim=1024, device='cpu'):
463 | super().__init__()
464 | from .lora import setup_peft_esm2
465 | from .configs import peft_config_esm2
466 |
467 | self.ln_config = ln_cfg
468 | self.model_config = model_config
469 | self.projection_dim = projection_dim
470 |
471 | if self.model_config.receptor_model_name == 'esm3':
472 | from .lora import setup_peft_esm3
473 | from .configs import peft_config_esm3
474 |
475 | # load the LoRA adapted ESM-3 Model here:
476 | self.esm_lora, self.esm_tokenizer = setup_peft_esm3(peft_config_esm3, ln_cfg.no_lora)
477 | else:
478 | from .lora import setup_peft_esm2
479 | from .configs import peft_config_esm2
480 |
481 | # load the LoRA adapted ESM-2 Model here:
482 | self.esm_lora, self.esm_tokenizer = setup_peft_esm2(peft_config_esm2, ln_cfg.no_lora, ln_cfg.regular_ft)
483 |
484 | if self.projection_dim:
485 | if hidden_dim:
486 | print("Using multi-layer projection head")
487 | self.proj_head = nn.Sequential(
488 | nn.Linear(input_dim, hidden_dim),
489 | nn.LayerNorm(hidden_dim),
490 | nn.LeakyReLU(),
491 | nn.Dropout(p=0.5),
492 | nn.Linear(hidden_dim, projection_dim),
493 | )
494 | else:
495 | print("Using single-layer projection head")
496 | self.proj_head = nn.Sequential(
497 | nn.Linear(input_dim, projection_dim),
498 | nn.LayerNorm(projection_dim),
499 | )
500 | else:
501 | print("NOT using projection head")
502 |
503 | # for ESM-2, we need linker to represent multimers
504 | self.linker_size = 25
505 | self.gly_linker = 'G'*self.linker_size
506 | self.gly_idx = 6 # according to: https://huggingface.co/facebook/esm2_t33_650M_UR50D/blob/main/vocab.txt
507 |
508 | self.device = device
509 |
510 | def forward(self, x, mask):
511 | seqs_tokens = self.process_seqs(x, mask, mask_prob=self.ln_config.mask_prob)
512 |
513 | if self.model_config.receptor_model_name == 'esm3':
514 | outputs = self.esm_lora(sequence_tokens=seqs_tokens['input_ids']).embeddings
515 | outputs = {'last_hidden_state': outputs}
516 | else:
517 | outputs = self.esm_lora(**seqs_tokens)
518 |
519 | seq_embeds = get_sequence_embeddings(seqs_tokens, outputs, is_sep=False, is_cls=False)
520 |
521 | if self.projection_dim:
522 | return self.proj_head(seq_embeds)
523 | else:
524 | return seq_embeds
525 |
526 | def process_seqs(self, seqs, mask, mask_prob=0.15):
527 | '''
528 | seqs: list of epitope sequences
529 | '''
530 | tra_seqs, trb_seqs = seqs
531 | seqs = [f"{tra_seqs[i]}{self.gly_linker}{trb_seqs[i]}" for i in range(len(tra_seqs))]
532 | mask_regions = [[True]*len(seqa)+[False]*self.linker_size+[True]*len(seqb) for seqa, seqb in zip(tra_seqs, trb_seqs)]
533 |
534 | # marking where the Glycine linker starts
535 | attn_starts = [len(alpha_chain) for alpha_chain in tra_seqs]
536 |
537 | # removing special tokens since epitopes are protein fragments (peptides)
538 | seqs_tokens = self.esm_tokenizer(seqs, return_tensors="pt", add_special_tokens=False, padding=True)
539 |
540 | if mask:
541 | # masking the sequences for training
542 | seqs, attn_mask_indices = apply_masking_seq(seqs, mask_token='', mask_regions=mask_regions, p=mask_prob)
543 | indices_tensor = torch.tensor(attn_mask_indices, dtype=torch.long)
544 | if len(indices_tensor) > 0:
545 | seqs_tokens['attention_mask'][indices_tensor[:, 0], indices_tensor[:, 1]] = 0.
546 |
547 | # remove mask tokens on linker region:
548 | for i in range(len(attn_starts)):
549 | seqs_tokens['input_ids'][i, attn_starts[i]:attn_starts[i]+self.linker_size] = self.gly_idx
550 |
551 |
552 | # attention masking the linker region
553 | for i in range(len(attn_starts)):
554 | seqs_tokens['attention_mask'][i, attn_starts[i]:attn_starts[i]+self.linker_size] = 0.
555 |
556 | # print("TCR Seqs:", seqs)
557 | # print("seqs_tokens:", seqs_tokens['attention_mask'])
558 |
559 | return seqs_tokens.to(self.device)
560 |
561 |
562 | class TCREncoderESMBetaOnly(nn.Module):
563 | def __init__(self, input_dim, projection_dim, ln_cfg, model_config=None, hidden_dim=1024, device='cpu'):
564 | super().__init__()
565 | from .lora import setup_peft_esm2
566 | from .configs import peft_config_esm2
567 |
568 | # # load the LoRA adapted ESM Model here:
569 | # self.esm_lora, self.esm_tokenizer = setup_peft_esm2(peft_config_esm2, ln_cfg.no_lora)
570 |
571 | self.ln_config = ln_cfg
572 | self.model_config = model_config
573 | self.projection_dim = projection_dim
574 |
575 | if self.model_config.receptor_model_name == 'esm3':
576 | from .lora import setup_peft_esm3
577 | from .configs import peft_config_esm3
578 |
579 | # load the LoRA adapted ESM-3 Model here:
580 | self.esm_lora, self.esm_tokenizer = setup_peft_esm3(peft_config_esm3, ln_cfg.no_lora)
581 | else:
582 | from .lora import setup_peft_esm2
583 | from .configs import peft_config_esm2
584 |
585 | # load the LoRA adapted ESM-2 Model here:
586 | self.esm_lora, self.esm_tokenizer = setup_peft_esm2(peft_config_esm2, ln_cfg.no_lora)
587 |
588 | if self.projection_dim:
589 | if hidden_dim:
590 | print("Using multi-layer projection head")
591 | self.proj_head = nn.Sequential(
592 | nn.Linear(input_dim, hidden_dim),
593 | nn.LayerNorm(hidden_dim),
594 | nn.LeakyReLU(),
595 | nn.Dropout(p=0.5),
596 | nn.Linear(hidden_dim, projection_dim),
597 | )
598 | else:
599 | print("Using single-layer projection head")
600 | self.proj_head = nn.Sequential(
601 | nn.Linear(input_dim, projection_dim),
602 | nn.LayerNorm(projection_dim),
603 | )
604 | else:
605 | print("NOT using projection head")
606 |
607 | self.device = device
608 |
609 | def forward(self, x, mask):
610 | seqs_tokens = self.process_seqs(x, mask, mask_prob=self.ln_config.mask_prob)
611 |
612 | if self.model_config.receptor_model_name == 'esm3':
613 | outputs = self.esm_lora(sequence_tokens=seqs_tokens['input_ids']).embeddings
614 | outputs = {'last_hidden_state': outputs}
615 | else:
616 | outputs = self.esm_lora(**seqs_tokens)
617 |
618 | seq_embeds = get_sequence_embeddings(seqs_tokens, outputs, is_sep=False, is_cls=False)
619 |
620 | if self.projection_dim:
621 | return self.proj_head(seq_embeds)
622 | else:
623 | return seq_embeds
624 |
625 | def process_seqs(self, seqs, mask, mask_prob=0.15):
626 | '''
627 | seqs: list of epitope sequences
628 | '''
629 | tra_seqs, trb_seqs = seqs
630 | seqs = trb_seqs
631 |
632 | # removing special tokens since epitopes are protein fragments (peptides)
633 | seqs_tokens = self.esm_tokenizer(seqs, return_tensors="pt", add_special_tokens=False, padding=True)
634 |
635 | if mask:
636 | # masking the sequences for training
637 | seqs, attn_mask_indices = apply_masking_seq(seqs, mask_token='', p=mask_prob)
638 | indices_tensor = torch.tensor(attn_mask_indices, dtype=torch.long)
639 | if len(indices_tensor) > 0:
640 | seqs_tokens['attention_mask'][indices_tensor[:, 0], indices_tensor[:, 1]] = 0.
641 |
642 | return seqs_tokens.to(self.device)
643 |
644 |
645 | class TCREncoderInHouse(nn.Module):
646 | def __init__(self, input_dim, projection_dim, ln_cfg, model_config=None, hidden_dim=1024, device='cpu'):
647 | super().__init__()
648 | from .lora import setup_peft_inhouse
649 | from .configs import peft_config_inhouse
650 | import os
651 |
652 | model_ckpt_path = os.getenv('INHOUSE_MODEL_CKPT_PATH')
653 |
654 | self.inhouse_lora, self.inhouse_tokenizer = setup_peft_inhouse(peft_config_inhouse, ln_cfg.no_lora, model_ckpt_path=model_ckpt_path)
655 |
656 | self.ln_config = ln_cfg
657 | self.model_config = model_config
658 | self.projection_dim = projection_dim
659 |
660 | if self.projection_dim:
661 | if hidden_dim:
662 | print("Using multi-layer projection head")
663 | self.proj_head = nn.Sequential(
664 | nn.Linear(input_dim, hidden_dim),
665 | nn.LayerNorm(hidden_dim),
666 | nn.LeakyReLU(),
667 | nn.Dropout(p=0.5),
668 | nn.Linear(hidden_dim, projection_dim),
669 | )
670 | else:
671 | print("Using single-layer projection head")
672 | self.proj_head = nn.Sequential(
673 | nn.Linear(input_dim, projection_dim),
674 | nn.LayerNorm(projection_dim),
675 | )
676 | else:
677 | print("NOT using projection head")
678 |
679 | self.device = device
680 |
681 | def forward(self, x, mask):
682 | seq_tokens = self.process_seqs(x, mask, mask_prob=self.ln_config.mask_prob)
683 |
684 | # feed to InHouse Model
685 | # print("seq tokens input_ids:", seq_tokens['input_ids'])
686 | # print("seq tokens attention_mask:", seq_tokens['attention_mask'])
687 | seq_outputs, _ = self.inhouse_lora(seq_tokens["input_ids"], seq_tokens["attention_mask"])
688 |
689 | # print("seq outputs:", seq_outputs)
690 |
691 | return self.proj_head(seq_outputs)
692 |
693 | def process_seqs(self, seqs, mask, mask_prob=0.15):
694 | tra_seqs, trb_seqs = seqs
695 |
696 | if mask:
697 | tra_seqs_, tra_masks = apply_masking_seq(tra_seqs, mask_token='', p=mask_prob)
698 | trb_seqs_, trb_masks = apply_masking_seq(trb_seqs, mask_token='', p=mask_prob)
699 |
700 | # adjust the tra_masks and trb_masks to the correct indices:
701 | tra_masks = [(n, 1+i) for (n, i) in tra_masks]
702 | trb_masks = [(n, 1+len(tra_seqs[n])+2+i) for (n, i) in trb_masks]
703 |
704 | indices_tensor = torch.tensor(tra_masks + trb_masks, dtype=torch.long)
705 |
706 | tra_seqs, trb_seqs = tra_seqs_, trb_seqs_
707 |
708 | # format the seq strings accordingly to InHouse:
709 | ab_seqs = [self.apply_special_token_formatting(tra_seqs[i], trb_seqs[i]) for i in range(len(tra_seqs))]
710 |
711 | seqs_tokens = self.inhouse_tokenizer(ab_seqs, return_tensors="pt", add_special_tokens=False, padding=True)
712 |
713 | # adjust the attention mask
714 | if mask and len(indices_tensor) > 0:
715 | seqs_tokens['attention_mask'][indices_tensor[:, 0], indices_tensor[:, 1]] = 0.
716 |
717 | return seqs_tokens.to(self.device)
718 |
719 | def apply_special_token_formatting(self, alpha, beta):
720 | '''
721 | Apply RoBERTa style formatting to input:
722 | seq1seq2
723 | '''
724 | return f"{self.inhouse_tokenizer.cls_token}{alpha}{self.inhouse_tokenizer.eos_token}{self.inhouse_tokenizer.eos_token}{beta}{self.inhouse_tokenizer.eos_token}"
725 |
726 | class TCREncoderOneHot(nn.Module):
727 | def __init__(self, input_dim, projection_dim, ln_cfg, model_config, device='cpu'):
728 | super().__init__()
729 |
730 | self.ln_config = ln_cfg
731 | self.projection_dim = projection_dim
732 |
733 | if self.projection_dim:
734 | print("Using single-layer projection head")
735 | self.proj_head = nn.Sequential(
736 | nn.Linear(input_dim, projection_dim),
737 | nn.LayerNorm(projection_dim),
738 | )
739 | else:
740 | assert False, "Projection head must be used with one-hot encoding!"
741 |
742 | # Define the amino acid to index mapping
743 | self.amino_acid_to_index = {
744 | 'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4,
745 | 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9,
746 | 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14,
747 | 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19,
748 | 'X': 20 # Unknown amino acid
749 | }
750 |
751 | self.device = device
752 |
753 | def forward(self, x, mask):
754 | seqs = [seqa + seqb for seqa, seqb in zip(x[0], x[1])]
755 | seqs_onehot = self.create_padded_one_hot_tensor(seqs, len(self.amino_acid_to_index))
756 |
757 | proj_output = self.proj_head(seqs_onehot)
758 |
759 | # average the projected embeddings by seq length:
760 | seq_lens = torch.sum(seqs_onehot, dim=(1, 2))
761 | # Create a mask with shape (batch_size, max_seq_length)
762 | seq_mask = torch.arange(proj_output.size(1)).unsqueeze(0).to(self.device) < seq_lens.unsqueeze(-1)
763 | seq_mask = seq_mask.unsqueeze(2) # Shape (batch_size, max_seq_length, 1)
764 | # Sum the embeddings across the sequence length dimension using the mask
765 | masked_embeddings = proj_output * seq_mask
766 | sum_embeddings = masked_embeddings.sum(dim=1)
767 |
768 | # Divide by the true sequence lengths to get the average
769 | avg_embeddings = sum_embeddings / seq_lens.unsqueeze(1)#.to(embeddings.device)
770 |
771 | return avg_embeddings
772 |
773 |
774 | # @staticmethod
775 | def encode_amino_acid_sequence(self, sequence):
776 | """ Convert an amino acid sequence to a list of indices. """
777 | return [self.amino_acid_to_index[aa] for aa in sequence]
778 |
779 | # @staticmethod
780 | def one_hot_encode_sequence(self, sequence, vocab_size):
781 | """ One-hot encode a single sequence. """
782 | encoding = np.zeros((len(sequence), vocab_size), dtype=int)
783 | for idx, char in enumerate(sequence):
784 | encoding[idx, char] = 1
785 | return encoding
786 |
787 | # @staticmethod
788 | def pad_sequences(self, encoded_sequences, max_length):
789 | """ Pad the encoded sequences to the maximum length. """
790 | padded_sequences = []
791 | for seq in encoded_sequences:
792 | padded_seq = np.pad(seq, ((0, max_length - len(seq)), (0, 0)), mode='constant', constant_values=0)
793 | padded_sequences.append(padded_seq)
794 | return np.array(padded_sequences)
795 |
796 | # @staticmethod
797 | def create_padded_one_hot_tensor(self, sequences, vocab_size):
798 | """ Convert a batch of sequences to a padded one-hot encoding tensor. """
799 | # Encode and one-hot encode each sequence
800 | encoded_sequences = [self.one_hot_encode_sequence(self.encode_amino_acid_sequence(seq), vocab_size) for seq in sequences]
801 |
802 | # Determine the maximum sequence length
803 | max_length = max(len(seq) for seq in sequences)
804 |
805 | # Pad the sequences
806 | padded_sequences = self.pad_sequences(encoded_sequences, max_length)
807 |
808 | # Convert to a PyTorch tensor
809 | padded_tensor = torch.tensor(padded_sequences, dtype=torch.float32)
810 |
811 | return padded_tensor.to(self.device)
--------------------------------------------------------------------------------