├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── antiberty_pytorch ├── __init__.py ├── antiberty_pytorch.py ├── data.py └── train.py ├── data ├── download.smk └── manifest_230324.csv ├── img ├── antiberty_num_params.png ├── banner.png └── training.png ├── note └── oas_data_example.ipynb ├── setup.py └── tokenizer └── ProteinTokenizer ├── added_tokens.json ├── special_tokens_map.json ├── tokenizer_config.json └── vocab.txt /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/sequences 2 | data/*.sh 3 | 4 | **/checkpoints 5 | lightning_logs/ 6 | note/.ipynb_checkpoints/ 7 | *.egg-info/ 8 | **/__pycache__/ 9 | wandb/ 10 | **/.snakemake 11 | tuning/results 12 | **/*.pt 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 이도훈 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # antiberty-pytorch 2 | [](https://github.com/Lightning-AI/lightning) 3 | 4 | 5 |  6 | 7 | ## installation 8 | ```bash 9 | $ pip install antiberty-pytorch 10 | ``` 11 | 12 | ## Reproduction status 13 | 14 | ### Number of parameters 15 | 16 |  17 | 18 | This version of AntiBERTy implementation has 25,759,769 parameters in total, and it matches well with the approx. 26M parameters specified in the paper (See above). 19 | 20 | ### Training with 1% of the entire OAS data 21 | 22 | I've reproduced AntiBERTy training with about tiny ~1% of the entire OAS data (`batch_size=16`, `mask_prob=0.15`) and observed pretty reasonable loss decrease, though it's not for validation set. 23 | The training log can be found [here](https://api.wandb.ai/links/dohlee/qqzxgo1v). 24 | 25 |  26 | 27 | ## Observed Antibody Sequences (OAS) dataset preparation pipeline 28 | 29 | I wrote a `snakemake` pipeline in the directory `data` to automate the dataset prep process. It will download metadata from [OAS](https://opig.stats.ox.ac.uk/webapps/oas/oas) and extract lists of sequences. The pipeline can be run as follows: 30 | 31 | ```bash 32 | $ cd data 33 | $ snakemake -s download.smk -j1 34 | ``` 35 | 36 | *NOTE: Only 3% of the entire OAS sequences were downloaded for now due to space and computational cost. (83M sequences, 31GB)* 37 | 38 | ## Citation 39 | ```bibtex 40 | @article{ruffolo2021deciphering, 41 | title = {Deciphering antibody affinity maturation with language models and weakly supervised learning}, 42 | author = {Ruffolo, Jeffrey A and Gray, Jeffrey J and Sulam, Jeremias}, 43 | journal = {arXiv}, 44 | year= {2021} 45 | } 46 | ``` 47 | -------------------------------------------------------------------------------- /antiberty_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from antiberty_pytorch.antiberty_pytorch import AntiBERTy 2 | from antiberty_pytorch.data import OASDataset -------------------------------------------------------------------------------- /antiberty_pytorch/antiberty_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | import pytorch_lightning as pl 4 | import transformers 5 | 6 | 7 | class AntiBERTy(pl.LightningModule): 8 | def __init__(self): 9 | super().__init__() 10 | config = transformers.BertConfig( 11 | vocab_size=25, 12 | hidden_size=512, 13 | num_hidden_layers=8, 14 | num_attention_heads=8, 15 | intermediate_size=2048, 16 | max_position_embeddings=512, 17 | ) 18 | self.bert = transformers.BertForMaskedLM(config) 19 | 20 | def forward(self, input_ids, labels=None): 21 | return self.bert(input_ids, labels=labels) 22 | 23 | def training_step(self, batch, batch_idx): 24 | input_ids, labels = batch["input_ids"], batch["labels"] 25 | out = self(input_ids=input_ids, labels=labels) 26 | 27 | self.log_dict({"loss": out.loss}, prog_bar=True, on_step=True, on_epoch=True) 28 | return out.loss 29 | 30 | def validation_step(self, batch, batch_idx): 31 | input_ids, labels = batch["input_ids"], batch["labels"] 32 | out = self(input_ids=input_ids, labels=labels) 33 | 34 | self.log_dict({"val/loss": out.loss}, prog_bar=True, on_step=True, on_epoch=True) 35 | return out.loss 36 | 37 | def configure_optimizers(self): 38 | return optim.AdamW(self.parameters(), lr=1e-5) 39 | 40 | 41 | if __name__ == "__main__": 42 | from transformers import DataCollatorForLanguageModeling 43 | from transformers import BertTokenizer 44 | from torch.utils.data import DataLoader 45 | 46 | from .data import OASDataset 47 | 48 | tokenizer = BertTokenizer.from_pretrained("tokenizer/ProteinTokenizer") 49 | collator = DataCollatorForLanguageModeling( 50 | tokenizer=tokenizer, 51 | mlm=True, 52 | mlm_probability=0.5, 53 | ) 54 | data = ["ACGACGACGACGAGC", "CGGCGAGCGAAG", "CGACGACGACAGCGACGACGAGCAGCAG"] 55 | 56 | ds = OASDataset(data, tokenizer, max_len=512) 57 | loader = DataLoader(ds, batch_size=2, collate_fn=collator) 58 | 59 | model = AntiBERTy() 60 | 61 | for batch in loader: 62 | print(batch["input_ids"]) 63 | print(batch["labels"]) 64 | 65 | out = model(batch["input_ids"], labels=batch["labels"]) 66 | print(out.loss) 67 | 68 | print("# Parameters:", sum(p.numel() for p in model.parameters())) 69 | -------------------------------------------------------------------------------- /antiberty_pytorch/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch.utils.data import Dataset 5 | 6 | class OASDataset(Dataset): 7 | def __init__(self, data, tokenizer, max_len): 8 | self.data = data 9 | self.tokenizer = tokenizer 10 | self.max_len = max_len 11 | 12 | def __len__(self): 13 | return len(self.data) 14 | 15 | def __getitem__(self, index): 16 | text = self.data[index] 17 | encoding = self.tokenizer(text, truncation=True, max_length=self.max_len) 18 | return encoding['input_ids'] 19 | 20 | if __name__ == '__main__': 21 | from transformers import DataCollatorForLanguageModeling 22 | from transformers import BertTokenizer 23 | from torch.utils.data import DataLoader 24 | 25 | tokenizer = BertTokenizer.from_pretrained('tokenizer/ProteinTokenizer') 26 | collator = DataCollatorForLanguageModeling( 27 | tokenizer=tokenizer, 28 | mlm=True, 29 | mlm_probability=0.5, 30 | ) 31 | data = ['ACGACGACGACGAGC', 'CGGCGAGCGAAG', 'CGACGACGACAGCGACGACGAGCAGCAG'] 32 | 33 | ds = OASDataset(data, tokenizer, max_len=512) 34 | loader = DataLoader(ds, batch_size=2, collate_fn=collator) 35 | 36 | for batch in loader: 37 | print(batch['input_ids']) 38 | print(batch['labels']) 39 | 40 | print(batch['input_ids'].shape) 41 | print(batch['labels'].shape) 42 | break -------------------------------------------------------------------------------- /antiberty_pytorch/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | import pytorch_lightning as pl 5 | 6 | from torch.utils.data import DataLoader 7 | from antiberty_pytorch import AntiBERTy, OASDataset 8 | from transformers import ( 9 | DataCollatorForLanguageModeling, 10 | BertTokenizer, 11 | ) 12 | 13 | from pytorch_lightning.loggers import WandbLogger 14 | 15 | 16 | def parse_argument(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("-i", "--input", required=True) 19 | parser.add_argument("-o", "--output", required=True) 20 | parser.add_argument("-a", "--accelerator", default="gpu") 21 | parser.add_argument("-b", "--batch-size", type=int, default=16) 22 | parser.add_argument("-p", "--mask-prob", type=float, default=0.15) 23 | return parser.parse_args() 24 | 25 | 26 | def main(): 27 | torch.set_float32_matmul_precision("high") # Trade-off precision for speed. 28 | 29 | args = parse_argument() 30 | wandb_logger = WandbLogger(project="antiberty-pytorch", entity="dohlee") 31 | 32 | tokenizer = BertTokenizer.from_pretrained("tokenizer/ProteinTokenizer") 33 | collator = DataCollatorForLanguageModeling( 34 | tokenizer=tokenizer, 35 | mlm=True, 36 | mlm_probability=args.mask_prob, 37 | ) 38 | 39 | seqs = [] 40 | for fp in os.listdir(args.input): 41 | with open(os.path.join(args.input, fp)) as f: 42 | seqs += f.read().splitlines() 43 | 44 | train_seqs, val_seqs = seqs[: int(len(seqs) * 0.99)], seqs[int(len(seqs) * 0.99) :] 45 | train_ds = OASDataset(train_seqs, tokenizer, max_len=512) 46 | val_ds = OASDataset(val_seqs, tokenizer, max_len=512) 47 | 48 | train_loader = DataLoader( 49 | train_ds, 50 | batch_size=args.batch_size, 51 | collate_fn=collator, 52 | num_workers=4, 53 | shuffle=True, 54 | ) 55 | val_loader = DataLoader( 56 | val_ds, 57 | batch_size=args.batch_size, 58 | collate_fn=collator, 59 | num_workers=4, 60 | shuffle=False, 61 | ) 62 | 63 | model = AntiBERTy() 64 | trainer = pl.Trainer( 65 | logger=wandb_logger, 66 | accelerator=args.accelerator, 67 | devices=1, 68 | max_epochs=-1, 69 | ) 70 | trainer.fit(model, train_loader, val_loader) 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /data/download.smk: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | manifest = pd.read_csv('manifest_230324.csv') 4 | 5 | # Randomly sample 3% of the dataset. 6 | manifest = manifest.sample(frac=0.03, random_state=42) 7 | 8 | f2type = {r.filename:r.seq_type for r in manifest.to_records()} 9 | f2study = {r.filename:r.study for r in manifest.to_records()} 10 | 11 | filenames = manifest.filename.values 12 | ALL = expand('sequences/{filename}.list', filename=filenames) 13 | 14 | rule all: 15 | input: ALL 16 | 17 | rule download: 18 | output: 19 | 'sequences/{filename}.list' 20 | params: 21 | type = lambda wc: f2type[wc.filename], 22 | study = lambda wc: f2study[wc.filename], 23 | shell: 24 | 'wget -qO- ' 25 | 'http://opig.stats.ox.ac.uk/webapps/ngsdb/{params.type}/{params.study}/csv/{wildcards.filename}.csv.gz | ' 26 | 'gunzip -c | ' 27 | 'tail -n+3 | ' 28 | 'cut -d, -f1 > {output}' 29 | -------------------------------------------------------------------------------- /img/antiberty_num_params.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dohlee/antiberty-pytorch/a6bc5f84e97454068aed13b2ac4be40144cc9e2a/img/antiberty_num_params.png -------------------------------------------------------------------------------- /img/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dohlee/antiberty-pytorch/a6bc5f84e97454068aed13b2ac4be40144cc9e2a/img/banner.png -------------------------------------------------------------------------------- /img/training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dohlee/antiberty-pytorch/a6bc5f84e97454068aed13b2ac4be40144cc9e2a/img/training.png -------------------------------------------------------------------------------- /note/oas_data_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 4, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "meta = pd.read_csv('/data/project/dohoon/antiberty-pytorch/ERR2843421_Heavy_IGHA.csv', skiprows=1)" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 9, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "data": { 28 | "text/plain": [ 29 | "25363" 30 | ] 31 | }, 32 | "execution_count": 9, 33 | "metadata": {}, 34 | "output_type": "execute_result" 35 | } 36 | ], 37 | "source": [ 38 | "len(meta)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 14, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "with open('../ERR2843421.list', 'w') as outFile:\n", 48 | " print('\\n'.join(meta.sequence.values), file=outFile)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 12, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "data": { 58 | "text/html": [ 59 | "
\n", 77 | " | sequence | \n", 78 | "locus | \n", 79 | "stop_codon | \n", 80 | "vj_in_frame | \n", 81 | "v_frameshift | \n", 82 | "productive | \n", 83 | "rev_comp | \n", 84 | "complete_vdj | \n", 85 | "v_call | \n", 86 | "d_call | \n", 87 | "... | \n", 88 | "cdr3_start | \n", 89 | "cdr3_end | \n", 90 | "np1 | \n", 91 | "np1_length | \n", 92 | "np2 | \n", 93 | "np2_length | \n", 94 | "c_region | \n", 95 | "Redundancy | \n", 96 | "ANARCI_numbering | \n", 97 | "ANARCI_status | \n", 98 | "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", 103 | "AGCTCTGGGAGAGGAGCCCCAGCCCTGAAATTCCCAAGTGTTTCCA... | \n", 104 | "H | \n", 105 | "F | \n", 106 | "T | \n", 107 | "F | \n", 108 | "T | \n", 109 | "T | \n", 110 | "T | \n", 111 | "IGHV3-43*02 | \n", 112 | "IGHD3-16*01 | \n", 113 | "... | \n", 114 | "423.0 | \n", 115 | "452.0 | \n", 116 | "T | \n", 117 | "1.0 | \n", 118 | "TTAA | \n", 119 | "4.0 | \n", 120 | "CATCCCCGACCAGCCCCAAGGTCTTCCCG | \n", 121 | "4 | \n", 122 | "{'fwh1': {'1 ': 'E', '2 ': 'V', '3 ': 'Q', '4 ... | \n", 123 | "|Deletions: 10, 73|||| | \n", 124 | "
1 | \n", 127 | "AGCTCTGGGAGAGGAGCCCCAGCCCTGAGATTCCCAGGTGTTTCCA... | \n", 128 | "H | \n", 129 | "F | \n", 130 | "T | \n", 131 | "F | \n", 132 | "T | \n", 133 | "T | \n", 134 | "F | \n", 135 | "IGHV3-9*01 | \n", 136 | "IGHD5/OR15-5a*01 | \n", 137 | "... | \n", 138 | "426.0 | \n", 139 | "470.0 | \n", 140 | "CCAGAGGGA | \n", 141 | "9.0 | \n", 142 | "CTGGG | \n", 143 | "5.0 | \n", 144 | "TGCATCCCCGACCAGCCCCAAGGTCTTCCCG | \n", 145 | "1 | \n", 146 | "{'fwh1': {'1 ': 'E', '2 ': 'V', '3 ': 'Q', '4 ... | \n", 147 | "Unusual residue: X|Deletions: 10, 73|||| | \n", 148 | "
2 | \n", 151 | "GGCTTTCTGAGAGTCATGGATCTCATGTGCAAGAAAATGAAGCACC... | \n", 152 | "H | \n", 153 | "F | \n", 154 | "T | \n", 155 | "F | \n", 156 | "T | \n", 157 | "T | \n", 158 | "T | \n", 159 | "IGHV4-39*01 | \n", 160 | "NaN | \n", 161 | "... | \n", 162 | "382.0 | \n", 163 | "399.0 | \n", 164 | "GGCCCCG | \n", 165 | "7.0 | \n", 166 | "NaN | \n", 167 | "NaN | \n", 168 | "CATCCCCGACCAGCCCCAAGGTCTTCCCG | \n", 169 | "2 | \n", 170 | "{'fwh1': {'1 ': 'Q', '2 ': 'L', '3 ': 'Q', '4 ... | \n", 171 | "|Deletions: 10, 55, 73|||| | \n", 172 | "
3 rows × 97 columns
\n", 176 | "\n", 280 | " | seq_type | \n", 281 | "study | \n", 282 | "filename | \n", 283 | "
---|---|---|---|
0 | \n", 288 | "unpaired | \n", 289 | "Eliyahu_2018 | \n", 290 | "ERR2843400_Heavy_IGHE | \n", 291 | "
1 | \n", 294 | "unpaired | \n", 295 | "Eliyahu_2018 | \n", 296 | "ERR2843418_Heavy_IGHA | \n", 297 | "
2 | \n", 300 | "unpaired | \n", 301 | "Eliyahu_2018 | \n", 302 | "ERR2843418_Heavy_Bulk | \n", 303 | "
\n", 378 | " | seq_type | \n", 379 | "study | \n", 380 | "filename | \n", 381 | "
---|---|---|---|
12081 | \n", 386 | "unpaired | \n", 387 | "Galson_2015a | \n", 388 | "SRR3099401_Heavy_IGHM | \n", 389 | "
291 | \n", 392 | "unpaired | \n", 393 | "Schultheiss_2020 | \n", 394 | "ERR4337035_Heavy_Bulk | \n", 395 | "
10814 | \n", 398 | "unpaired | \n", 399 | "Briney_2019 | \n", 400 | "SRR8283768_Heavy_IGHD | \n", 401 | "
3647 | \n", 404 | "unpaired | \n", 405 | "Soto_2019 | \n", 406 | "SRR8365361_1_Heavy_IGHM | \n", 407 | "
1372 | \n", 410 | "unpaired | \n", 411 | "Galson_2015 | \n", 412 | "SRR3990897_Heavy_Bulk | \n", 413 | "
... | \n", 416 | "... | \n", 417 | "... | \n", 418 | "... | \n", 419 | "
4936 | \n", 422 | "unpaired | \n", 423 | "Kim_2020 | \n", 424 | "SRR12326744_1_Heavy_IGHA | \n", 425 | "
6817 | \n", 428 | "unpaired | \n", 429 | "Ellebedy_2016 | \n", 430 | "SRR3620118_Heavy_IGHM | \n", 431 | "
15245 | \n", 434 | "unpaired | \n", 435 | "Waltari_2018 | \n", 436 | "SRR5811779_1_Heavy_IGHE | \n", 437 | "
8768 | \n", 440 | "unpaired | \n", 441 | "Chen_2020 | \n", 442 | "SRR11937625_1_Light_Bulk | \n", 443 | "
1624 | \n", 446 | "unpaired | \n", 447 | "Kuri-Cervantes_2020 | \n", 448 | "SRR12081538_Heavy_Bulk | \n", 449 | "
1579 rows × 3 columns
\n", 453 | "