├── LICENSE ├── README.md ├── antibody_scripts ├── antibody_finetune.py ├── antibody_finetune.sh ├── antibody_run.sh ├── antibody_train.py ├── humab25_eval.py ├── patent_eval.py ├── sample.py └── sample_for_anti_cdr.py ├── configs ├── antibody_finetune.yml ├── antibody_test.yml ├── antibody_train.yml ├── heavy_test.yml ├── heavy_train.yml └── training_nano_framework.yml ├── data ├── antibody_eval_data │ ├── HuAb348_data │ │ ├── humanization_pair_data_filter.csv │ │ ├── only_h_lab.fasta │ │ ├── only_l_lab.fasta │ │ ├── only_lk_lab.fasta │ │ ├── only_ll_lab.fasta │ │ ├── sample_t20_mouse_score.csv │ │ └── sample_t20_score.csv │ ├── Humab25_data │ │ ├── parental_mouse.csv │ │ ├── sample_experimental_t20_score.csv │ │ └── sample_mouse_t20_score.csv │ └── putative_data │ │ └── humanization_pair152.csv ├── fasta_file │ ├── 7k9i.fasta │ └── 7x2l.fasta ├── nanobody_eval_data │ ├── abnativ_select_vhh.csv │ └── nanobert_exp.csv └── train_sh_file │ ├── camel_data │ └── nano_download.sh │ ├── heavy_data │ └── unpair_download_selected.sh │ └── pair_data │ └── bulk_download.sh ├── dataset ├── abnativ_alignment │ ├── __init__.py │ ├── aho_consensus.py │ ├── align_and_clean.py │ ├── blossum.py │ ├── csv_dict.py │ ├── misc.py │ ├── mybio.py │ ├── parse_pdb.py │ ├── plotter.py │ └── structs.py ├── oas_dataset.py ├── oas_pair_dataset_new.py ├── oas_unpair_dataset_new.py └── preprocess.py ├── doc └── process.svg ├── environment.yaml ├── evaluation ├── ABLSTM_eval.py ├── Biophi_eval.py ├── T20_eval.py ├── Zscore_eval.py └── humab_eval.py ├── model ├── encoder │ ├── __init__.py │ ├── cross_attention.py │ ├── diffusion.py │ ├── model.py │ ├── modules.py │ ├── multihead_attention.py │ └── rotary_embedding.py └── nanoencoder │ ├── abnativ_model.py │ ├── abnativ_onehot.py │ ├── abnativ_scoring.py │ ├── abnativ_utils.py │ ├── abnativ_vq.py │ ├── attention.py │ └── model.py ├── nanobody_scripts ├── nano_eval.py ├── nanofinetune.py ├── nanofinetune_run.sh ├── nanosample.py ├── nanotrain.py ├── nanotrain_run.sh └── sample_for_nano_cdr.py ├── start_docker.sh └── utils ├── anti_numbering.py ├── evaluation.py ├── loss.py ├── misc.py ├── tokenizer.py ├── train_utils.py └── warmup.py /LICENSE: -------------------------------------------------------------------------------- 1 | # PolyForm Noncommercial License 1.0.0 2 | 3 | 4 | 5 | ## Acceptance 6 | 7 | In order to get any license under these terms, you must agree 8 | to them as both strict obligations and conditions to all 9 | your licenses. 10 | 11 | ## Copyright License 12 | 13 | The licensor grants you a copyright license for the 14 | software to do everything you might do with the software 15 | that would otherwise infringe the licensor's copyright 16 | in it for any permitted purpose. However, you may 17 | only distribute the software according to [Distribution 18 | License](#distribution-license) and make changes or new works 19 | based on the software according to [Changes and New Works 20 | License](#changes-and-new-works-license). 21 | 22 | ## Distribution License 23 | 24 | The licensor grants you an additional copyright license 25 | to distribute copies of the software. Your license 26 | to distribute covers distributing the software with 27 | changes and new works permitted by [Changes and New Works 28 | License](#changes-and-new-works-license). 29 | 30 | ## Notices 31 | 32 | You must ensure that anyone who gets a copy of any part of 33 | the software from you also gets a copy of these terms or the 34 | URL for them above, as well as copies of any plain-text lines 35 | beginning with `Required Notice:` that the licensor provided 36 | with the software. For example: 37 | 38 | > Required Notice: Copyright Yoyodyne, Inc. (http://example.com) 39 | 40 | ## Changes and New Works License 41 | 42 | The licensor grants you an additional copyright license to 43 | make changes and new works based on the software for any 44 | permitted purpose. 45 | 46 | ## Patent License 47 | 48 | The licensor grants you a patent license for the software that 49 | covers patent claims the licensor can license, or becomes able 50 | to license, that you would infringe by using the software. 51 | 52 | ## Noncommercial Purposes 53 | 54 | Any noncommercial purpose is a permitted purpose. 55 | 56 | ## Personal Uses 57 | 58 | Personal use for research, experiment, and testing for 59 | the benefit of public knowledge, personal study, private 60 | entertainment, hobby projects, amateur pursuits, or religious 61 | observance, without any anticipated commercial application, 62 | is use for a permitted purpose. 63 | 64 | ## Noncommercial Organizations 65 | 66 | Use by any charitable organization, educational institution, 67 | public research organization, public safety or health 68 | organization, environmental protection organization, 69 | or government institution is use for a permitted purpose 70 | regardless of the source of funding or obligations resulting 71 | from the funding. 72 | 73 | ## Fair Use 74 | 75 | You may have "fair use" rights for the software under the 76 | law. These terms do not limit them. 77 | 78 | ## No Other Rights 79 | 80 | These terms do not allow you to sublicense or transfer any of 81 | your licenses to anyone else, or prevent the licensor from 82 | granting licenses to anyone else. These terms do not imply 83 | any other licenses. 84 | 85 | ## Patent Defense 86 | 87 | If you make any written claim that the software infringes or 88 | contributes to infringement of any patent, your patent license 89 | for the software granted under these terms ends immediately. If 90 | your company makes such a claim, your patent license ends 91 | immediately for work on behalf of your company. 92 | 93 | ## Violations 94 | 95 | The first time you are notified in writing that you have 96 | violated any of these terms, or done anything with the software 97 | not covered by your licenses, your licenses can nonetheless 98 | continue if you come into full compliance with these terms, 99 | and take practical steps to correct past violations, within 100 | 32 days of receiving notice. Otherwise, all your licenses 101 | end immediately. 102 | 103 | ## No Liability 104 | 105 | ***As far as the law allows, the software comes as is, without 106 | any warranty or condition, and the licensor will not be liable 107 | to you for any damages arising out of these terms or the use 108 | or nature of the software, under any kind of legal claim.*** 109 | 110 | ## Definitions 111 | 112 | The **licensor** is the individual or entity offering these 113 | terms, and the **software** is the software the licensor makes 114 | available under these terms. 115 | 116 | **You** refers to the individual or entity agreeing to these 117 | terms. 118 | 119 | **Your company** is any legal entity, sole proprietorship, 120 | or other kind of organization that you work for, plus all 121 | organizations that have control over, are under the control of, 122 | or are under common control with that organization. **Control** 123 | means ownership of substantially all the assets of an entity, 124 | or the power to direct its management and policies by vote, 125 | contract, or otherwise. Control can be direct or indirect. 126 | 127 | **Your licenses** are all the licenses granted to you for the 128 | software under these terms. 129 | 130 | **Use** means anything you do with the software requiring one 131 | of your licenses. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HuDiff 2 | 3 | ## Humanization process. 4 | ![pipeline](doc/process.svg) 5 | 6 | 7 | This package offers an implementation of the training details for HuDiff-Ab and HuDiff-Nb, as well as the inference pipeline for both models. Additionally, we provide the following: 8 | 9 | 1. A fine-tuned antibody humanization diffusion model designed for humanizing mouse antibodies, and a fine-tuned nanobody humanization diffusion model for humanizing nanobodies. 10 | 2. The training set we have preprocessed and the test set we constructed in our paper. 11 | 12 | Any publication that presents findings derived from using this source code or the model parameters should cite the HuDiff paper. 13 | 14 | For any inquiries, please contact the HuDiff team at fm924507@gmail.com. 15 | 16 | ## Conda Environment 17 | 18 | You can use the environment.yaml file to construct the environment. 19 | ``` 20 | conda env create -f environment.yaml 21 | ``` 22 | 23 | ## HuDiff-Ab 24 | ### Data Preparation 25 | We have uploaded the lmdb file of the paired datasets (see [Hugging Face](https://huggingface.co/cloud77/HuDiff/tree/main) for HuDiff), and you can use these datasets directly to train the HuDiff-Ab model. Or you can check the download list we provide. 26 | ### Pre-training 27 | ``` 28 | antibody_scripts/antibody_run.sh 29 | ``` 30 | Ensure that you modify the environment variables `config_path`, `data_path`, and `log_path` before executing the provided bash script. The `data_path` should be set to the directory path of the provided LMDB files. 31 | 32 | ### Fine-tuning 33 | ``` 34 | antibody_scripts/antibody_finetune.sh 35 | ``` 36 | Ensure that you modify the environment variables `config_path`, `data_path`, `log_path`, and `ckpt_path`. The `ckpt_path` should specify the checkpoint to be fine-tuned. Note that we do not provide the pre-trained checkpoint in our released data, so you will need to train it from scratch. However, we do provide the checkpoints after fine-tuning. The `data_path` remains the same as in the pre-training stage, utilizing the provided LMDB files. 37 | 38 | ### Evaluation 39 | Follow these steps to evaluate and execute the provided code for the antibody analysis: 40 | 1. Install the required library, OAsis, by cloning its GitHub repository: [BioPhi](https://github.com/Merck/BioPhi) 41 | 2. For the patent dataset collected by our lab, run the following command to create samples: 42 | ``` 43 | python antibody_scripts/sample.py 44 | --ckpt --ckpt checkpoints/antibody/hudiffab.pt 45 | --data_fpath ./data/antibody_eval_data/HuAb348_data/humanization_pair_data_filter.csv 46 | ``` 47 | 3. Then, evaluate the sample using the following command: 48 | ``` 49 | python antibody_scripts/patent_eval.py 50 | ``` 51 | 4. For the public 25 pair antibody dataset, execute the following commands: 52 | ``` 53 | python antibody_scripts/sample.py 54 | --ckpt --ckpt checkpoints/antibody/hudiffab.pt 55 | --data_fpath ./data/antibody_eval_data/Humab25_data/parental_mouse.csv 56 | ``` 57 | 5. Finally, evaluate the public dataset using the following command: 58 | ``` 59 | python antibody_scripts/humab25_eval.py 60 | ``` 61 | For the putative humanization process, although we do not provide the evaluation scripts, you can still follow the same steps outlined above on your own. 62 | 63 | ### Humanization 64 | There are two ways to humanize an antibody, either by using a complex fasta file containing antigen-antibody sequences or by providing individual heavy and light chain sequences. 65 | ``` 66 | python antibody_scripts/sample_for_anti_cdr.py 67 | --ckpt checkpoints/antibody/hudiffab.pt 68 | --anti_complex_fasta data/fasta_file/7k9i.fasta 69 | ``` 70 | or providing individual heavy and light chain sequences 71 | ``` 72 | python antibody_scripts/sample_for_anti_cdr.py 73 | --ckpt checkpoints/antibody/hudiffab.pt 74 | --heavy_seq HEAVY_SEQUENCE 75 | --light_seq LIGHT_SEQUENCE 76 | ``` 77 | Keep in mind that you need to replace 7k9i.fasta, HEAVY_SEQUENCE and LIGHT_SEQUENCE with the chain sequences, respectively. 78 | 79 | 80 | ## HuDiff-Nb 81 | ### Data Preparation 82 | We have uploaded all LMDB files from the data preprocessing to [Hugging Face](https://huggingface.co/cloud77/HuDiff/tree/main) for HuDiff. If you wish to process the training dataset from scratch, download the heavy chain files (a download list is provided). 83 | ### Pre-training 84 | The provided command can be directly executed for training. However, it is essential to specify the environment variables beforehand. 85 | ``` 86 | nanobody_scripts/nanotrain_run.sh 87 | ``` 88 | Please note that you need to modify the paths for both the data file (`unpair_data_path`) and the configuration file (`config_path`). 89 | ### Fine-tuning 90 | After establishing a pre-trained model, it must be selected for fine-tuning. Modify the YAML configuration file to specify the path to the pre-trained checkpoints. Before fine-tuning, install the [AbNatiV](https://gitlab.developers.cam.ac.uk/ch/sormanni/abnativ) models, and ensure that the AbNatiV models are specified in the configuration YAML file. 91 | ``` 92 | nanobody_scripts/nanofinetune_run.sh 93 | ``` 94 | ### Evaluation 95 | This script provides different sampling methods, which can be customized, such as replacing the checkpoint with your trained model. 96 | ``` 97 | python nanobody_scripts/nanosample.py 98 | --ckpt checkpoints/nanobody/hudiffnb.pt 99 | --data_fpath data/nanobody_eval_data/abnativ_select_vhh.csv 100 | --model pretrain 101 | --inpaint_sample False 102 | ``` 103 | ``` 104 | python nanobody_scripts/nanosample.py 105 | --ckpt checkpoints/nanobody/hudiffnb.pt 106 | --data_fpath data/nanobody_eval_data/abnativ_select_vhh.csv 107 | --model finetune_vh 108 | --inpaint_sample True 109 | ``` 110 | After sampling, we can use the following script to evaluate the sampling results. 111 | ``` 112 | python nanobody_scripts/nano_eval.py # Need to specific the path of sample. 113 | ``` 114 | ### Humanization 115 | Using the fasta file of the nanobody, our model can humanize it. If you require a higher degree of humanization, consider increasing the batch size or the number of samplings. 116 | ``` 117 | python nanobody_scripts/sample_for_nano_cdr.py 118 | --ckpt checkpoints/nanobody/hudiffnb.pt 119 | --nano_complex_fasta data/fasta_file/7x2l.fasta 120 | --model finetune_vh 121 | --inpaint_sample True 122 | ``` 123 | 124 | # Citing HuDiff 125 | If you use HuDiff in your research, please cite our paper 126 | ```BibTex 127 | @article{ma2024adaptive, 128 | title={An adaptive autoregressive diffusion approach to design active humanized antibody and nanobody}, 129 | author={Ma, Jian and Wu, Fandi and Xu, Tingyang and Xu, Shaoyong and Liu, Wei and Yan, Divin and Bai, Qifeng and Yao, Jianhua}, 130 | journal={bioRxiv}, 131 | year={2024}, 132 | publisher={Cold Spring Harbor Laboratory} 133 | } 134 | ``` -------------------------------------------------------------------------------- /antibody_scripts/antibody_finetune.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=3 2 | echo "Using GPU: $CUDA_VISIBLE_DEVICES" 3 | PAIR_DATA_PATH='data/release_data_dir/oas_pair_mouse_data' 4 | CONFIG_PATH='configs/antibody_finetune.yml' 5 | LOG_PATH='tmp/antibody_finetune_log/' 6 | CKPT_PATH='antibody/pretrain_antibody.pt' 7 | DATA_VERSION='filter' 8 | export CONSIDER_MOUSE=True 9 | python antibody_scripts/antibody_finetune.py \ 10 | --pair_mouse_data_path $PAIR_DATA_PATH \ 11 | --config_path $CONFIG_PATH \ 12 | --log_path $LOG_PATH \ 13 | --ckpt_path $CKPT_PATH \ 14 | --data_version $DATA_VERSION \ 15 | --consider_mouse $CONSIDER_MOUSE \ -------------------------------------------------------------------------------- /antibody_scripts/antibody_run.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=3 2 | echo "Using GPU: $CUDA_VISIBLE_DEVICES" 3 | PAIR_DATA_PATH='release_data_dir/oas_pair_human_data' 4 | CONFIG_PATH='configs/antibody_train.yml' 5 | LOG_PATH='tmp/antibody_pretrain_log' 6 | DATA_VERSION='filter' 7 | python antibody_scripts/antibody_train.py \ 8 | --pair_data_path $PAIR_DATA_PATH \ 9 | --config_path $CONFIG_PATH \ 10 | --log_path $LOG_PATH \ 11 | --data_version $DATA_VERSION \ 12 | -------------------------------------------------------------------------------- /antibody_scripts/sample_for_anti_cdr.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | import argparse 7 | import pandas as pd 8 | from abnumber import Chain 9 | from anarci import anarci, number 10 | from copy import deepcopy 11 | import re 12 | from Bio import SeqIO 13 | import sys 14 | current_dir = os.path.dirname(os.path.dirname(__file__)) 15 | sys.path.append(current_dir) 16 | 17 | 18 | from dataset.preprocess import (HEAVY_POSITIONS_dict, LIGHT_POSITIONS_dict, 19 | HEAVY_CDR_INDEX, LIGHT_CDR_INDEX) 20 | from dataset.oas_pair_dataset_new import light_pad_cdr, HEAVY_REGION_INDEX, LIGHT_REGION_INDEX 21 | from sample import (get_pad_seq, get_input_element, batch_input_element, 22 | batch_equal_input_element 23 | ) 24 | from utils.tokenizer import Tokenizer 25 | from utils.train_utils import model_selected 26 | from utils.misc import get_new_log_dir, get_logger, seed_all 27 | 28 | 29 | REGION_LENGTH = (26, 12, 17, 10, 38, 30, 11) 30 | 31 | def compare_length(length_list): 32 | small = True 33 | for i, lg in enumerate(length_list): 34 | if lg <= REGION_LENGTH[i]: 35 | continue 36 | else: 37 | small = False 38 | return small 39 | 40 | 41 | def get_diff_region_aa_seq(raw_seq, length_list): 42 | split_aa_seq_list = [] 43 | start_lg = 0 44 | for lg in length_list: 45 | end_lg = start_lg + lg 46 | aa_seq = raw_seq[start_lg:end_lg] 47 | split_aa_seq_list.append(aa_seq) 48 | start_lg = end_lg 49 | assert ''.join(split_aa_seq_list) == raw_seq, 'Split length has wrong.' 50 | return split_aa_seq_list 51 | 52 | 53 | def get_h_l_seq_from_fasta(fpath): 54 | """ 55 | Split the heavy and light chain from the raw fasta file. 56 | :param fpath: the raw fasta file path. 57 | :return: heavy sequence, light sequence. 58 | """ 59 | heavy_chain = None 60 | light_chain = None 61 | sequences = SeqIO.parse(fpath, 'fasta') 62 | for seq in sequences: 63 | if 'heavy chain' in seq.description: 64 | heavy_chain = str(seq.seq) 65 | elif 'light chain' in seq.description: 66 | light_chain = str(seq.seq) 67 | else: 68 | continue 69 | assert heavy_chain is not None and light_chain is not None, print("Reading the fasta has problem.") 70 | return heavy_chain, light_chain 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = argparse.ArgumentParser(description="This program is designed to humanize non-human antibodies.") 75 | parser.add_argument('--ckpt', type=str, 76 | default='checkpoints/antibody/hudiffab.pt', 77 | help='The ckpt path.' 78 | ) 79 | parser.add_argument('--anti_complex_fasta', type=str, 80 | default='fasta_file/7k9i.fasta', 81 | help='fasta file of the antibody.' 82 | ) 83 | parser.add_argument('--heavy_seq', type=str, 84 | help='heavy chain sequence of antibody.' 85 | ) 86 | parser.add_argument('--light_seq', type=str, 87 | help='light chain sequence of antibody.' 88 | ) 89 | parser.add_argument('--log_dirpath', type=str, 90 | default='antibody_sample_log/' 91 | # default='./tmp' 92 | ) 93 | parser.add_argument('--batch_size', type=int, 94 | default=10, 95 | help='the batch size of sample.' 96 | ) 97 | parser.add_argument('--seed', type=int, 98 | default=42 99 | ) 100 | parser.add_argument('--sample_number', type=int, 101 | default=10, 102 | help='The number of all sample.' 103 | ) 104 | parser.add_argument('--sample_order', type=str, 105 | default='shuffle') 106 | parser.add_argument('--sample_type', type=str, 107 | default='pair') 108 | parser.add_argument('--finetune', type=str, 109 | default=True) 110 | args = parser.parse_args() 111 | 112 | batch_size = args.batch_size 113 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 114 | # seed_all(args.seed) 115 | 116 | # Read the fasta file or seq. 117 | if args.anti_complex_fasta is not None: 118 | mouse_heavy, mouse_light = get_h_l_seq_from_fasta(args.anti_complex_fasta) 119 | pdb_name = os.path.basename(args.anti_complex_fasta).split('.')[0] 120 | else: 121 | mouse_heavy = args.heavy_seq 122 | mouse_light = args.light_seq 123 | pdb_name = 'Unkown' 124 | 125 | 126 | sample_tag = f'{pdb_name}_{args.sample_order}_{args.sample_type}' 127 | # log dir 128 | if args.log_dirpath is not None: 129 | log_path = args.log_dirpath 130 | else: 131 | log_path = os.path.dirname(args.anti_complex_fasta) 132 | log_dir = get_new_log_dir( 133 | root=log_path, 134 | prefix=sample_tag 135 | ) 136 | logger = get_logger('test', log_dir) 137 | 138 | # load model check point. 139 | ckpt = torch.load(args.ckpt, map_location='cpu') 140 | 141 | # Pretrained config for initialing the model. 142 | pretrain_config = ckpt['pretrain_config'] 143 | model = model_selected(pretrain_config).to(device) 144 | model.load_state_dict(ckpt['model']) 145 | finetune = args.finetune 146 | model.eval() 147 | 148 | logger.info(args.ckpt) 149 | logger.info(args.seed) 150 | 151 | # Fixed param. 152 | pad_region = 0 153 | 154 | # save path 155 | save_fpath = os.path.join(log_dir, 'sample_humanization_result.csv') 156 | with open(save_fpath, 'a', encoding='UTF-8') as f: 157 | f.write('Specific,name,hseq,lseq,\n') 158 | 159 | wrong_idx_list = [] 160 | length_not_equal_list = [] 161 | sample_number = args.sample_number 162 | mouse_aa_h = Chain(mouse_heavy, scheme='imgt').seq 163 | mouse_aa_l = Chain(mouse_light, scheme='imgt').seq 164 | 165 | origin = 'mouse' 166 | name = pdb_name 167 | with open(save_fpath, 'a', encoding='UTF-8') as f: 168 | f.write(f'{origin},{name},{mouse_aa_h},{mouse_aa_l}\n') 169 | 170 | try: 171 | (h_l_pad_seq_sample, h_l_pad_seq_region, 172 | chain_type, h_l_ms_batch, h_l_loc, ms_tokenizer) = batch_input_element(mouse_aa_h, mouse_aa_l, batch_size, pad_region, finetune=finetune) 173 | except: 174 | logger.info('This antibody encoding may have problem, please check!') 175 | 176 | if args.sample_order == 'shuffle': 177 | np.random.shuffle(h_l_loc) 178 | 179 | # Remove Duplcate. 180 | Duplcated_set = set() 181 | while sample_number > 0: 182 | all_token = ms_tokenizer.toks 183 | with torch.no_grad(): 184 | for i in tqdm(h_l_loc, total=len(h_l_loc), desc='Antibody Humanization Process'): 185 | h_l_prediction = model( 186 | h_l_pad_seq_sample.to(device), 187 | h_l_pad_seq_region.to(device), 188 | chain_type.to(device), 189 | # h_l_ms_batch.to(device), 190 | ) 191 | 192 | h_l_pred = h_l_prediction[:, i, :len(all_token)-1] 193 | h_l_soft = torch.nn.functional.softmax(h_l_pred, dim=1) 194 | h_l_sample = torch.multinomial(h_l_soft, num_samples=1) 195 | h_l_pad_seq_sample[:, i] = h_l_sample.squeeze() 196 | 197 | h_pad_seq_sample = h_l_pad_seq_sample[:, :152] 198 | l_pad_seq_sample = h_l_pad_seq_sample[:, 152:] 199 | h_untokenized = [ms_tokenizer.idx2seq(s) for s in h_pad_seq_sample] 200 | l_untokenized = [ms_tokenizer.idx2seq(s) for s in l_pad_seq_sample] 201 | 202 | for _, (g_h, g_l) in enumerate(zip(h_untokenized, l_untokenized)): 203 | 204 | if sample_number == 0: 205 | break 206 | 207 | with open(save_fpath, 'a', encoding='UTF-8') as f: 208 | if (g_h, g_l) not in Duplcated_set: 209 | Duplcated_set.add((g_h, g_l)) 210 | sample_origin = 'humanization' 211 | sample_name = str(name) + 'human_sample' 212 | f.write(f'{sample_origin},{sample_name},{g_h},{g_l}\n') 213 | 214 | sample_number -= 1 215 | logger.info('Already Sample number {}'.format(args.sample_number-sample_number)) 216 | logger.info('Sample Heavy Chain Seq: {}'.format(g_h)) 217 | logger.info('Sample Light Chain Seq: {}'.format(g_l)) 218 | 219 | logger.info('Length did not equal list: {}'.format(length_not_equal_list)) 220 | logger.info('Wrong idx: {}'.format(wrong_idx_list)) 221 | 222 | 223 | 224 | 225 | -------------------------------------------------------------------------------- /configs/antibody_finetune.yml: -------------------------------------------------------------------------------- 1 | name: antibody_finetune 2 | model: 3 | type: structure 4 | all_seq: False 5 | loss_type: 'smooth_loss' 6 | human_threshold: 1.0 7 | mouse_resi_h_ratio: 0.00 8 | mouse_resi_l_ratio: 0.00 9 | mutation: False 10 | 11 | finetune: 12 | batch_acc: 1 13 | batch_size: 32 14 | num_workers: 4 15 | max_iter: 30000 16 | valid_step: 50 17 | optimizer: 18 | type: Adam 19 | lr: 4.e-5 20 | weight_decay: 0. 21 | beta1: 0.95 22 | beta2: 0.999 23 | scheduler: 24 | type: plateau 25 | factor: 0.6 26 | patience: 10 27 | min_lr: 1.e-6 28 | warmup_steps: 10 29 | 30 | preckpt: 31 | ab_vh_ckpt: abnativ/vh_model.ckpt 32 | ab_vlk_ckpt: abnativ/vkappa_model.ckpt 33 | ab_vll_ckpt: abnativ/vlambda_model.ckpt -------------------------------------------------------------------------------- /configs/antibody_test.yml: -------------------------------------------------------------------------------- 1 | name: trans_oadm 2 | model: 3 | n_tokens: 23 4 | d_embedding: 64 5 | d_model: 64 6 | n_encoder_layers: 1 7 | aa_kernel_size: 13 8 | r: 128 9 | n_side: 3 10 | s_embedding: 4 11 | s_model: 64 12 | n_region: 7 13 | r_embedding: 4 14 | r_model: 64 15 | n_pos_model: 64 16 | max_len: 291 17 | sum_d_model: 192 # d_model + s_model + r_model 18 | dual_layers: 2 19 | att_model: 512 20 | dim_feedforward: 512 21 | nhead: 8 22 | cs_layers: 1 # cross attention layers 23 | dropout: 0.2 24 | activation: gelu 25 | 26 | train: 27 | seed: 2023 28 | max_iter: 1000000 29 | batch_acc: 2 30 | valid_step: 2 31 | batch_size: 16 32 | num_workers: 1 33 | clip_norm: 10 34 | loss_type: 'merge' 35 | l_loss_weight: 2 36 | optimizer: 37 | type: Adam 38 | lr: 1.e-4 39 | weight_decay: 0. 40 | beta1: 0.95 41 | beta2: 0.999 42 | scheduler: 43 | type: plateau 44 | factor: 0.6 45 | patience: 10 46 | min_lr: 1.e-6 47 | warmup_steps: 10 48 | # type: -------------------------------------------------------------------------------- /configs/antibody_train.yml: -------------------------------------------------------------------------------- 1 | name: trans_oadm 2 | model: 3 | n_tokens: 23 4 | d_embedding: 256 5 | d_model: 256 6 | n_encoder_layers: 6 7 | aa_kernel_size: 7 8 | r: 128 9 | n_side: 3 10 | s_embedding: 4 11 | s_model: 256 12 | n_region: 7 13 | r_embedding: 4 14 | r_model: 256 15 | n_pos_model: 256 16 | max_len: 291 17 | sum_d_model: 768 # d_model + s_model 18 | dual_layers: 6 19 | att_model: 512 20 | dim_feedforward: 256 21 | nhead: 8 22 | cs_layers: 5 #5 # cross attention layers 23 | dropout: 0.2 24 | activation: gelu 25 | 26 | train: 27 | seed: 2023 28 | max_iter: 1000000 29 | batch_acc: 300 30 | valid_step: 3 31 | batch_size: 128 # 256 32 | num_workers: 4 33 | clip_norm: 10 34 | loss_type: merge 35 | l_loss_weight: 3 36 | optimizer: 37 | type: Adam 38 | lr: 1.e-4 # 1.e-4 39 | weight_decay: 1.e-4 40 | beta1: 0.95 41 | beta2: 0.999 42 | scheduler: 43 | type: plateau # plateau 44 | factor: 0.6 45 | patience: 10 46 | min_lr: 1.e-6 47 | multiplier: 10 48 | total_epoch: 10 -------------------------------------------------------------------------------- /configs/heavy_test.yml: -------------------------------------------------------------------------------- 1 | name: nano 2 | model: 3 | n_tokens: 23 4 | d_embedding: 64 5 | d_model: 64 6 | n_encoder_layers: 1 7 | aa_kernel_size: 13 8 | r: 128 9 | n_region: 7 10 | r_embedding: 4 11 | r_model: 64 12 | n_pos_model: 64 13 | max_len: 152 14 | sum_d_model: 128 # d_model + r_model 15 | dual_layers: 2 16 | att_model: 512 17 | dim_feedforward: 256 18 | nhead: 8 19 | cs_layers: 1 # cross attention layers 20 | dropout: 0.0 21 | activation: gelu 22 | 23 | train: 24 | seed: 2023 25 | max_iter: 1000000 26 | batch_acc: 2 27 | valid_step: 2 28 | batch_size: 16 29 | num_workers: 0 30 | clip_norm: 10 31 | optimizer: 32 | type: Adam 33 | lr: 1.e-4 34 | weight_decay: 0. 35 | beta1: 0.95 36 | beta2: 0.999 37 | scheduler: 38 | type: plateau 39 | factor: 0.6 40 | patience: 10 41 | min_lr: 1.e-6 42 | multiplier: 10 43 | total_epoch: 20 44 | -------------------------------------------------------------------------------- /configs/heavy_train.yml: -------------------------------------------------------------------------------- 1 | name: nano 2 | model: 3 | n_tokens: 23 4 | d_embedding: 256 5 | d_model: 256 6 | n_encoder_layers: 6 7 | aa_kernel_size: 7 8 | r: 128 9 | n_region: 7 10 | r_embedding: 4 11 | r_model: 256 12 | n_pos_model: 256 13 | max_len: 152 14 | sum_d_model: 512 # d_model + r_model 15 | dual_layers: 6 16 | att_model: 512 17 | dim_feedforward: 256 18 | nhead: 8 19 | cs_layers: 5 # cross attention layers 20 | dropout: 0.5 21 | activation: gelu 22 | 23 | train: 24 | seed: 2023 25 | max_iter: 1000000 26 | batch_acc: 300 27 | valid_step: 3 28 | batch_size: 512 29 | num_workers: 4 30 | clip_norm: 10 31 | optimizer: 32 | type: Adam 33 | lr: 1.e-4 34 | weight_decay: 0. 35 | beta1: 0.95 36 | beta2: 0.999 37 | scheduler: 38 | type: plateau 39 | factor: 0.6 40 | patience: 10 41 | min_lr: 1.e-5 42 | multiplier: 10 43 | total_epoch: 20 44 | 45 | -------------------------------------------------------------------------------- /configs/training_nano_framework.yml: -------------------------------------------------------------------------------- 1 | name: infilling 2 | model: 3 | loss_type: 'smooth_loss' 4 | vhh_nativeness: True 5 | temperature: 1 6 | human_threshold: 1.0 # if only human 1 7 | human_all_seq: False 8 | vhh_all_seq: False 9 | equal_weight: False 10 | part_reconstruct_vhh: False 11 | 12 | 13 | finetune: 14 | seed: 2023 15 | max_iter: 1000000 16 | batch_acc: 1 17 | valid_step: 20 18 | batch_size: 512 19 | num_workers: 4 20 | clip_norm: 10 21 | reconstruct_loss_weight: 1.e-3 22 | cross_interval: 5 23 | optimizer: 24 | type: Adam 25 | lr: 1.e-5 26 | weight_decay: 0. 27 | beta1: 0.95 28 | beta2: 0.999 29 | scheduler: 30 | type: plateau 31 | factor: 0.6 32 | patience: 10 33 | min_lr: 1.e-6 34 | multiplier: 10 35 | total_epoch: 20 36 | model: 37 | abnativ_humanness_ckpt_fpath: checkpoints/abnativ/vh_model.ckpt 38 | abnativ_vhh_ckpt_fpath: checkpoints/abnativ/vhh_model.ckpt 39 | infilling_ckpt_fpath: checkpoints/nanobody/pretrained.pt -------------------------------------------------------------------------------- /data/antibody_eval_data/HuAb348_data/only_ll_lab.fasta: -------------------------------------------------------------------------------- 1 | >0_VL VL 2 | QAVVTQEPSLTVSPGGTVTLTCGSSTGAVTTSNFANWVQEKPGQAFRSLIGGTNNRASWVPARFSGSLLGGKAALTISGAQPEDEAEYFCALWYSNHWVFGGGTKLTVL 3 | >1_VL VL 4 | SAEVTQPPSVSVSPGQTARITCSSNTGAVTTSNYANWVQQKPGQAPVGLIGGTNERPSGIPERFSGSSSGNTATLTISGAQAEDEADYYCALWYSNHWVFGGGTKLTVL 5 | >2_VL VL 6 | QPVLTQSPSASASLGASVKLTCTLSSGHSSYTIAWHQQQPGKGPRYLMKLNSDGSHSKGDGIPDRFSGSSSGADRYLTISNLQSEDEADYYCGTWGTGIVVFGGGTKLTVL 7 | >3_VL VL 8 | QTVVTQEPSFSVSPGGTVTLTCRSSTGAVSTSNYANWVQQTPGQAPRGLIGGANSRAPGIPDRFSGSILGNKAALTITGAQADDESDYYCALWFSNHWVFGGGTKLTVL 9 | >4_VL VL 10 | QAVVTQEPSLTVSPGGTVTLTCRSSTGAVTTSNYANWVQQKPGQAPRGLIGGTNKRAPGTPARFSGSLLGGKAALTLSGAQPEDEAEYYCALWYSNLWVFGGGTKLTVL 11 | >5_VL VL 12 | QAVVTQEPSLTVSPGGTVTLTCRSSTGAVTTSNYANWVQEKPGQAPRGLIGGTNKRAPWTPARFSGSLLGGKAALTITGAQAEDEAEYYCVLWYSNLWVFGGGTKLTVL 13 | >6_VL VL 14 | QFQLTQPSSVSASVGDRVTITCERSSGDIGDSYVSWYQQKPGQPPKNVIYADDQRPSGVPDRFSGSIDGSGNSASLTISSLQAEDAADYFCQSYDSNIDFNPVFGGGTKLEVK 15 | -------------------------------------------------------------------------------- /data/antibody_eval_data/Humab25_data/parental_mouse.csv: -------------------------------------------------------------------------------- 1 | ,type,name,h_seq,l_seq 2 | 0,mouse,AntiCD28,EVKLQQSGPGLVTPSQSLSITCTVSGFSLSDYGVHWVRQSPGQGLEWLGVIWAGGGTNYNSALMSRKSISKDNSKSQVFLKMNSLQADDTAVYYCARDKGYSYYYSMDYWGQGTSVTVSS,DIETLQSPASLAVSLGQRATISCRASESVEYYVTSLMQWYQQKPGQPPKLLIFAASNVESGVPARFSGSGSGTNFSLNIHPVDEDDVAMYFCQQSRKYVPYTFGGGTKLEIK 3 | 1,mouse,Campath,EVKLLESGGGLVQPGGSMRLSCAGSGFTFTDFYMNWIRQPAGKAPEWLGFIRDKAKGYTTEYNPSVKGRFTISRDNTQNMLYLQMNTLRAEDTATYYCAREGHTAAPFDYWGQGVMVTVSS,DIKMTQSPSFLSASVGDRVTLNCKASQNIDKYLNWYQQKLGESPKLLIYNTNNLQTGIPSRFSGSGSGTDFTLTISSLQPEDVATYFCLQHISRPRTFGTGTKLELK 4 | 2,mouse,Bevacizumab,EIQLVQSGPELKQPGETVRISCKASGYTFTNYGMNWVKQAPGKGLKWMGWINTYTGEPTYAADFKRRFTFSLETSASTAYLQISNLKNDDTATYFCAKYPHYYGSSHWYFDVWGAGTTVTVSS,DIQMTQTTSSLSASLGDRVIISCSASQDISNYLNWYQQKPDGTVKVLIYFTSSLHSGVPSRFSGSGSGTDYSLTISNLEPEDIATYYCQQYSTVPWTFGGGTKLEIK 5 | 3,mouse,Herceptin,QVQLQQSGPELVKPGASLKLSCTASGFNIKDTYIHWVKQRPEQGLEWIGRIYPTNGYTRYDPKFQDKATITADTSSNTAYLQVSRLTSEDTAVYYCSRWGGDGFYAMDYWGQGASVTVSS,DIVMTQSHKFMSTSVGDRVSITCKASQDVNTAVAWYQQKPGHSPKLLIYSASFRYTGVPDRFTGNRSGTDFTFTISSVQAEDLAVYYCQQHYTTPPTFGGGTKVEIK 6 | 4,mouse,Omalizumab,DVQLQESGPGLVKPSQSLSLACSVTGYSITSGYSWNWIRQFPGNKLEWMGSITYDGSSNYNPSLKNRISVTRDTSQNQFFLKLNSATAEDTATYYCARGSHYFGHWHFAVWGAGTTVTVSS,DIQLTQSPASLAVSLGQRATISCKASQSVDYDGDSYMNWYQQKPGQPPILLIYAASYLGSEIPARFSGSGSGTDFTLNIHPVEEEDAATFYCQQSHEDPYTFGAGTKLEIK 7 | 5,mouse,Eculizumab,QVQLQQSGAELMKPGASVKMSCKATGYIFSNYWIQWIKQRPGHGLEWIGEILPGSGSTEYTENFKDKAAFTADTSSNTAYMQLSSLTSEDSAVYYCARYFFGSSPNWYFDVWGAGTTVTVSS,DIQMTQSPASLSASVGETVTITCGASENIYGALNWYQRKQGKSPQLLIYGATNLADGMSSRFSGSGSGRQYYLKISSLHPDDVATYYCQNVLNTPLTFGAGTKLELK 8 | 6,mouse,Tocilizumab,DVQLQESGPVLVKPSQSLSLTCTVTGYSITSDHAWSWIRQFPGNKLEWMGYISYSGITTYNPSLKSRISITRDTSKNQFFLQLNSVTTGDTSTYYCARSLARTTAMDYWGQGTSVTVSS,DIQMTQTTSSLSASLGDRVTISCRASQDISSYLNWYQQKPDGTIKLLIYYTSRLHSGVPSRFSGSGSGTDYSLTINNLEQEDIATYFCQQGNTLPYTFGGGTKLEIN 9 | 7,mouse,Pembrolizumab,QVQLQQPGAELVKPGTSVKLSCKASGYTFTNYYMYWVKQRPGQGLEWIGGINPSNGGTNFNEKFKNKATLTVDSSSSTTYMQLSSLTSEDSAVYYCTRRDYRFDMGFDYWGQGTTLTVSS,DIVLTQSPASLAVSLGQRAAISCRASKGVSTSGYSYLHWYQQKPGQSPKLLIYLASYLESGVPARFSGSGSGTDFTLNIHPVEEEDAATYYCQHSRDLPLTFGTGTKLELK 10 | 8,mouse,Pertuzumab,EVQLQQSGPELVKPGTSVKISCKASGFTFTDYTMDWVKQSHGKSLEWIGDVNPNSGGSIYNQRFKGKASLTVDRSSRIVYMELRSLTFEDTAVYYCARNLGPSFYFDYWGQGTTLTVSS,DTVMTQSHKIMSTSVGDRVSITCKASQDVSIGVAWYQQRPGQSPKLLIYSASYRYTGVPDRFTGSGSGTDFTFTISSVQAEDLAVYYCQQYYIYPYTFGGGTKLEIK 11 | 9,mouse,Ixekizumab,QVQLQQSRPELVKPGASVKISCKASGYSFTDYNMNWVKQSNGKSLEWIGVINPNYGTTDYNQRFKGKATLTVDQSSRTAYMQLNSLTSEDSAVYYCVIYDYATGTGGYWGQGSPLTVSS,DVVLTQTPLSLPVSLGDQASISCRSSQSLVHSNGNTYLHWYLQKPGQSPKLLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVEAEDLGVYFCSQSTHVPFTFGSGTKLEIK 12 | 10,mouse,Palivizumab,QVELQESGPGILQPSQTLSLTCSFSGFSLSTSGMSVGWIRQPSGEGLEWLADIWWDDKKDYNPSLKSRLTISKDTSSNQVFLKITGVDTADTATYYCARSMITNWYFDVWGAGTTVTVSS,DIQLTQSPAIMSASPGEKVTMTCSASSSVGYMHWYQQKSSTSPKLWIYDTSKLASGVPGRFSGSGSGNSYSLTISSIQAEDVATYYCFQGSGYPFTFGQGTKLEIK 13 | 11,mouse,Certolizumab,QIQLVQSGPELKKPGETVKISCKASGYVFTDYGMNWVKQAPGKAFKWMGWINTYIGEPIYVDDFKGRFAFSLETSASTAFLQINNLKNEDTATYFCARGYRSYAMDYWGQGTSVTVSS,DIVMTQSQKFMSTSVGDRVSVTCKASQNVGTNVAWYQQKPGQSPKALIYSASFLYSGVPYRFTGSGSGTDFTLTISTVQSEDLAEYFCQQYNIYPLTFGAGTKLELK 14 | 12,mouse,Idarucizumab,QVQLEQSGPGLVAPSQRLSITCTVSGFSLTSYIVDWVRQSPGKGLEWLGVIWAGGSTGYNSALRSRLSITKSNSKSQVFLQMNSLQTDDTAIYYCASAAYYSYYNYDGFAYWGQGTLVTVSA,DVVMTQTPLTLSVTIGQPASISCKSSQSLLYTNGKTYLYWLLQRPGQSPKRLIYLVSKLDSGVPDRFSGSGSGTDFTLKISRVEAEDVGIYYCLQSTHFPHTFGGGTKLEIK 15 | 13,mouse,Reslizumab,EVKLLESGGGLVQPSQTLSLTCTVSGLSLTSNSVNWIRQPPGKGLEWMGLIWSNGDTDYNSAIKSRLSISRDTSKSQVFLKMNSLQSEDTAMYFCAREYYGYFDYWGQGVMVTVSS,DIQMTQSPASLSASLGETISIECLASEGISSYLAWYQQKPGKSPQLLIYGANSLQTGVPSRFSGSGSATQYSLKISSMQPEDEGDYFCQQSYKFPNTFGAGTKLELK 16 | 14,mouse,Solanezumab,EVKLVESGGGLVQPGGSLKLSCAVSGFTFSRYSMSWVRQTPEKRLELVAQINSVGNSTYYPDTVKGRFTISRDNAEYTLSLQMSGLRSDDTATYYCASGDYWGQGTTLTVSS,DVVMTQTPLSLPVSLGDQASISCRSSQSLIYSDGNAYLHWFLQKPGQSPKLLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVETEDLGVYFCSQSTHVPWTFGGGTKLEIK 17 | 15,mouse,Lorvotuzumab,DVQLVESGGGLVQPGGSRKLSCAASGFTFSSFGMHWVRQAPEKGLEWVAYISSGSFTIYHADTVKGRFTISRDNPKNTLFLQMTSLRAEDTAHYYCARMRKGYAMDYWGQGTTVTVSS,DVLMTQTPLSLPVSLGDQASISCRSSQIIIHSDGNTYLEWFLQKPGQSPKLLIYKVSNRFSGVPDRFSGSGSGTDFTLMISRVEAEDLGVYYCFQGSHVPHTFGGGTKLEIK 18 | 16,mouse,Pinatuzumab,QVQLQQSGPELVKPGASVKISCKASGYEFSRSWMNWVKQRPGQGREWIGRIYPGDGDTNYSGKFKGKATLTADKSSSTAYMQLSSLTSVDSAVYFCARDGSSWDWYFDVWGAGTTVTVSS,DILMTQTPLSLPVSLGDQASISCRSSQSIVHSNGNTFLEWYLQKPGQSPKLLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVEAEDLGVYYCFQGSQFPYTFGGGTKVEIK 19 | 17,mouse,Etaracizumab,EVQLEESGGGLVKPGGSLKLSCAASGFAFSSYDMSWVRQIPEKRLEWVAKVSSGGGSTYYLDTVQGRFTISRDNAKNTLYLQMSSLNSEDTAMYYCARHNYGSFAYWGQGTLVTVSA,ELVMTQTPATLSVTPGDSVSLSCRASQSISNHLHWYQQKSHESPRLLIKYASQSISGIPSRFSGSGSGTDFTLSINSVETEDFGMYFCQQSNSWPHTFGGGTKLEIK 20 | 18,mouse,Talacotuzumab,EVQLQQSGPELVKPGASVKMSCKASGYTFTDYYMKWVKQSHGKSLEWIGDIIPSNGATFYNQKFKGKATLTVDRSSSTAYMHLNSLTSEDSAVYYCTRSHLLRASWFAYWGQGTLVTVSA,DFVMTQSPSSLTVTAGEKVTMSCKSSQSLLNSGNQKNYLTWYLQKPGQPPKLLIYWASTRESGVPDRFTGSGSGTDFTLTISSVQAEDLAVYYCQNDYSYPYTFGGGTKLEIK 21 | 19,mouse,Rovalpituzumab,QIQLVQSGPELKKPGETVKISCKASGYTFTNYGMNWVKQAPGKGLKWMAWINTYTGEPTYADDFKGRFAFSLETSASTASLQIINLKNEDTATYFCARIGDSSPSDYWGQGTTLTVSS,SIVMTQTPKFLLVSAGDRVTITCKASQSVSNDVVWYQQKPGQSPKLLIYYASNRYTGVPDRFAGSGYGTDFSFTISTVQAEDLAVYFCQQDYTSPWTFGGGTKLEIR 22 | 20,mouse,Clazakizumab,QSLEESGGRLVTPGTPLTLTCTASGFSLSNYYVTWVRQAPGKGLEWIGIIYGSDETAYATWAIGRFTISKTSTTVDLKMTSLTAADTATYFCARDDSSDWDAKFNLWGQGTLVTVSS,AYDMTQTPASVSAAVGGTVTIKCQASQSINNELSWYQQKPGQRPKLLIYRASTLASGVSSRFKGSGSGTEFTLTISDLECADAATYYCQQGYSLRNIDNAFGGGTEVVVK 23 | 21,mouse,Ligelizumab,QVQLQQSGAELMKPGASVKISCKTTGYTFSMYWLEWVKQRPGHGLEWVGEISPGTFTTNYNEKFKAKATFTADTSSNTAYLQLSGLTSEDSAVYFCARFSHFSGSNYDYFDYWGQGTSLTVSS,DILLTQSPAILSVSPGERVSFSCRASQSIGTNIHWYQQRTDGSPRLLIKYASESISGIPSRFSGSGSGTEFTLNINSVESEDIADYYCQQSDSWPTTFGGGTKLEIK 24 | 22,mouse,Crizanlizumab,QVQLQQSGPELVKPGALVKISCKASGYTFTSYDINWVKQRPGQGLEWIGWIYPGDGSIKYNEKFKGKATLTVDKSSSTAYMQVSSLTSENSAVYFCARRGEYGNYEGAMDYWGQGTTVTVSS,DIVLTQSPASLAVSLGQRATISCKASQSVDYDGHSYMNWYQQKPGQPPKLLIYAASNLESGIPARFSGSGSGTDFTLNIHPVEEEDAATYYCQQSDENPLTFGTGTKLELK 25 | 23,mouse,Mogamulizumab,EVQLVESGGDLMKPGGSLKISCAASGFIFSNYGMSWVRQTPDMRLEWVATISSASTYSYYPDSVKGRFTISRDNAENSLYLQMNSLRSEDTGIYYCGRHSDGNFAFGYWGRGTLVTVSA,DVLMTQTPLSLPVSLGDQASISCRSSRNIVHINGDTYLEWYLQRPGQSPKLLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVEAEDLGVYYCFQGSLLPWTFGGGTRLEIR 26 | 24,mouse,Refanezumab,EIQLVQSGPELKKPGETNKISCKASGYTFTNYGMNWVKQAPGKGLKWMGWINTYTGEPTYADDFTGRFAFSLETSASTAYLQISNLKNEDTATYFCARNPINYYGINYEGYVMDYWGQGTLVTVSS,NIMMTQSPSSLAVSAGEKVTMSCKSSHSVLYSSNQKNYLAWYQQKPGQSPKLLIYWASTRESGVPDRFTGSGSGTDFTLTIINVHTEDLAVYYCHQYLSSLTFGTGTKLEIK 27 | -------------------------------------------------------------------------------- /data/antibody_eval_data/Humab25_data/sample_experimental_t20_score.csv: -------------------------------------------------------------------------------- 1 | Raw_name,h_score,h_gene,l_score,l_gene,h_seq,l_seq 2 | AntiCD28,80.8756,vh,80.58047,vk,EVQLVQSGGGLVQPGGSLRLSCAGSGFTFSDYGVHWVRQAPGKGLEWVSAIWAGGGTNYASSVMGRFTISRDNAKNSLYLQMNSLRAEDMAVYYCARDKGYSYYYSMDYWGQGTLVTVSS,DIVMTQSPDSLAVSLGERATINCRASESVEYYVTSLMAWYQQKPGQPPKLLIYAASNVESGVPDRFSGSGSGTNFSLTISSLQAEDVAVYYCQQSRKYVPYTFGQGTKLEIK 3 | Campath,72.3146,vh,85.60757,vk,QVQLQESGPGLVRPSQTLSLTCTVSGFTFTDFYMNWVRQPPGRGLEWIGFIRDKAKGYTTEYNPSVKGRVTMLVDTSKNQFSLRLSSVTAADTAVYYCAREGHTAAPFDYWGQGSLVTVSS,DIQMTQSPSSLSASVGDRVTITCKASQNIDKYLNWYQQKPGKAPKLLIYNTNNLQTGVPSRFSGSGSGTDFTFTISSLQPEDIATYYCLQHISRPRTFGQGTKVEIK 4 | Bevacizumab,73.65857,vh,89.25237,vk,EVQLVESGGGLVQPGGSLRLSCAASGYTFTNYGMNWVRQAPGKGLEWVGWINTYTGEPTYAADFKRRFTFSLDTSKSTAYLQMNSLRAEDTAVYYCAKYPHYYGSSHWYFDVWGQGTLVTVSS,DIQMTQSPSSLSASVGDRVTITCSASQDISNYLNWYQQKPGKAPKVLIYFTSSLHSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQYSTVPWTFGQGTKVEIK 5 | Herceptin,86.255,vh,91.07487,vk,EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLEWVARIYPTNGYTRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCSRWGGDGFYAMDYWGQGTLVTVSS,DIQMTQSPSSLSASVGDRVTITCRASQDVNTAVAWYQQKPGKAPKLLIYSASFLESGVPSRFSGSRSGTDFTLTISSLQPEDFATYYCQQHYTTPTTFGQGTKVEIK 6 | Omalizumab,76.36367,vh,87.74777,vk,EVQLVESGGGLVQPGGSLRLSCAVSGYSITSGYSWNWIRQAPGKGLEWVASITYDGSTNYADSVKGRFTISRDDSKNTFYLQMNSLRAEDTAVYYCARGSHYFGHWHFAVWGQGTLVTVSS,DIQLTQSPSSLSASVGDRVTITCRASQSVDYDGDSYMNWYQQKPGKAPKLLIYAASYLESGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSHEDPYTFGQGTKVEIK 7 | Eculizumab,78.77057,vh,84.81317,vk,QVQLVQSGAEVKKPGASVKVSCKASGYIFSNYWIQWVRQAPGQGLEWMGEILPGSGSTEYTENFKDRVTMTRDTSTSTVYMELSSLRSEDTAVYYCARYFFGSSPNWYFDVWGQGTLVTVSS,DIQMTQSPSSLSASVGDRVTITCGASENIYGALNWYQQKPGKAPKLLIYGATNLADGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQNVLNTPLTFGQGTKVEIK 8 | Tocilizumab,80.63037,vh,88.7856,vk,QVQLQESGPGLVRPSQTLSLTCTVSGYSITSDHAWSWVRQPPGRGLEWIGYISYSGITTYNPSLKSRVTMLRDTSKNQFSLRLSSVTAADTAVYYCARSLARTTAMDYWGQGSLVTVSS,DIQMTQSPSSLSASVGDRVTITCRASQDISSYLNWYQQKPGKAPKLLIYYTSRLHSGVPSRFSGSGSGTDFTFTISSLQPEDIATYYCQQGNTLPYTFGQGTKVEIK 9 | Pembrolizumab,75.70837,vh,82.74777,vk,QVQLVQSGVEVKKPGASVKVSCKASGYTFTNYYMYWVRQAPGQGLEWMGGINPSNGGTNFNEKFKNRVTLTTDSSTTTAYMELKSLQFDDTAVYYCARRDYRFDMGFDYWGQGTTVTVSS,EIVLTQSPATLSLSPGERATLSCRASKGVSTSGYSYLHWYQQKPGQAPRLLIYLASYLESGVPARFSGSGSGTDFTLTISSLEPEDFAVYYCQHSRDLPLTFGGGTKVEIK 10 | Pertuzumab,76.47067,vh,89.81317,vk,EVQLVESGGGLVQPGGSLRLSCAASGFTFTDYTMDWVRQAPGKGLEWVADVNPNSGGSIYNQRFKGRFTLSVDRSKNTLYLQMNSLRAEDTAVYYCARNLGPSFYFDYWGQGTLVTVSS,DIQMTQSPSSLSASVGDRVTITCKASQDVSIGVAWYQQKPGKAPKLLIYSASYRYTGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQYYIYPYTFGQGTKVEIK 11 | Ixekizumab,76.47067,vh,85.58047,vk,QVQLVQSGAEVKKPGSSVKVSCKASGYSFTDYHIHWVRQAPGQGLEWMGVINPMYGTTDYNQRFKGRVTITADESTSTAYMELSSLRSEDTAVYYCARYDYFTGTGVYWGQGTLVTVSS,DIVMTQTPLSLSVTPGQPASISCRSSRSLVHSRGNTYLHWYLQKPGQSPQLLIYKVSNRFIGVPDRFSGSGSGTDFTLKISRVEAEDVGVYYCSQSTHLPFTFGQGTKLEIK 12 | Palivizumab,80.755,vh,79.15897,vk,QVTLRESGPALVKPTQTLTLTCTFSGFSLSTSGMSVGWIRQPPGKALEWLADIWWDDKKDYNPSLKSRLTISKDTSKNQVVLKVTNMDPADTATYYCARSMITNWYFDVWGAGTTVTVSS,DIQMTQSPSTLSASVGDRVTITCKCQLSVGYMHWYQQKPGKAPKLLIYDTSKLASGVPSRFSGSGSGTEFTLTISSLQPDDFATYYCFQGSGYPFTFGGGTKLEIK 13 | Certolizumab,80.755,vh,88.59817,vk,EVQLVESGGGLVQPGGSLRLSCAASGYVFTDYGMNWVRQAPGKGLEWMGWINTYIGEPIYADSVKGRFTFSLDTSKSTAYLQMNSLRAEDTAVYYCARGYRSYAMDYWGQGTLVTVSS,DIQMTQSPSSLSASVGDRVTITCKASQNVGTNVAWYQQKPGKAPKALIYSASFLYSGVPYRFSGSGSGTDFTLTISSLQPEDFATYYCQQYNIYPLTFGQGTKVEIK 14 | Idarucizumab,77.33617,vh,86.56257,vk,QVQLQESGPGLVKPSETLSLTCTVSGFSLTSYIVDWIRQPPGKGLEWIGVIWAGGSTGYNSALRSRVSITKDTSKNQFSLKLSSVTAADTAVYYCASAAYYSYYNYDGFAYWGQGTLVTVSS,DVVMTQSPLSLPVTLGQPASISCKSSQSLLYTDGKTYLYWFLQRPGQSPRRLIYLVSKLDSGVPDRFSGSGSGTDFTLKISRVEAEDVGVYYCLQSTHFPHTFGGGTKVEIK 15 | Reslizumab,75.25867,vh,87.05617,vk,EVQLVESGGGLVQPGGSLRLSCAVSGLSLTSNSVNWIRQAPGKGLEWVGLIWSNGDTDYNSAIKSRFTISRDTSKSTVYLQMNSLRAEDTAVYYCAREYYGYFDYWGQGTLVTVSS,DIQMTQSPSSLSASVGDRVTITCLASEGISSYLAWYQQKPGKAPKLLIYGANSLQTGVPSRFSGSGSATDYTLTISSLQPEDFATYYCQQSYKFPNTFGQGTKVEVK 16 | Solanezumab,83.93167,vh,88.43757,vk,EVQLVESGGGLVQPGGSLRLSCAASGFTFSRYSMSWVRQAPGKGLELVAQINSVGNSTYYPDTVKGRFTISRDNAKNTLYLQMNSLRAEDTAVYYCASGDYWGQGTLVTVSS,DVVMTQSPLSLPVTLGQPASISCRSSQSLIYSDGNAYLHWFLQKPGQSPRLLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVEAEDVGVYYCSQSTHVPWTFGQGTKVEIK 17 | Lorvotuzumab,87.28817,vh,89.68757,vk,QVQLVESGGGVVQPGRSLRLSCAASGFTFSSFGMHWVRQAPGKGLEWVAYISSGSFTIYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCARMRKGYAMDYWGQGTLVTVSS,DVVMTQSPLSLPVTLGQPASISCRSSQIIIHSDGNTYLEWFQQRPGQSPRRLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVEAEDVGVYYCFQGSHVPHTFGQGTKVEIK 18 | Pinatuzumab,77.70837,vh,81.51797,vk,EVQLVESGGGLVQPGGSLRLSCAASGYEFSRSWMNWVRQAPGKGLEWVGRIYPGDGDTNYSGKFKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARDGSSWDWYFDVWGQGTLVTVSS,DIQMTQSPSSLSASVGDRVTITCRSSQSIVHSVGNTFLEWYQQKPGKAPKLLIYKVSNRFSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCFQGSQFPYTFGQGTKVEIK 19 | Etaracizumab,82.99157,vh,84.11217,vk,QVQLVESGGGVVQPGRSLRLSCAASGFTFSSYDMSWVRQAPGKGLEWVAKVSSGGGSTYYLDTVQGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCARHLHGSFASWGQGTTVTVSS,EIVLTQSPATLSLSPGERATLSCQASQSISNFLHWYQQRPGQAPRLLIRYRSQSISGIPARFSGSGSGTDFTLTISSLEPEDFAVYYCQQSGSWPLTFGGGTKVEIK 20 | Talacotuzumab,77.79177,vh,89.60187,vk,EVQLVQSGAEVKKPGESLKISCKGSGYSFTDYYMKWARQMPGKGLEWMGDIIPSNGATFYNQKFKGQVTISADKSISTTYLQWSSLKASDTAMYYCARSHLLRASWFAYWGQGTMVTVSS,DIVMTQSPDSLAVSLGERATINCESSQSLLNSGNQKNYLTWYQQKPGQPPKPLIYWASTRESGVPDRFSGSGSGTDFTLTISSLQAEDVAVYYCQNDYSYPYTFGQGTKLEIK 21 | Rovalpituzumab,81.99157,vh,85.98137,vk,QVQLVQSGAEVKKPGASVKVSCKASGYTFTNYGMNWVRQAPGQGLEWMGWINTYTGEPTYADDFKGRVTMTTDTSTSTAYMELRSLRSDDTAVYYCARIGDSSPSDYWGQGTLVTVSS,EIVMTQSPATLSVSPGERATLSCKASQSVSNDVVWYQQKPGQAPRLLIYYASNRYTGIPARFSGSGSGTEFTLTISSLQSEDFAVYYCQQDYTSPWTFGQGTKLEIK 22 | Clazakizumab,78.16677,vh,83.27277,vk,EVQLVESGGGLVQPGGSLRLSCAASGFSLSNYYVTWVRQAPGKGLEWVGIIYGSDETAYATSAIGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCARDDSSDWDAKFNLWGQGTLVTVSS,AIQMTQSPSSLSASVGDRVTITCQASQSINNELSWYQQKPGKAPKLLIYRASTLASGVPSRFSGSGSGTDFTLTISSLQPDDFATYYCQQGYSLRNIDNAFGGGTKVEIK 23 | Ligelizumab,74.8786,vh,84.76647,vk,QVQLVQSGAEVMKPGSSVKVSCKASGYTFSWYWLEWVRQAPGHGLEWMGEIDPGTFTTNYNEKFKARVTFTADTSTSTAYMELSSLRSEDTAVYYCARFSHFSGSNYDYFDYWGQGTLVTVSS,EIVMTQSPATLSVSPGERATLSCRASQSIGTNIHWYQQKPGQAPRLLIYYASESISGIPARFSGSGSGTEFTLTISSLQSEDFAVYYCQQSWSWPTTFGGGTKVEIK 24 | Crizanlizumab,77.6236,vh,87.38747,vk,QVQLVQSGAEVKKPGASVKVSCKVSGYTFTSYDINWVRQAPGKGLEWMGWIYPGDGSIKYNEKFKGRVTMTVDKSTDTAYMELSSLRSEDTAVYYCARRGEYGNYEGAMDYWGQGTLVTVSS,DIQMTQSPSSLSASVGDRVTITCKASQSVDYDGHSYMNWYQQKPGKAPKLLIYAASNLESGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSDENPLTFGGGTKVEIK 25 | Mogamulizumab,78.27737,vh,83.52687,vk,EVQLVESGGDLVQPGRSLRLSCAASGFIFSNYGMSWVRQAPGKGLEWVATISSASTYSYYPDSVKGRFTISRDNAKNSLYLQMNSLRVEDTALYYCGRHSDGNFAFGYWGQGTLVTVSS,DVLMTQSPLSLPVTPGEPASISCRSSRNIVHINGDTYLEWYLQKPGQSPQLLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVEAEDVGVYYCFQGSLLPWTFGQGTKVEIK 26 | Refanezumab,83.80957,vh,93.09737,vk,QVQLVQSGSELKKPGASVKVSCKASGYTFTNYGMNWVRQAPGQGLEWMGWINTYTGEPTYADDFTGRFVFSLDTSVSTAYLQISSLKAEDTAVYYCARNPINYYGINYEGYVMDYWGQGTLVTVSS,DIVMTQSPDSLAVSLGERATINCKSSHSVLYSSNQKNYLAWYQQKPGQPPKLLIYWASTRESGVPDRFSGSGSGTDFTLTISSLQAEDVAVYYCHQYLSSLTFGQGTKLEIK 27 | -------------------------------------------------------------------------------- /data/antibody_eval_data/Humab25_data/sample_mouse_t20_score.csv: -------------------------------------------------------------------------------- 1 | Raw_name,h_score,h_gene,l_score,l_gene,h_seq,l_seq 2 | AntiCD28,62.54,vh,62.76797,vk,EVKLQQSGPGLVTPSQSLSITCTVSGFSLSDYGVHWVRQSPGQGLEWLGVIWAGGGTNYNSALMSRKSISKDNSKSQVFLKMNSLQADDTAVYYCARDKGYSYYYSMDYWGQGTSVTVSS,DIETLQSPASLAVSLGQRATISCRASESVEYYVTSLMQWYQQKPGQPPKLLIFAASNVESGVPARFSGSGSGTNFSLNIHPVDEDDVAMYFCQQSRKYVPYTFGGGTKLEIK 3 | Campath,68.38847,vh,67.61687,vk,EVKLLESGGGLVQPGGSMRLSCAGSGFTFTDFYMNWIRQPAGKAPEWLGFIRDKAKGYTTEYNPSVKGRFTISRDNTQNMLYLQMNTLRAEDTATYYCAREGHTAAPFDYWGQGVMVTVSS,DIKMTQSPSFLSASVGDRVTLNCKASQNIDKYLNWYQQKLGESPKLLIYNTNNLQTGIPSRFSGSGSGTDFTLTISSLQPEDVATYFCLQHISRPRTFGTGTKLELK 4 | Bevacizumab,67.92687,vh,75.70097,vk,EIQLVQSGPELKQPGETVRISCKASGYTFTNYGMNWVKQAPGKGLKWMGWINTYTGEPTYAADFKRRFTFSLETSASTAYLQISNLKNDDTATYFCAKYPHYYGSSHWYFDVWGAGTTVTVSS,DIQMTQTTSSLSASLGDRVIISCSASQDISNYLNWYQQKPDGTVKVLIYFTSSLHSGVPSRFSGSGSGTDYSLTISNLEPEDIATYYCQQYSTVPWTFGGGTKLEIK 5 | Herceptin,61.70837,vh,73.41127,vk,QVQLQQSGPELVKPGASLKLSCTASGFNIKDTYIHWVKQRPEQGLEWIGRIYPTNGYTRYDPKFQDKATITADTSSNTAYLQVSRLTSEDTAVYYCSRWGGDGFYAMDYWGQGASVTVSS,DIVMTQSHKFMSTSVGDRVSITCKASQDVNTAVAWYQQKPGHSPKLLIYSASFRYTGVPDRFTGNRSGTDFTFTISSVQAEDLAVYYCQQHYTTPPTFGGGTKVEIK 6 | Omalizumab,66.1576,vh,67.74777,vk,DVQLQESGPGLVKPSQSLSLACSVTGYSITSGYSWNWIRQFPGNKLEWMGSITYDGSSNYNPSLKNRISVTRDTSQNQFFLKLNSATAEDTATYYCARGSHYFGHWHFAVWGAGTTVTVSS,DIQLTQSPASLAVSLGQRATISCKASQSVDYDGDSYMNWYQQKPGQPPILLIYAASYLGSEIPARFSGSGSGTDFTLNIHPVEEEDAATFYCQQSHEDPYTFGAGTKLEIK 7 | Eculizumab,60.77877,vh,67.61687,vk,QVQLQQSGAELMKPGASVKMSCKATGYIFSNYWIQWIKQRPGHGLEWIGEILPGSGSTEYTENFKDKAAFTADTSSNTAYMQLSSLTSEDSAVYYCARYFFGSSPNWYFDVWGAGTTVTVSS,DIQMTQSPASLSASVGETVTITCGASENIYGALNWYQRKQGKSPQLLIYGATNLADGMSSRFSGSGSGRQYYLKISSLHPDDVATYYCQNVLNTPLTFGAGTKLELK 8 | Tocilizumab,70.63037,vh,74.29917,vk,DVQLQESGPVLVKPSQSLSLTCTVTGYSITSDHAWSWIRQFPGNKLEWMGYISYSGITTYNPSLKSRISITRDTSKNQFFLQLNSVTTGDTSTYYCARSLARTTAMDYWGQGTSVTVSS,DIQMTQTTSSLSASLGDRVTISCRASQDISSYLNWYQQKPDGTIKLLIYYTSRLHSGVPSRFSGSGSGTDYSLTINNLEQEDIATYFCQQGNTLPYTFGGGTKLEIN 9 | Pembrolizumab,66.33337,vh,66.26137,vk,QVQLQQPGAELVKPGTSVKLSCKASGYTFTNYYMYWVKQRPGQGLEWIGGINPSNGGTNFNEKFKNKATLTVDSSSSTTYMQLSSLTSEDSAVYYCTRRDYRFDMGFDYWGQGTTLTVSS,DIVLTQSPASLAVSLGQRAAISCRASKGVSTSGYSYLHWYQQKPGQSPKLLIYLASYLESGVPARFSGSGSGTDFTLNIHPVEEEDAATYYCQHSRDLPLTFGTGTKLELK 10 | Pertuzumab,70.63037,vh,71.35517,vk,EVQLQQSGPELVKPGTSVKISCKASGFTFTDYTMDWVKQSHGKSLEWIGDVNPNSGGSIYNQRFKGKASLTVDRSSRIVYMELRSLTFEDTAVYYCARNLGPSFYFDYWGQGTTLTVSS,DTVMTQSHKIMSTSVGDRVSITCKASQDVSIGVAWYQQRPGQSPKLLIYSASYRYTGVPDRFTGSGSGTDFTFTISSVQAEDLAVYYCQQYYIYPYTFGGGTKLEIK 11 | Ixekizumab,60.16817,vh,80.22327,vk,QVQLQQSRPELVKPGASVKISCKASGYSFTDYNMNWVKQSNGKSLEWIGVINPNYGTTDYNQRFKGKATLTVDQSSRTAYMQLNSLTSEDSAVYYCVIYDYATGTGGYWGQGSPLTVSS,DVVLTQTPLSLPVSLGDQASISCRSSQSLVHSNGNTYLHWYLQKPGQSPKLLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVEAEDLGVYFCSQSTHVPFTFGSGTKLEIK 12 | Palivizumab,67.755,vh,62.57017,vk,QVELQESGPGILQPSQTLSLTCSFSGFSLSTSGMSVGWIRQPSGEGLEWLADIWWDDKKDYNPSLKSRLTISKDTSSNQVFLKITGVDTADTATYYCARSMITNWYFDVWGAGTTVTVSS,DIQLTQSPAIMSASPGEKVTMTCSASSSVGYMHWYQQKSSTSPKLWIYDTSKLASGVPGRFSGSGSGNSYSLTISSIQAEDVATYYCFQGSGYPFTFGQGTKLEIK 13 | Certolizumab,68.94077,vh,70.18697,vk,QIQLVQSGPELKKPGETVKISCKASGYVFTDYGMNWVKQAPGKAFKWMGWINTYIGEPIYVDDFKGRFAFSLETSASTAFLQINNLKNEDTATYFCARGYRSYAMDYWGQGTSVTVSS,DIVMTQSQKFMSTSVGDRVSVTCKASQNVGTNVAWYQQKPGQSPKALIYSASFLYSGVPYRFTGSGSGTDFTLTISTVQSEDLAEYFCQQYNIYPLTFGAGTKLELK 14 | Idarucizumab,58.11487,vh,79.10717,vk,QVQLEQSGPGLVAPSQRLSITCTVSGFSLTSYIVDWVRQSPGKGLEWLGVIWAGGSTGYNSALRSRLSITKSNSKSQVFLQMNSLQTDDTAIYYCASAAYYSYYNYDGFAYWGQGTLVTVSA,DVVMTQTPLTLSVTIGQPASISCKSSQSLLYTNGKTYLYWLLQRPGQSPKRLIYLVSKLDSGVPDRFSGSGSGTDFTLKISRVEAEDVGIYYCLQSTHFPHTFGGGTKLEIK 15 | Reslizumab,62.88797,vh,69.81317,vk,EVKLLESGGGLVQPSQTLSLTCTVSGLSLTSNSVNWIRQPPGKGLEWMGLIWSNGDTDYNSAIKSRLSISRDTSKSQVFLKMNSLQSEDTAMYFCAREYYGYFDYWGQGVMVTVSS,DIQMTQSPASLSASLGETISIECLASEGISSYLAWYQQKPGKSPQLLIYGANSLQTGVPSRFSGSGSATQYSLKISSMQPEDEGDYFCQQSYKFPNTFGAGTKLELK 16 | Solanezumab,70.38467,vh,80.35717,vk,EVKLVESGGGLVQPGGSLKLSCAVSGFTFSRYSMSWVRQTPEKRLELVAQINSVGNSTYYPDTVKGRFTISRDNAEYTLSLQMSGLRSDDTATYYCASGDYWGQGTTLTVSS,DVVMTQTPLSLPVSLGDQASISCRSSQSLIYSDGNAYLHWFLQKPGQSPKLLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVETEDLGVYFCSQSTHVPWTFGGGTKLEIK 17 | Lorvotuzumab,78.34757,vh,79.28577,vk,DVQLVESGGGLVQPGGSRKLSCAASGFTFSSFGMHWVRQAPEKGLEWVAYISSGSFTIYHADTVKGRFTISRDNPKNTLFLQMTSLRAEDTAHYYCARMRKGYAMDYWGQGTTVTVSS,DVLMTQTPLSLPVSLGDQASISCRSSQIIIHSDGNTYLEWFLQKPGQSPKLLIYKVSNRFSGVPDRFSGSGSGTDFTLMISRVEAEDLGVYYCFQGSHVPHTFGGGTKLEIK 18 | Pinatuzumab,63.66677,vh,80.17867,vk,QVQLQQSGPELVKPGASVKISCKASGYEFSRSWMNWVKQRPGQGREWIGRIYPGDGDTNYSGKFKGKATLTADKSSSTAYMQLSSLTSVDSAVYFCARDGSSWDWYFDVWGAGTTVTVSS,DILMTQTPLSLPVSLGDQASISCRSSQSIVHSNGNTFLEWYLQKPGQSPKLLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVEAEDLGVYYCFQGSQFPYTFGGGTKVEIK 19 | Etaracizumab,76.32487,vh,66.68227,vk,EVQLEESGGGLVKPGGSLKLSCAASGFAFSSYDMSWVRQIPEKRLEWVAKVSSGGGSTYYLDTVQGRFTISRDNAKNTLYLQMSSLNSEDTAMYYCARHNYGSFAYWGQGTLVTVSA,ELVMTQTPATLSVTPGDSVSLSCRASQSISNHLHWYQQKSHESPRLLIKYASQSISGIPSRFSGSGSGTDFTLSINSVETEDFGMYFCQQSNSWPHTFGGGTKLEIK 20 | Talacotuzumab,64.755,vh,79.24787,vk,EVQLQQSGPELVKPGASVKMSCKASGYTFTDYYMKWVKQSHGKSLEWIGDIIPSNGATFYNQKFKGKATLTVDRSSSTAYMHLNSLTSEDSAVYYCTRSHLLRASWFAYWGQGTLVTVSA,DFVMTQSPSSLTVTAGEKVTMSCKSSQSLLNSGNQKNYLTWYLQKPGQPPKLLIYWASTRESGVPDRFTGSGSGTDFTLTISSVQAEDLAVYYCQNDYSYPYTFGGGTKLEIK 21 | Rovalpituzumab,71.05937,vh,67.38327,vk,QIQLVQSGPELKKPGETVKISCKASGYTFTNYGMNWVKQAPGKGLKWMAWINTYTGEPTYADDFKGRFAFSLETSASTASLQIINLKNEDTATYFCARIGDSSPSDYWGQGTTLTVSS,SIVMTQTPKFLLVSAGDRVTITCKASQSVSNDVVWYQQKPGQSPKLLIYYASNRYTGVPDRFAGSGYGTDFSFTISTVQAEDLAVYFCQQDYTSPWTFGGGTKLEIR 22 | Clazakizumab,56.54,vh,64.18187,vk,QSLEESGGRLVTPGTPLTLTCTASGFSLSNYYVTWVRQAPGKGLEWIGIIYGSDETAYATWAIGRFTISKTSTTVDLKMTSLTAADTATYFCARDDSSDWDAKFNLWGQGTLVTVSS,AYDMTQTPASVSAAVGGTVTIKCQASQSINNELSWYQQKPGQRPKLLIYRASTLASGVSSRFKGSGSGTEFTLTISDLECADAATYYCQQGYSLRNIDNAFGGGTEVVVK 23 | Ligelizumab,59.91877,vh,66.30847,vk,QVQLQQSGAELMKPGASVKISCKTTGYTFSMYWLEWVKQRPGHGLEWVGEISPGTFTTNYNEKFKAKATFTADTSSNTAYLQLSGLTSEDSAVYFCARFSHFSGSNYDYFDYWGQGTSLTVSS,DILLTQSPAILSVSPGERVSFSCRASQSIGTNIHWYQQRTDGSPRLLIKYASESISGIPSRFSGSGSGTEFTLNINSVESEDIADYYCQQSDSWPTTFGGGTKLEIK 24 | Crizanlizumab,65.40987,vh,69.9556,vk,QVQLQQSGPELVKPGALVKISCKASGYTFTSYDINWVKQRPGQGLEWIGWIYPGDGSIKYNEKFKGKATLTVDKSSSTAYMQVSSLTSENSAVYFCARRGEYGNYEGAMDYWGQGTTVTVSS,DIVLTQSPASLAVSLGQRATISCKASQSVDYDGHSYMNWYQQKPGQPPKLLIYAASNLESGIPARFSGSGSGTDFTLNIHPVEEEDAATYYCQQSDENPLTFGTGTKLELK 25 | Mogamulizumab,70.58827,vh,75.6256,vk,EVQLVESGGDLMKPGGSLKISCAASGFIFSNYGMSWVRQTPDMRLEWVATISSASTYSYYPDSVKGRFTISRDNAENSLYLQMNSLRSEDTGIYYCGRHSDGNFAFGYWGRGTLVTVSA,DVLMTQTPLSLPVSLGDQASISCRSSRNIVHINGDTYLEWYLQRPGQSPKLLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVEAEDLGVYYCFQGSLLPWTFGGGTRLEIR 26 | Refanezumab,70.83337,vh,78.23017,vk,EIQLVQSGPELKKPGETNKISCKASGYTFTNYGMNWVKQAPGKGLKWMGWINTYTGEPTYADDFTGRFAFSLETSASTAYLQISNLKNEDTATYFCARNPINYYGINYEGYVMDYWGQGTLVTVSS,NIMMTQSPSSLAVSAGEKVTMSCKSSHSVLYSSNQKNYLAWYQQKPGQSPKLLIYWASTRESGVPDRFTGSGSGTDFTLTIINVHTEDLAVYYCHQYLSSLTFGTGTKLEIK 27 | -------------------------------------------------------------------------------- /data/fasta_file/7k9i.fasta: -------------------------------------------------------------------------------- 1 | >7K9I_1|Chain A|Spike protein S1|Severe acute respiratory syndrome coronavirus 2 (2697049) 2 | TNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGP 3 | >7K9I_2|Chain B[auth H]|2B04 heavy chain|Mus musculus (10090) 4 | QVQLKQSGPGLVAPSQSLSITCTVSGFSLINYAISWVRQPPGKGLEWLGVIWTGGGTNYNSALKSRLSISKDNSKSQVFLKMNSLQTDDTARYYCARKDYYGRYYGMDYWGQGTSVTVS 5 | >7K9I_3|Chain C[auth L]|2B04 light chain|Mus musculus (10090) 6 | QAVVTQESALTTSPGETVTLTCRSSTGAVTTSNYANWVQEKPDHLFTGLIGGTNNRAPGVPARFSGSLIGDKAALTITGAQTEDEAIYFCALWYNNHWVFGGGTKLTVL 7 | -------------------------------------------------------------------------------- /data/fasta_file/7x2l.fasta: -------------------------------------------------------------------------------- 1 | >7X2L_1|Chain A[auth E]|Spike protein S1|Severe acute respiratory syndrome coronavirus 2 (2697049) 2 | TNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPK 3 | >7X2L_2|Chain B|Nanobody 3-2A2-4|Vicugna pacos (30538) 4 | QVQLQESGGGLVQPGESLRLSCAASGSISTLNVMGWYRQAPGKQRELVAQITLDGSPEYADSVKGRFTITKDGAQSTLYLQMNNLKPEDTAVYFCKLENGGFFYYWGQGTQVTVSTHHHHHH 5 | -------------------------------------------------------------------------------- /data/nanobody_eval_data/nanobert_exp.csv: -------------------------------------------------------------------------------- 1 | ,name,vhhseq 2 | 0,tarperprumig,QVQLVESGGGLVKPGGSLRLSCAASGRPVSNYAAAWFRQAPGKEREFVSAINWQKTATYADSVKGRFTISRDNAKNSLYLQMNSLRAEDTAVYYCAAVFRVVAPKTQYDYDYWGQGTLVTVSS 3 | 1,gefurulimab,EVQLVESGGGLVKPGGSLRLSCAASGRPVSNYAAAWFRQAPGKEREFVSAINWQKTATYADSVKGRFTISRDNAKNSLYLQMNSLRAEDTAVYYCAAVFRVVAPKTQYDYDYWGQGTLVTVSS 4 | 2,gontivimab,DVQLVESGGGLVQAGGSLSISCAASGGSLSNYVLGWFRQAPGKEREFVAAINWRGDITIGPPNVEGRFTISRDNAKNTGYLQMNSLAPDDTAVYYCGAGTPLNPGAYIYDWSYDYWGRGTQVTVSS 5 | 3,sonelokimab1,DVQLVESGGGLVQPGGSLRLSCAASGRTFSSYVVGWFRQAPGKEREFIGAISGSGESIYYAVSEKGRFTISRDNSKNTLYLQMNSLRPEDTAVYYCTADQEFGYLRFGRSEYWGQGTLVTVSS 6 | 4,enristomig,EVQLLESGGGEVQPGGSLRLSCAASGGIFAIKPISWYRQAPGKQREWVSTTTSSGATNYAESVKGRFTISRDNAKNTLYLQMSSLRAEDTAVYYCNVFEYWGQGTLVTVKP 7 | 5,ozekibart,EVQLLESGGGEVQPGGSLRLSCAASGLTFPNYGMGWFRQAPGKEREFVSAIYWSGGTVYYAESVKGRFTISRDNAKNTLYLQMSSLRAEDTAVYYCAVTIRGAATQTWKYDYWGQGTLVTVKP 8 | 6,letolizumab,EVQLLESGGGLVQPGGSLRLSCAASGFTFNWELMGWARQAPGKGLEWVSGIEGPGDVTYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCVKVGKDAKSDYRGQGTLVTVSS 9 | 7,sonelokimab2,EVQLVESGGGLVQPGGSLRLSCAASGRTYDAMGWLRQAPGKEREFVAAISGSGDDTYYADSVKGRFTISRDNSKNTLYLQMNSLRPEDTAVYYCATRRGLYYVWDANDYENWGQGTLVTVSS 10 | 8,lunsekimig1,DVQLVESGGGVVQPGGSLRLSCAASGRTFSSYRMGWFRQAPGKEREFVAALSGDGYSTYTANSVKGRFTISRDNSKNTVYLQMNSLRPEDTALYYCAAKLQYVSGWSYDYPYWGQGTLVTVSS 11 | 9,isecarosmab,DVQLVESGGGVVQPGGSLRLSCAASGRTVSSYAMGWFRQAPGKEREFVAGISRSAERTYYVDSLKGRFTISRDNSKNTVYLQMNSLRPEDTALYYCAADLDPNRIFSREEYAYWGQGTLVTVSS 12 | 10,lunsekimig2,EVQLVESGGGVVQPGGSLRLSCAASGSGFGVNILYWYRQAAGIERELIASITSGGITNYVDSVKGRFTISRDNSENTMYLQMNSLRAEDTGLYYCASRNIFDGTTEWGQGTLVTVSS 13 | 11,erfonrilimab,QVQLVESGGGLVQPGGSLRLSCAASGKMSSRRCMAWFRQAPGKERERVAKLLTTSGSTYLADSVKGRFTISRDNSKNTVYLQMNSLRAEDTAVYYCAADSFEDPTCTLVTSSGAFQYWGQGTLVTVSS 14 | 12,lunsekimig3,EVQLVESGGGVVQPGGSLRLSCAASGFTFADYDYDIGWFRQAPGKEREGVSCISNRDGSTYYADSVKGRFTISRDNSKNTVYLQMNSLRPEDTALYYCAVEIHCDDYGVENFDFDPWGQGTLVTVSS 15 | 13,porustobart,EVQLVESGGGLIQPGGSLRLSCAVSGFTVSKNYMSWVRQAPGKGLEWVSVVYSGGSKTYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCARAVPHSPSSFDIWGQGTMVTVSS 16 | 14,rimteravimab,DVQLVESGGGLVQPGGSLRLSCAASGRTFSEYAMGWFRQAPGKEREFVATISWSGGATYYTDSVKGRFTISRDNAKNTVYLQMNSLRPEDTAVYYCAAAGLGTVVSEWDYDYDYWGQGTLVTVSS 17 | 15,caplacizumab,EVQLVESGGGLVQPGGSLRLSCAASGRTFSYNPMGWFRQAPGKGRELVAAISRTGGSTYYPDSVEGRFTISRDNAKRMVYLQMNSLRAEDTAVYYCAAAGVRAEDGRVRTLPSEYTFWGQGTQVTVSS 18 | 16,ozoralizumab,EVQLVESGGGLVQPGGSLRLSCAASGFTFSDYWMYWVRQAPGKGLEWVSEINTNGLITKYPDSVKGRFTISRDNAKNTLYLQMNSLRPEDTAVYYCARSPSGFNRGQGTLVTVSS 19 | 17,vobarilizumab,EVQLVESGGGLVQPGGSLRLSCAASGSVFKINVMAWYRQAPGKGRELVAGIISGGSTSYADSVKGRFTISRDNAKNTLYLQMNSLRPEDTAVYYCAFITTESDYDLGRRYWGQGTLVTVSS 20 | -------------------------------------------------------------------------------- /data/train_sh_file/camel_data/nano_download.sh: -------------------------------------------------------------------------------- 1 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544217_Heavy_Bulk.csv.gz 2 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544217_Heavy_IGHA.csv.gz 3 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544217_Heavy_IGHD.csv.gz 4 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544217_Heavy_IGHE.csv.gz 5 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544217_Heavy_IGHG.csv.gz 6 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544217_Heavy_IGHM.csv.gz 7 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544218_Heavy_Bulk.csv.gz 8 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544218_Heavy_IGHD.csv.gz 9 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544218_Heavy_IGHE.csv.gz 10 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544218_Heavy_IGHG.csv.gz 11 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544218_Heavy_IGHM.csv.gz 12 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544219_Heavy_Bulk.csv.gz 13 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544219_Heavy_IGHA.csv.gz 14 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544219_Heavy_IGHM.csv.gz 15 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544220_Heavy_Bulk.csv.gz 16 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544220_Heavy_IGHA.csv.gz 17 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544220_Heavy_IGHG.csv.gz 18 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544220_Heavy_IGHM.csv.gz 19 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544221_Heavy_Bulk.csv.gz 20 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544221_Heavy_IGHA.csv.gz 21 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544221_Heavy_IGHM.csv.gz 22 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544222_Heavy_Bulk.csv.gz 23 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544222_Heavy_IGHA.csv.gz 24 | wget https://opig.stats.ox.ac.uk/webapps/ngsdb/unpaired/Li_2017/csv/SRR3544222_Heavy_IGHM.csv.gz -------------------------------------------------------------------------------- /dataset/abnativ_alignment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAI4S/HuDiff/bb7636f182699f98c37855dad05a5c6c61b576bd/dataset/abnativ_alignment/__init__.py -------------------------------------------------------------------------------- /dataset/abnativ_alignment/aho_consensus.py: -------------------------------------------------------------------------------- 1 | # (c) 2023 Sormannilab 2 | 3 | 4 | 5 | # List of idx in the AHo numbering 6 | cdr1_aho_indices, cdr2_aho_indices, cdr3_aho_indices = range(27,43), range(57,70), range(108,139) 7 | fr_aho_indices = list(range(1,27)) + list(range(43,57)) + list(range(70,108)) + list(range(139,150)) 8 | 9 | # Seed consensus 149-long 10 | VH_consensus_no_gaps = "QVQLVESSGGGLVQPGGSLRLSCAASGGFTFSSSTLSGYYMHWVRQAPGKGLEWVGYISPSAGNGGSTYYADSVKGRFTISRDNSKNTAYLQMNSLRSEDTAVYYCARDGGYYGSDGGVAYYAADFFGEYYYYYYFDYWGQGTLVTVSS" 11 | VH_conservation_index = [ 12 | 0.649, 13 | 0.871, 14 | 0.828, 15 | 0.946, 16 | 0.556, 17 | 0.725, 18 | 0.859, 19 | 0.256, 20 | 0.916, 21 | 0.525, 22 | 0.612, 23 | 0.709, 24 | 0.717, 25 | 0.602, 26 | 0.873, 27 | 0.758, 28 | 0.394, 29 | 0.76, 30 | 0.666, 31 | 0.536, 32 | 0.675, 33 | 0.739, 34 | 1.015, 35 | 0.473, 36 | 0.678, 37 | 0.835, 38 | 0.865, 39 | 0.254, 40 | 0.498, 41 | 0.484, 42 | 0.655, 43 | 0.421, 44 | 0.298, 45 | 0.241, 46 | 0.256, 47 | 0.256, 48 | 0.25, 49 | 0.229, 50 | 0.592, 51 | 0.252, 52 | 0.59, 53 | 0.359, 54 | 1.002, 55 | 0.676, 56 | 0.833, 57 | 0.937, 58 | 0.543, 59 | 0.929, 60 | 0.815, 61 | 0.699, 62 | 0.628, 63 | 0.797, 64 | 0.933, 65 | 0.849, 66 | 0.534, 67 | 0.534, 68 | 0.129, 69 | 0.867, 70 | 0.272, 71 | 0.333, 72 | 0.141, 73 | 0.243, 74 | 0.253, 75 | 0.244, 76 | 0.303, 77 | 0.519, 78 | 0.189, 79 | 0.59, 80 | 0.321, 81 | 0.903, 82 | 0.505, 83 | 0.466, 84 | 0.52, 85 | 0.545, 86 | 0.717, 87 | 0.663, 88 | 0.808, 89 | 0.556, 90 | 0.777, 91 | 0.716, 92 | 0.618, 93 | 0.522, 94 | 0.926, 95 | 0.472, 96 | 0.619, 97 | 0.528, 98 | 0.614, 99 | 0.658, 100 | 0.45, 101 | 0.719, 102 | 0.747, 103 | 0.634, 104 | 0.643, 105 | 0.518, 106 | 0.708, 107 | 0.805, 108 | 0.508, 109 | 0.404, 110 | 0.754, 111 | 0.987, 112 | 0.786, 113 | 0.912, 114 | 0.601, 115 | 0.962, 116 | 0.823, 117 | 1.015, 118 | 0.727, 119 | 0.627, 120 | 0.161, 121 | 0.142, 122 | 0.139, 123 | 0.157, 124 | 0.126, 125 | 0.135, 126 | 0.166, 127 | 0.205, 128 | 0.213, 129 | 0.235, 130 | 0.247, 131 | 0.251, 132 | 0.257, 133 | 0.257, 134 | 0.258, 135 | 0.258, 136 | 0.257, 137 | 0.257, 138 | 0.252, 139 | 0.244, 140 | 0.231, 141 | 0.212, 142 | 0.189, 143 | 0.164, 144 | 0.156, 145 | 0.201, 146 | 0.214, 147 | 0.472, 148 | 0.691, 149 | 0.537, 150 | 0.981, 151 | 0.922, 152 | 0.814, 153 | 0.924, 154 | 0.897, 155 | 0.444, 156 | 0.863, 157 | 0.916, 158 | 0.945, 159 | 0.892, 160 | 0.795, 161 | ] 162 | 163 | # old mixed of kappa and lambda no longer sure of which species but likely human 164 | VL_consensus_no_gaps = "DIVMTQSPSSLSASPGDRVTISCRASSSQSISSSNNGNNYLAWYQQKPGQAPKLLIYGKSAAAADSASNRASGVPDRFSGSGSGSGTDFTLTISSLQAEDFATYYCQQYDSSLSSFAAAAAAAAAAAAAFGGESSPYTFGGGTKLEIKR" 165 | VL_conservation_index = [ 166 | 0.605, 167 | 0.634, 168 | 0.576, 169 | 0.667, 170 | 0.905, 171 | 0.968, 172 | 0.563, 173 | 0.708, 174 | 0.322, 175 | 0.563, 176 | 0.657, 177 | 0.635, 178 | 0.473, 179 | 0.68, 180 | 0.512, 181 | 0.927, 182 | 0.532, 183 | 0.548, 184 | 0.672, 185 | 0.682, 186 | 0.705, 187 | 0.54, 188 | 1.016, 189 | 0.553, 190 | 0.564, 191 | 0.627, 192 | 0.211, 193 | 0.222, 194 | 0.571, 195 | 0.435, 196 | 0.561, 197 | 0.209, 198 | 0.193, 199 | 0.208, 200 | 0.238, 201 | 0.259, 202 | 0.243, 203 | 0.217, 204 | 0.363, 205 | 0.525, 206 | 0.559, 207 | 0.415, 208 | 1.02, 209 | 0.852, 210 | 0.836, 211 | 0.881, 212 | 0.752, 213 | 0.846, 214 | 0.823, 215 | 0.499, 216 | 0.496, 217 | 0.914, 218 | 0.621, 219 | 0.765, 220 | 0.819, 221 | 0.918, 222 | 0.806, 223 | 0.201, 224 | 0.264, 225 | 0.262, 226 | 0.265, 227 | 0.265, 228 | 0.265, 229 | 0.265, 230 | 0.264, 231 | 0.229, 232 | 0.507, 233 | 0.657, 234 | 0.347, 235 | 0.615, 236 | 0.329, 237 | 0.604, 238 | 0.921, 239 | 0.767, 240 | 0.906, 241 | 0.464, 242 | 0.987, 243 | 0.993, 244 | 0.763, 245 | 0.915, 246 | 0.847, 247 | 0.663, 248 | 0.813, 249 | 0.765, 250 | 0.259, 251 | 0.243, 252 | 0.742, 253 | 0.598, 254 | 0.697, 255 | 0.644, 256 | 0.921, 257 | 0.661, 258 | 0.965, 259 | 0.633, 260 | 0.401, 261 | 0.626, 262 | 0.643, 263 | 0.464, 264 | 0.87, 265 | 0.993, 266 | 0.427, 267 | 0.792, 268 | 0.444, 269 | 0.975, 270 | 0.816, 271 | 1.016, 272 | 0.584, 273 | 0.65, 274 | 0.338, 275 | 0.246, 276 | 0.241, 277 | 0.2, 278 | 0.252, 279 | 0.264, 280 | 0.265, 281 | 0.265, 282 | 0.265, 283 | 0.265, 284 | 0.265, 285 | 0.265, 286 | 0.265, 287 | 0.265, 288 | 0.265, 289 | 0.265, 290 | 0.265, 291 | 0.265, 292 | 0.265, 293 | 0.265, 294 | 0.265, 295 | 0.265, 296 | 0.265, 297 | 0.265, 298 | 0.263, 299 | 0.221, 300 | 0.159, 301 | 0.569, 302 | 0.284, 303 | 0.698, 304 | 0.991, 305 | 0.935, 306 | 0.477, 307 | 0.938, 308 | 0.941, 309 | 0.876, 310 | 0.678, 311 | 0.733, 312 | 0.654, 313 | 0.766, 314 | 0.753, 315 | ] 316 | 317 | VKappa_consensus_no_gaps = "DIVMTQSPDSLSVSPGERATISCRAS-SQSISHSSNGKSYLAWYQQKPGQAPKLLIYYASNARFSLASTRASGVPSRFSGSGSGGGTDFTLTISSLEAEDFAVYYCQQYSSWPPFTP------------RDTPPLPLTFGQGTKVEIK-" 318 | VKappa_conservation_index = [ 319 | 0.684, 320 | 0.806, 321 | 0.684, 322 | 0.689, 323 | 0.911, 324 | 0.911, 325 | 0.752, 326 | 0.945, 327 | 0.362, 328 | 0.461, 329 | 0.695, 330 | 0.645, 331 | 0.625, 332 | 0.568, 333 | 0.629, 334 | 0.782, 335 | 0.635, 336 | 0.587, 337 | 0.67, 338 | 0.643, 339 | 0.756, 340 | 0.499, 341 | 1.015, 342 | 0.725, 343 | 0.663, 344 | 0.635, 345 | 0.263, 346 | 0.238, 347 | 0.916, 348 | 0.649, 349 | 0.552, 350 | 0.339, 351 | 0.279, 352 | 0.264, 353 | 0.22, 354 | 0.3, 355 | 0.288, 356 | 0.259, 357 | 0.452, 358 | 0.543, 359 | 0.858, 360 | 0.493, 361 | 1.0, 362 | 0.918, 363 | 0.77, 364 | 0.93, 365 | 0.912, 366 | 0.94, 367 | 0.783, 368 | 0.677, 369 | 0.517, 370 | 0.882, 371 | 0.532, 372 | 0.806, 373 | 0.866, 374 | 0.939, 375 | 0.715, 376 | 0.281, 377 | 0.263, 378 | 0.263, 379 | 0.263, 380 | 0.263, 381 | 0.263, 382 | 0.263, 383 | 0.263, 384 | 0.225, 385 | 0.74, 386 | 0.784, 387 | 0.39, 388 | 0.566, 389 | 0.389, 390 | 0.553, 391 | 0.928, 392 | 0.731, 393 | 0.929, 394 | 0.459, 395 | 0.97, 396 | 0.958, 397 | 0.855, 398 | 0.929, 399 | 0.855, 400 | 0.928, 401 | 0.783, 402 | 0.705, 403 | 0.263, 404 | 0.289, 405 | 0.926, 406 | 0.816, 407 | 0.969, 408 | 0.92, 409 | 0.907, 410 | 0.741, 411 | 0.956, 412 | 0.66, 413 | 0.564, 414 | 0.681, 415 | 0.674, 416 | 0.533, 417 | 0.933, 418 | 0.976, 419 | 0.53, 420 | 0.766, 421 | 0.561, 422 | 0.97, 423 | 0.881, 424 | 1.015, 425 | 0.566, 426 | 0.914, 427 | 0.394, 428 | 0.273, 429 | 0.428, 430 | 0.244, 431 | 0.262, 432 | 0.263, 433 | 0.263, 434 | 0.263, 435 | 0.263, 436 | 0.263, 437 | 0.263, 438 | 0.263, 439 | 0.263, 440 | 0.263, 441 | 0.263, 442 | 0.263, 443 | 0.263, 444 | 0.263, 445 | 0.263, 446 | 0.263, 447 | 0.263, 448 | 0.263, 449 | 0.263, 450 | 0.263, 451 | 0.263, 452 | 0.246, 453 | 0.28, 454 | 0.638, 455 | 0.292, 456 | 0.843, 457 | 0.972, 458 | 0.942, 459 | 0.638, 460 | 0.94, 461 | 0.94, 462 | 0.879, 463 | 0.703, 464 | 0.887, 465 | 0.955, 466 | 0.985, 467 | 0.263, 468 | ] 469 | 470 | VLambda_consensus_no_gaps = "QSVLTQP-PSVSVSPGQTVTLTCTGSSAGSVGSDL-AGYYVSWYQQKPGQAPRLLIYENSGS-SDGDNNRPSGVPDRFSGSKSGSSNTASLTISGLQAEDEADYYCQSYDSSLSGLS-----------ADGFSLSAWVFGGGTKLTVLL" # ovwrite with edited 471 | VLambda_conservation_index = [ 472 | 0.736, 473 | 0.275, 474 | 0.591, 475 | 0.75, 476 | 0.904, 477 | 0.931, 478 | 0.631, 479 | 0.274, 480 | 0.63, 481 | 0.823, 482 | 0.501, 483 | 0.734, 484 | 0.422, 485 | 0.615, 486 | 0.763, 487 | 0.832, 488 | 0.437, 489 | 0.505, 490 | 0.653, 491 | 0.589, 492 | 0.648, 493 | 0.59, 494 | 1.018, 495 | 0.498, 496 | 0.461, 497 | 0.448, 498 | 0.495, 499 | 0.274, 500 | 0.407, 501 | 0.28, 502 | 0.483, 503 | 0.307, 504 | 0.336, 505 | 0.274, 506 | 0.274, 507 | 0.274, 508 | 0.274, 509 | 0.184, 510 | 0.511, 511 | 0.498, 512 | 0.478, 513 | 0.305, 514 | 0.987, 515 | 0.684, 516 | 0.94, 517 | 0.929, 518 | 0.424, 519 | 0.888, 520 | 0.812, 521 | 0.398, 522 | 0.599, 523 | 0.963, 524 | 0.484, 525 | 0.464, 526 | 0.635, 527 | 0.667, 528 | 0.74, 529 | 0.258, 530 | 0.277, 531 | 0.221, 532 | 0.274, 533 | 0.274, 534 | 0.274, 535 | 0.274, 536 | 0.304, 537 | 0.201, 538 | 0.385, 539 | 0.439, 540 | 0.396, 541 | 0.69, 542 | 0.537, 543 | 0.736, 544 | 0.793, 545 | 0.616, 546 | 0.805, 547 | 0.519, 548 | 0.987, 549 | 0.928, 550 | 0.855, 551 | 0.834, 552 | 0.852, 553 | 0.417, 554 | 0.488, 555 | 0.667, 556 | 0.239, 557 | 0.222, 558 | 0.609, 559 | 0.362, 560 | 0.721, 561 | 0.381, 562 | 0.944, 563 | 0.696, 564 | 0.848, 565 | 0.624, 566 | 0.75, 567 | 0.631, 568 | 0.805, 569 | 0.414, 570 | 0.81, 571 | 0.984, 572 | 0.993, 573 | 0.843, 574 | 0.863, 575 | 0.971, 576 | 0.952, 577 | 1.018, 578 | 0.35, 579 | 0.316, 580 | 0.592, 581 | 0.462, 582 | 0.459, 583 | 0.389, 584 | 0.25, 585 | 0.274, 586 | 0.274, 587 | 0.274, 588 | 0.274, 589 | 0.274, 590 | 0.274, 591 | 0.274, 592 | 0.274, 593 | 0.274, 594 | 0.274, 595 | 0.274, 596 | 0.274, 597 | 0.274, 598 | 0.274, 599 | 0.274, 600 | 0.274, 601 | 0.274, 602 | 0.274, 603 | 0.274, 604 | 0.27, 605 | 0.201, 606 | 0.256, 607 | 0.204, 608 | 0.387, 609 | 0.905, 610 | 0.999, 611 | 0.918, 612 | 0.788, 613 | 0.917, 614 | 0.942, 615 | 0.956, 616 | 0.82, 617 | 0.934, 618 | 0.963, 619 | 0.955, 620 | 0.274, 621 | ] 622 | 623 | VHH_consensus_no_gaps = "QVQLQESGGGGLVQAGGSLRLSCAASGSRTFSSYFGDTYAMGWFRQAPGKEREFVAAISSSGSSGGSTYYADSVKGRFTISRDNAKNTVYLQMNSLKPEDTAVYYCAAGRGGSGSSGYCGVAAAAIHAAYTSPGEYDYWGQGTQVTVSS" 624 | VHH_conservation_index = [ 625 | 0.741, 626 | 0.917, 627 | 0.826, 628 | 0.957, 629 | 0.627, 630 | 0.826, 631 | 0.888, 632 | 0.26, 633 | 0.912, 634 | 0.901, 635 | 0.887, 636 | 0.713, 637 | 0.923, 638 | 0.913, 639 | 0.627, 640 | 0.902, 641 | 0.811, 642 | 0.906, 643 | 0.962, 644 | 0.87, 645 | 0.949, 646 | 0.871, 647 | 1.013, 648 | 0.713, 649 | 0.782, 650 | 0.874, 651 | 0.733, 652 | 0.26, 653 | 0.32, 654 | 0.512, 655 | 0.507, 656 | 0.474, 657 | 0.21, 658 | 0.255, 659 | 0.259, 660 | 0.26, 661 | 0.255, 662 | 0.252, 663 | 0.485, 664 | 0.244, 665 | 0.736, 666 | 0.543, 667 | 1.007, 668 | 0.643, 669 | 0.961, 670 | 0.923, 671 | 0.803, 672 | 0.965, 673 | 0.885, 674 | 0.863, 675 | 0.656, 676 | 0.819, 677 | 0.933, 678 | 0.425, 679 | 0.881, 680 | 0.649, 681 | 0.249, 682 | 0.802, 683 | 0.316, 684 | 0.304, 685 | 0.205, 686 | 0.257, 687 | 0.258, 688 | 0.249, 689 | 0.493, 690 | 0.463, 691 | 0.264, 692 | 0.706, 693 | 0.398, 694 | 0.903, 695 | 0.63, 696 | 0.835, 697 | 0.817, 698 | 0.858, 699 | 0.885, 700 | 0.874, 701 | 0.963, 702 | 0.97, 703 | 0.884, 704 | 0.9, 705 | 0.878, 706 | 0.762, 707 | 0.952, 708 | 0.782, 709 | 0.783, 710 | 0.834, 711 | 0.805, 712 | 0.825, 713 | 0.644, 714 | 0.811, 715 | 0.956, 716 | 0.853, 717 | 0.969, 718 | 0.839, 719 | 0.699, 720 | 0.941, 721 | 0.816, 722 | 0.84, 723 | 0.948, 724 | 0.987, 725 | 0.933, 726 | 0.892, 727 | 0.659, 728 | 0.969, 729 | 0.827, 730 | 1.013, 731 | 0.637, 732 | 0.511, 733 | 0.159, 734 | 0.132, 735 | 0.083, 736 | 0.101, 737 | 0.089, 738 | 0.113, 739 | 0.136, 740 | 0.177, 741 | 0.218, 742 | 0.247, 743 | 0.258, 744 | 0.259, 745 | 0.26, 746 | 0.26, 747 | 0.26, 748 | 0.26, 749 | 0.26, 750 | 0.26, 751 | 0.259, 752 | 0.253, 753 | 0.239, 754 | 0.208, 755 | 0.156, 756 | 0.141, 757 | 0.15, 758 | 0.106, 759 | 0.161, 760 | 0.432, 761 | 0.38, 762 | 0.639, 763 | 0.887, 764 | 0.904, 765 | 0.877, 766 | 0.912, 767 | 0.954, 768 | 0.858, 769 | 0.951, 770 | 0.955, 771 | 0.947, 772 | 0.898, 773 | 0.873, 774 | ] -------------------------------------------------------------------------------- /dataset/abnativ_alignment/align_and_clean.py: -------------------------------------------------------------------------------- 1 | # (c) 2023 Sormannilab and Aubin Ramon 2 | # 3 | # Alignment with ANARCI and cleaning of sequences. 4 | # 5 | # ============================================================================ 6 | 7 | import sys 8 | from .mybio import Anarci_alignment, get_SeqRecords, get_antibodyVD_numbers, clean_anarci_alignment 9 | 10 | 11 | def anarci_alignments_of_Fv_sequences(fp_seq_not_align_list, seed: bool = True, dont_change_header: bool = True, scheme: str = 'AHo', isVHH: bool = False, clean: bool = True, minimum_added: int = 50, nb_N_gaps: int=1, 12 | add_C_term_missing_VH: str = 'SS', add_C_term_missing_VKappa: str = 'K', add_C_term_missing_VLambda: str = 'L', check_duplicates: bool = False, del_cyst_misalign=False, verbose: bool=True) : 13 | ''' 14 | Align sequences from a fasta file using the ANARCI program available from https://github.com/oxpig/ANARCI 15 | 16 | Careful, in most light chains you need at least tolerate_missing_termini_residues=1 in the AHo numbering 17 | 18 | Parameters 19 | ---------- 20 | - fp_fasta : str 21 | Filepath to the fasta file with the sequences to align or a single string sequence 22 | e.g., 'seqs.fa' or 'QVQE...VSS' 23 | - seed : bool 24 | If True, start the numbering with a well defined seed in AHo numbering scheme to get all AHo numbers 25 | and then potentially discard unusual sequences by setting dont_change_header to True in add_sequence 26 | - dont_change_heade: bool 27 | If True, discard unusual sequences based on the header of the first sequence 28 | - scheme : str 29 | Type of numbering scheme, cleaning only supports AHo numbering 30 | - isVHH : bool 31 | If True, will specify to the heavy chains that they are VHH sequences 32 | - clean : bool 33 | If True, clean sequences based on custom parameters 34 | - minimum_added : int 35 | Minimum size of sequences 36 | - nb_N_gaps : int 37 | If not None, allow nb_N_gaps consecutive gap at the N-terminal 38 | - add_C_term_missing_VH : str 39 | If not None, add the string motif if missing at the C-terminal (from posi 149 backwards) for Heavy chains 40 | - add_C_term_missing_VKappa : str 41 | If not None, add the string motif if missing at the C-terminal (from posi 148 backwards) for Kappa chains 42 | - add_C_term_missing_VLambda : str 43 | If not None, add the string motif if missing at the C-terminal (from posi 148 backwards) for Lambda chains 44 | - check_duplicates : bool 45 | If True, remove duplicates among the same chain type (only!) 46 | - del_cyst_misalign : bool 47 | If True, remove the misaligned cysteines sequences (should be set to False if sequence has been mutated at those cysteines positions). 48 | Default is False for prediction. 49 | - verbose: bool 50 | 51 | Returns 52 | ------- 53 | - VH : Anarci_alignment class (see mybio) 54 | - failed,mischtype : list of tuples 55 | failed and mischtype are list of tuples like [(j,seq_name,seq,chtype),...] 56 | mischtype are for chains that are not H, K, or L according to anarci. 57 | 58 | ''' 59 | if clean and scheme!='AHo' : 60 | sys.stderr.write("**WARNING** in anarci_alignments_of_Fv_sequences clean requested with scheme=%s, but at present only supported for AHo - setting clean to False!"%(scheme)) 61 | clean=False 62 | 63 | Fv_Sequences = fp_seq_not_align_list 64 | 65 | # Fv_Sequences = get_SeqRecords(fp_seq_not_align_list) 66 | 67 | # start run 68 | VH=Anarci_alignment(seed=seed,scheme=scheme,chain_type='H',isVHH=isVHH) 69 | VK=Anarci_alignment(seed=seed,scheme=scheme,chain_type='K') 70 | VL=Anarci_alignment(seed=seed,scheme=scheme,chain_type='L') 71 | 72 | failed, mischtype= list(), list() 73 | try_to_fix_misalignedCys=False # True makes it much slower and at least for VHHs has no effect on failed rate 74 | 75 | for j,seq in enumerate(Fv_Sequences) : 76 | if verbose and j%300==0 : 77 | _=sys.stdout.write(' anarci_alignments_of_Fv_sequences done %d of %d -> %.2lf %% (len(alH)=%d len(failed)=%d len(mischtype)=%d)\n' %(j,len(Fv_Sequences),100.*j/len(Fv_Sequences),len(VH),len(failed),len(mischtype))) 78 | sys.stdout.flush() 79 | if hasattr(seq,'seq') : # seq record 80 | seq_name=seq.id 81 | seq=str(seq.seq) 82 | else : 83 | seq_name=str(j) 84 | Fv_res = get_antibodyVD_numbers(seq, scheme=scheme, full_return=True, seqname=seq_name, print_warns=False, auto_detect_chain_type=True) 85 | if Fv_res is None : # failed 86 | failed+=[(j,seq_name,seq,None)] 87 | continue 88 | for chtype in Fv_res : 89 | seqind_to_schemnum,schemnum_to_seqind,seqind_regions,warnings, info_dict, eval_table=Fv_res[chtype] 90 | if chtype=='H' : 91 | ok=VH.add_processed_sequence(seq, seqind_to_schemnum, seqind_regions, seq_name, minimum_added=minimum_added, dont_change_header=dont_change_header, try_to_fix_misalignedCys=try_to_fix_misalignedCys) 92 | if ok is None or ok %.2lf %% (len(alH)=%d len(failed)=%d len(mischtype)=%d)\n' %(j+1,len(Fv_Sequences),100.*(j+1)/len(Fv_Sequences),len(VH),len(failed),len(mischtype))) 106 | sys.stdout.flush() 107 | if clean : 108 | if len(VH)> 0 : 109 | if verbose: print("\n- Cleaning Heavy -") 110 | VH, failed_idx = clean_anarci_alignment(VH, warn=verbose,cons_cys_HD=[23,106], del_cyst_misalign=del_cyst_misalign, add_Nterm_missing=None, add_C_term_missing=add_C_term_missing_VH, isVHH=isVHH, check_duplicates=check_duplicates, nb_N_gaps=nb_N_gaps, verbose=verbose) 111 | if len(failed_idx) != 0: 112 | for k in failed_idx: 113 | failed+=[(k, None, None, None)] 114 | if len(VK)> 0 : 115 | if verbose: print("\n- Cleaning Kappa -") 116 | VK, failed_idx = clean_anarci_alignment(VK, warn=verbose,cons_cys_HD=[23,106],del_cyst_misalign=del_cyst_misalign, add_Nterm_missing=None, add_C_term_missing=add_C_term_missing_VKappa, isVHH=False, check_duplicates=check_duplicates, nb_N_gaps=nb_N_gaps, verbose=verbose) 117 | if len(failed_idx) != 0: 118 | for k in failed_idx: 119 | failed+=[(k, None, None, None)] 120 | if len(VL)> 0 : 121 | if verbose: print("\n- Cleaning Lambda -") 122 | VL, failed_idx = clean_anarci_alignment(VL, warn=verbose,cons_cys_HD=[23,106],del_cyst_misalign=del_cyst_misalign,add_Nterm_missing=None, add_C_term_missing=add_C_term_missing_VLambda, isVHH=False, check_duplicates=check_duplicates, nb_N_gaps=nb_N_gaps, verbose=verbose) 123 | if len(failed_idx) != 0: 124 | for k in failed_idx: 125 | failed+=[(k, None, None, None)] 126 | return VH, VK, VL, failed, mischtype 127 | -------------------------------------------------------------------------------- /dataset/oas_dataset.py: -------------------------------------------------------------------------------- 1 | """The OAS dataset.""" 2 | 3 | import os 4 | import re 5 | import random 6 | import logging 7 | from collections import defaultdict 8 | from tqdm import tqdm 9 | 10 | from torch.utils.data import IterableDataset 11 | from torch.utils.data import DataLoader 12 | from torch.utils.data import get_worker_info 13 | 14 | from tfold_utils.common_utils import tfold_init 15 | from tfold_utils.torch_utils import inspect_data 16 | from tfold_utils.prot_constants import RESD_NAMES_1C 17 | from utils.anti_numbering import get_regions 18 | 19 | 20 | def parse_csv_file(path, n_seqs_max=4096): 21 | """Parse the CSV file.""" 22 | 23 | # set of standard amino-acids 24 | resd_names = set(RESD_NAMES_1C) 25 | 26 | # parse the CSV file 27 | seq_list = [] 28 | headers = None 29 | with open(path, 'r', encoding='UTF-8') as i_file: 30 | for i_line in i_file: 31 | sub_strs = i_line.strip().split(',') 32 | if headers is None: 33 | headers = sub_strs 34 | else: 35 | seq_dict = {k: v for k, v in zip(headers, sub_strs)} 36 | # if len(set(seq_dict['chn']) - resd_names) != 0: 37 | # continue 38 | seq_list.append(seq_dict) 39 | if (n_seqs_max != -1) and (len(seq_list) >= n_seqs_max): 40 | break 41 | 42 | return seq_list 43 | 44 | 45 | class OasDataset(IterableDataset): 46 | """The OAS dataset.""" 47 | 48 | def __init__( 49 | self, 50 | csv_dpath=None, # directory path to CSV files 51 | pool_size=65536, # minimal number of candidate sequences in the pool 52 | n_seqs_max=4096, # maximal number of sequences parsed from a single CSV file 53 | spc_mode='human-only', # species mode (choices: 'no-human' / 'human-only') 54 | ): 55 | """Constructor function.""" 56 | 57 | super().__init__() 58 | 59 | # setup configurations 60 | self.csv_dpath = csv_dpath 61 | self.pool_size = pool_size 62 | self.n_seqs_max = n_seqs_max 63 | self.spc_mode = spc_mode 64 | 65 | # initialize the dataset 66 | self.__init_dataset() 67 | 68 | 69 | def __iter__(self): 70 | """Return an iterator of samples in the dataset.""" 71 | 72 | # initialization 73 | # n_files = len(self.file_list) 74 | 75 | # validate the worker information 76 | worker_info = get_worker_info() 77 | assert worker_info is None, 'only single-process data loading is supported' 78 | 79 | # initialize a new epoch 80 | # random.shuffle(self.file_list) 81 | logging.debug('=== start of epoch ===') 82 | 83 | # initialize the sequence pool 84 | idx_file = 0 85 | seq_pool = [] # list of (species, chn_type, seq_dict)-tuples 86 | while (len(seq_pool) < self.pool_size): 87 | # species, chn_type, csv_fpath = self.file_list[idx_file] 88 | csv_fpath = self.csv_dpath 89 | seq_list = parse_csv_file(csv_fpath, self.n_seqs_max) 90 | logging.debug('adding %d sequences from %s', len(seq_list), csv_fpath) 91 | seq_pool.extend([(x['ENTRY'], x['HSEQ'], x['LSEQ']) for x in seq_list]) 92 | random.shuffle(seq_pool) 93 | logging.debug('# of CSV files: %d (total)', len(seq_pool)) 94 | 95 | # parse CSV files and return an iterator of samples 96 | while len(seq_pool) > 0: 97 | # build an input dict 98 | species, H_seq, L_seq = seq_pool.pop() 99 | try: 100 | inputs = self.__build_inputs(species, H_seq, L_seq) 101 | yield inputs 102 | except: 103 | continue 104 | 105 | # early-exit if no expansion is needed 106 | if (len(seq_pool) >= self.pool_size): 107 | continue 108 | 109 | # indicate the end of current epoch 110 | logging.debug('=== end of epoch ===') 111 | 112 | 113 | def __init_dataset(self): 114 | """Initialize the dataset.""" 115 | logging.debug('=== CSV files ===') 116 | logging.debug('load CSV files {}'.format(self.csv_dpath)) 117 | 118 | 119 | @classmethod 120 | def __build_inputs(cls, species, H_seq, L_seq): 121 | """Build an input dict.""" 122 | 123 | inputs = { 124 | 'spec': species, 125 | 'hseq': H_seq, 126 | 'lseq': L_seq, 127 | 'hseq_cdr_region': get_regions(H_seq), 128 | 'lseq_cdr_region': get_regions(L_seq) 129 | } 130 | 131 | return inputs 132 | 133 | 134 | class OasDatasetShuf(IterableDataset): 135 | """The OAS dataset for CSV files containing randomly shuffled sequences.""" 136 | 137 | def __init__( 138 | self, 139 | csv_dpath=None, # directory path to CSV files 140 | spc_mode='no-human', # species mode (choices: 'no-human' / 'human-only') 141 | ): 142 | """Constructor function.""" 143 | 144 | super().__init__() 145 | 146 | # setup configurations 147 | self.csv_dpath = csv_dpath 148 | self.spc_mode = spc_mode 149 | 150 | # initialize the dataset 151 | self.__init_dataset() 152 | 153 | 154 | def __iter__(self): 155 | """Return an iterator of samples in the dataset.""" 156 | 157 | # initialization 158 | species = None 159 | chn_type = None 160 | 161 | # validate the worker information 162 | worker_info = get_worker_info() 163 | assert worker_info is None, 'only single-process data loading is supported' 164 | 165 | # initialize a new epoch 166 | logging.debug('=== start of epoch ===') 167 | 168 | # initialize the sequence pool 169 | idx_file = 0 170 | n_files = len(self.file_list) 171 | random.shuffle(self.file_list) 172 | species, chn_type, csv_fpath = self.file_list[idx_file] 173 | seq_list = parse_csv_file(csv_fpath, n_seqs_max=-1) 174 | logging.debug('adding %d sequences from %s', len(seq_list), csv_fpath) 175 | idx_seq = 0 176 | n_seqs = len(seq_list) 177 | random.shuffle(seq_list) 178 | 179 | # parse CSV files and return an iterator of samples 180 | while True: 181 | # build an input dict 182 | inputs = self.__build_inputs(species, chn_type, seq_list[idx_seq]) 183 | yield inputs 184 | 185 | # early-exit if no expansion is needed 186 | idx_seq += 1 187 | if idx_seq < n_seqs: 188 | continue 189 | 190 | # parse the next CSV file 191 | idx_file += 1 192 | if idx_file == n_files: # no available CSV files left 193 | break 194 | species, chn_type, csv_fpath = self.file_list[idx_file] 195 | seq_list = parse_csv_file(csv_fpath, n_seqs_max=-1) 196 | logging.debug('adding %d sequences from %s', len(seq_list), csv_fpath) 197 | idx_seq = 0 198 | n_seqs = len(seq_list) 199 | random.shuffle(seq_list) 200 | 201 | # indicate the end of current epoch 202 | logging.debug('=== end of epoch ===') 203 | 204 | 205 | def __init_dataset(self): 206 | """Initialize the dataset.""" 207 | 208 | self.file_list = [] 209 | for species in os.listdir(self.csv_dpath): 210 | # check whether the current species should be skipped 211 | if self.spc_mode == 'no-human': 212 | if species == 'human': 213 | continue 214 | elif self.spc_mode == 'human-only': 215 | if species != 'human': 216 | continue 217 | else: 218 | raise ValueError(f'unrecognized species mode: {self.spc_mode}') 219 | 220 | # enumerate all the CSV files for the current species 221 | for csv_fname in os.listdir(os.path.join(self.csv_dpath, species)): 222 | chn_type = 'hc' if csv_fname.startswith('heavy') else 'lc' 223 | csv_fpath = os.path.join(self.csv_dpath, species, csv_fname) 224 | self.file_list.append((species, chn_type, csv_fpath)) 225 | 226 | logging.debug('=== list of CSV files ===') 227 | logging.debug('\n'.join([str(x) for x in self.file_list])) 228 | 229 | 230 | @classmethod 231 | def __build_inputs(cls, species, chn_type, seq_dict): 232 | """Build an input dict.""" 233 | 234 | inputs = { 235 | 'spec': species, 236 | 'type': chn_type, 237 | 'seq': seq_dict['chn'], 238 | 'regions': {k: v for k, v in seq_dict.items() if k != 'chn'}, 239 | } 240 | 241 | return inputs 242 | 243 | 244 | def main(): 245 | """Main entry.""" 246 | 247 | # configurations 248 | data_dir_csv = '/test.pkl' 249 | 250 | # initialization 251 | tfold_init(verb_levl='DEBUG') 252 | 253 | # test w/ 254 | n_seqs_dict = defaultdict(int) 255 | dataset = OasDataset(data_dir_csv, pool_size=512, n_seqs_max=256) # reduced to 1/16 256 | for inputs in dataset: 257 | # n_seqs_dict[inputs['spec']] += 1 258 | logging.info(inputs['spec']) 259 | 260 | 261 | # test w/ DataLoader built from (batch size: 16) 262 | data_loader = DataLoader(dataset, batch_size=16, collate_fn=lambda x: x) 263 | for inputs in tqdm(data_loader): 264 | inspect_data(inputs, name='inputs') 265 | 266 | if __name__ == '__main__': 267 | main() 268 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: hudiff 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - bioconda 7 | dependencies: 8 | - python=3.9 9 | - pytorch=1.13.0 10 | - torchvision=0.14.0 11 | - torchaudio=0.13.0 12 | - pytorch-cuda=11.6 13 | - python-lmdb 14 | - abnumber 15 | - pip: 16 | - pandas 17 | - scipy 18 | - scikit-learn 19 | - tqdm 20 | - tensorboard 21 | - easydict 22 | - pyyaml 23 | - sequence-models 24 | - einops 25 | - matplotlib 26 | - seaborn -------------------------------------------------------------------------------- /evaluation/ABLSTM_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import pandas as pd 5 | import sys 6 | import seaborn as sns 7 | import matplotlib.pyplot as plt 8 | import tempfile 9 | 10 | from ablstm import ModelLSTM 11 | 12 | from anarci import anarci, number 13 | import re 14 | from abnumber import Chain 15 | 16 | # Deal Heavy seq. 17 | def seq_trans_to_aho(humanization_df): 18 | """ 19 | Gather ABLSTM score. 20 | :param humanization_df: DataFrame 21 | :return: 22 | """ 23 | h_seq_df = humanization_df['hseq'] 24 | data = [] 25 | for idx, a_seq in enumerate(h_seq_df): 26 | data.append((f'{idx}', a_seq)) 27 | 28 | # single_data = [('1', single_h_seq), ('2', single_h_seq)] 29 | h_results = anarci(data, scheme='aho', output=False) 30 | h_aho_seq_list = [] 31 | h_seq_results = h_results[0] 32 | for seq_list in h_seq_results: 33 | re_seq = seq_list[0][0] 34 | str_re_seq = str(re_seq) 35 | matches = re.findall(r"'([A-Z\-])'", str_re_seq) 36 | aho_seq = '-' + ''.join(matches) 37 | if len(aho_seq) != 150: 38 | pad_count = 150 - len(aho_seq) 39 | aho_seq = aho_seq + '-' * pad_count 40 | h_aho_seq_list.append(aho_seq) 41 | 42 | return h_aho_seq_list 43 | 44 | def model_eval(aho_txt_fpath): 45 | """ 46 | Predicting the H-score by model. 47 | :param ach_txt: 48 | :return: 49 | """ 50 | model_data_path = '/model.npy' 51 | pred_model = ModelLSTM(embedding_dim=64, hidden_dim=64, device='cuda', gapped=True, fixed_len=True) 52 | pred_model.load(fn=model_data_path) 53 | h_score = pred_model.eval(aho_txt_fpath) 54 | return h_score 55 | 56 | 57 | def main(sample_fpath=None): 58 | """ 59 | Get the Score. 60 | :return: 61 | """ 62 | if sample_fpath is None: 63 | sample_fpath = '/sample_humanization_result.csv' 64 | 65 | save_fpath = os.path.join(os.path.dirname(sample_fpath), 'sample_ablstm_score.pkl') 66 | 67 | sample_df = pd.read_csv(sample_fpath) 68 | sample_human_df = sample_df[sample_df['Specific'] == 'humanization'].reset_index(drop=True) 69 | aho_list = seq_trans_to_aho(sample_human_df) 70 | 71 | # Save h_aho_seq_list 72 | tmp = tempfile.NamedTemporaryFile(suffix='.txt', delete=False) 73 | tmp_fpath = tmp.name 74 | with open(tmp_fpath, 'w') as f: 75 | for seq in aho_list: 76 | f.write(seq + '\n') 77 | 78 | h_score = model_eval(tmp_fpath) 79 | os.remove(tmp_fpath) 80 | 81 | print(h_score) 82 | with open(save_fpath, 'wb') as f: 83 | pickle.dump(h_score, f) 84 | f.close() 85 | 86 | 87 | if __name__ == '__main__': 88 | main() -------------------------------------------------------------------------------- /evaluation/Biophi_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import pandas as pd 5 | from abnumber import Chain 6 | from tqdm import tqdm 7 | 8 | def save_pairs(heavy_chains, light_chains, path): 9 | assert len(heavy_chains) == len(light_chains) 10 | with open(path, 'w') as f: 11 | for heavy, light in zip(heavy_chains, light_chains): 12 | Chain.to_fasta(heavy, f, description='VH') 13 | Chain.to_fasta(light, f, description='VL') 14 | 15 | def trans_to_chain(df, save_path, version=None): 16 | H_chain_list = [Chain(df.iloc[i]['hseq'], scheme='imgt') for i in df.index] 17 | L_chain_list = [Chain(df.iloc[i]['lseq'], scheme='imgt') for i in df.index] 18 | 19 | assert version is not None, print('Need to given specific version.') 20 | for i, (h_chain, l_chain) in tqdm(enumerate(zip(H_chain_list, L_chain_list)), total=len(H_chain_list)): 21 | name = version + 'human' + f'{i}' 22 | h_chain.name = name 23 | l_chain.name = name 24 | 25 | save_pairs(H_chain_list, L_chain_list, save_path) 26 | 27 | 28 | def main(sample_fpath=None): 29 | """ 30 | Get the Biophi score. 31 | :return: 32 | """ 33 | if sample_fpath is None: 34 | sample_fpath = 'sample_humanization_result.csv' 35 | 36 | save_fpath = os.path.join(os.path.dirname(sample_fpath), 'sample_identity.fa') 37 | if os.path.exists(save_fpath): 38 | print('Fasta file Already exists. Skip!') 39 | return save_fpath 40 | 41 | sample_df = pd.read_csv(sample_fpath) 42 | sample_human_df = sample_df[sample_df['Specific'] == 'humanization'].reset_index(drop=True) 43 | trans_to_chain(sample_human_df, save_fpath, version='exp') 44 | 45 | if __name__ == '__main__': 46 | main() -------------------------------------------------------------------------------- /evaluation/T20_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import requests 4 | import sys 5 | import pandas as pd 6 | import time 7 | from tqdm import tqdm 8 | import re 9 | from abnumber import Chain 10 | import concurrent.futures 11 | 12 | T20_REGEX = re.compile('T20 Score:([0-9.]+)') 13 | def get_t20_online(seq, region=1): 14 | if region == 1: 15 | chain = Chain(seq, scheme='imgt') 16 | chain_type = 'vh' if chain.chain_type == 'H' else ('vl' if chain.chain_type == 'L' else 'vk') 17 | elif region == 2: 18 | chain_type = 'vh' 19 | else: 20 | raise ValueError('Region type do not appropriate.') 21 | 22 | html = None 23 | for retry in range(5): 24 | url = f'https://sam.curiaglobal.com/t20/cgi-bin/blast.py?chain={chain_type}®ion={region}&output=3&seqs={seq}' 25 | try: 26 | request = requests.get(url) 27 | if request.ok: 28 | html = request.text 29 | break 30 | except Exception as e: 31 | print(e) 32 | except: 33 | continue 34 | time.sleep(0.5 + retry * 5) 35 | print('Retry', retry+1) 36 | if not html: 37 | sys.exit(1) 38 | # print(html) 39 | matches = T20_REGEX.findall(html) 40 | time.sleep(1) 41 | if not matches: 42 | print(html) 43 | # raise ValueError(f'Error calling url {url}') 44 | return None, None 45 | return float(matches[0]), chain_type 46 | 47 | def get_pair_data_t20(h_seq, l_seq, region=1): 48 | h_score, h_type = get_t20_online(h_seq, region) 49 | l_score, l_type = get_t20_online(l_seq, region) 50 | # print(h_score, l_score) 51 | return [h_score, h_type, l_score, l_type, h_seq, l_seq] 52 | 53 | 54 | def get_one_chain_framework_t20(h_seq, region=2): 55 | h_score, h_type = get_t20_online(h_seq, region) 56 | return [h_score, h_type, h_seq] 57 | 58 | 59 | def process_line(line): 60 | h_seq = line[1]['hseq'] 61 | l_seq = line[1]['lseq'] 62 | name = [line[1]['name']] 63 | data = [] 64 | for retry in range(3): 65 | try: 66 | data = get_pair_data_t20(h_seq, l_seq) 67 | if len(data) > 2: 68 | break 69 | except: 70 | time.sleep(5) 71 | # continue 72 | # if data is not None: 73 | print(data) 74 | if len(data) > 2: 75 | new_data = name + data 76 | new_line_df = pd.DataFrame([new_data], columns=['Raw_name', 'h_score', 'h_gene', 'l_score', 'l_gene', 'h_seq', 'l_seq']) 77 | return new_line_df 78 | else: 79 | return None 80 | 81 | def process_one_seq_and_frame_line(line): 82 | h_seq = line[1]['hseq'] 83 | name = [line[1]['name']] 84 | # name = ['vhhseq' + str(line[0])] 85 | data = [] 86 | for retry in range(3): 87 | try: 88 | data = get_one_chain_framework_t20(h_seq, region=2) 89 | if len(data) > 2: 90 | break 91 | except: 92 | time.sleep(5) 93 | continue 94 | # if data is not None: 95 | print(data) 96 | if len(data) > 2: 97 | new_data = name + data 98 | new_line_df = pd.DataFrame([new_data], columns=['Raw_name', 'h_score', 'h_gene', 'h_seq']) 99 | return new_line_df 100 | else: 101 | return None 102 | 103 | 104 | def frame_main(sample_fpath=None): 105 | if sample_fpath is None: 106 | sample_fpath = '/sample_humanization_result.csv' 107 | 108 | 109 | print(sample_fpath) 110 | save_fpath = os.path.join(os.path.dirname(sample_fpath), 'sample_frame_t20_score.csv') 111 | if os.path.exists(save_fpath): 112 | return save_fpath 113 | 114 | sample_df = pd.read_csv(sample_fpath) 115 | 116 | sample_human_df = sample_df[sample_df['Specific'] == 'humanization'].reset_index(drop=True) 117 | with concurrent.futures.ProcessPoolExecutor() as executor: 118 | results = list(tqdm(executor.map(process_one_seq_and_frame_line, sample_human_df.iterrows()), total=len(sample_human_df))) 119 | 120 | save_frame_t20_df = pd.concat([result for result in results if result is not None], ignore_index=True) 121 | Not_successful_index = [i for i, result in enumerate(results) if result is None] 122 | 123 | print(Not_successful_index) 124 | save_frame_t20_df.to_csv(save_fpath, index=False) 125 | return save_fpath 126 | 127 | 128 | def main(sample_fpath=None): 129 | """ 130 | Gather the T20 score from the website. 131 | :return: 132 | """ 133 | if sample_fpath is None: 134 | sample_fpath = '/humanization_pair_data_filter.csv' 135 | 136 | save_fpath = os.path.join(os.path.dirname(sample_fpath), 'sample_t20_score.csv') 137 | if os.path.exists(save_fpath): 138 | return save_fpath 139 | 140 | sample_df = pd.read_csv(sample_fpath) 141 | 142 | sample_human_df = sample_df[sample_df['Specific'] == 'humanization'].reset_index(drop=True) 143 | 144 | save_t20_df = pd.DataFrame(columns=['Raw_name', 'h_score', 'h_gene', 'l_score', 'l_gene', 'h_seq', 'l_seq']) 145 | results = [] 146 | for line in sample_human_df.iterrows(): 147 | results.append(process_line(line=line)) 148 | 149 | save_t20_df = pd.concat([result for result in results if result is not None], ignore_index=True) 150 | Not_successful_index = [i for i, result in enumerate(results) if result is None] 151 | 152 | 153 | 154 | print(Not_successful_index) 155 | save_t20_df.to_csv(save_fpath, index=False) 156 | return save_fpath 157 | 158 | 159 | if __name__ == '__main__': 160 | main() -------------------------------------------------------------------------------- /evaluation/Zscore_eval.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from bs4 import BeautifulSoup 3 | import json 4 | import pandas as pd 5 | import time 6 | from tqdm import tqdm 7 | import re 8 | from abnumber import Chain 9 | import json 10 | from urllib.parse import urlencode 11 | import concurrent.futures 12 | 13 | import seaborn as sns 14 | import matplotlib.pyplot as plt 15 | 16 | import os 17 | 18 | SCORE_REGEX = re.compile('

The Z-score value of the Query sequence is: (-?[0-9.]+)

') 19 | def get_z_score_online(seq): 20 | chain = Chain(seq, scheme='imgt') 21 | chain_type = 'human_heavy' if chain.chain_type == 'H' else ('human_lambda' if chain.chain_type == 'L' else 'human_kappa') 22 | html = None 23 | for retry in range(5): 24 | url = f'http://www.bioinf.org.uk/abs/shab/shab.cgi?aa_sequence={seq}&DB={chain_type}' 25 | request = requests.get(url) 26 | time.sleep(0.5 + retry * 5) 27 | if request.ok: 28 | html = request.text 29 | break 30 | else: 31 | print('Retry', retry+1) 32 | if not html: 33 | raise ValueError('Z-score server is not accessible') 34 | matches = SCORE_REGEX.findall(html) 35 | if not matches: 36 | print(html) 37 | # raise ValueError(f'Error calling url {url}') 38 | return None, None 39 | return float(matches[0]), chain_type 40 | 41 | def get_pair_data_zscore(h_seq, l_seq): 42 | h_z_score, h_type = get_z_score_online(h_seq) 43 | l_z_score, l_type = get_z_score_online(l_seq) 44 | return [h_z_score, h_type, l_z_score, l_type, h_seq, l_seq] 45 | 46 | def process_z_score_line(line): 47 | h_seq = line[1]['hseq'] 48 | l_seq = line[1]['lseq'] 49 | name = [line[1]['name']] 50 | for retry in range(10): 51 | try: 52 | data = get_pair_data_zscore(h_seq, l_seq) 53 | if len(data) > 2: 54 | break 55 | except: 56 | time.sleep(5) 57 | continue 58 | if len(data) != 2: 59 | new_data = name + data 60 | new_line_df = pd.DataFrame([new_data], 61 | columns=['Raw_name', 'h_score', 'h_gene', 'l_score', 'l_gene', 'h_seq', 'l_seq']) 62 | return new_line_df 63 | else: 64 | return None 65 | 66 | 67 | def main(sample_fpath=None): 68 | """ 69 | Gathering the Z score info for eval. 70 | sample_fpath: file path of the sample result. 71 | :return: 72 | """ 73 | if sample_fpath is None: 74 | sample_fpath = '/sample_humanization_result.csv' 75 | 76 | print(sample_fpath) 77 | save_fpath = os.path.join(os.path.dirname(sample_fpath), 'sample_z_score.csv') 78 | 79 | sample_df = pd.read_csv(sample_fpath) 80 | sample_human_df = sample_df[sample_df['Specific'] == 'humanization'].reset_index(drop=True) 81 | 82 | save_z_df = pd.DataFrame(columns=['Raw_name', 'h_score', 'h_gene', 'l_score', 'l_gene', 'h_seq', 'l_seq']) 83 | 84 | results = [] 85 | for line in sample_human_df.iterrows(): 86 | results.append(process_z_score_line(line)) 87 | 88 | save_z_df = pd.concat([result for result in results if result is not None], ignore_index=True) 89 | Not_successful_index = [i for i, result in enumerate(results) if result is None] 90 | 91 | print(Not_successful_index) 92 | save_z_df.to_csv(save_fpath, index=False) 93 | 94 | if __name__ == '__main__': 95 | main() 96 | -------------------------------------------------------------------------------- /evaluation/humab_eval.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from bs4 import BeautifulSoup 3 | import json 4 | import pandas as pd 5 | import time 6 | from tqdm import tqdm 7 | import concurrent.futures 8 | import os 9 | 10 | from abnumber import Chain 11 | 12 | # Define deal out-of-order table. 13 | def regular_order_table(out_of_order_table): 14 | all_table_data = [] 15 | for table in out_of_order_table: 16 | table_data = [] 17 | for row in table.find_all('tr'): 18 | row_data = [] 19 | for cell in row.find_all(['th', 'td']): 20 | row_data.append(cell.text) 21 | table_data.append(row_data) 22 | all_table_data.append(table_data) 23 | return all_table_data[:2] # only the first two will be used, all is three. 24 | 25 | 26 | # Define extract data. Only want to know wther the sequence can be viewed as human. 27 | def extract_human_data(regular_table): 28 | extracted_data = [] 29 | for table_data in regular_table: 30 | table_header = table_data[0] 31 | human_row = [None, None, None, None] 32 | for row in table_data: 33 | if row[-1] == 'HUMAN': 34 | human_row = row 35 | extracted_data.extend(human_row) 36 | return extracted_data 37 | 38 | 39 | # Define request process. 40 | def get_predict_result(job_name, h_seq, l_seq): 41 | # Url path 42 | humab_url = 'https://opig.stats.ox.ac.uk/webapps/sabdab-sabpred/sabpred/humab' 43 | 44 | data = { 45 | 'h_sequence_score': h_seq, 46 | 'l_sequence_score': l_seq, 47 | 'jobname_score': job_name, 48 | 'humanise': True, 49 | } 50 | reponse = requests.post(humab_url, data=data) 51 | result_url = reponse.url 52 | print(result_url) 53 | 54 | # Need to wait a moment. 55 | time.sleep(15) 56 | 57 | # Get the result page. 58 | result_response = requests.get(result_url) 59 | 60 | if result_response.status_code == 200: 61 | soup = BeautifulSoup(result_response.text, 'html.parser') 62 | tables = soup.find_all('table', {'class': 'table table-results'}) 63 | # print(tables) 64 | 65 | predict_table = regular_order_table(tables) 66 | print(predict_table) 67 | extract_data = extract_human_data(predict_table) 68 | print(extract_data) 69 | else: 70 | print('May be url has problem or need larger sleep time.') 71 | 72 | sequence_list = [h_seq, l_seq] 73 | return extract_data + sequence_list 74 | 75 | 76 | def process_line(line): 77 | h_seq = line[1]['hseq'] 78 | l_seq = line[1]['lseq'] 79 | 80 | l_chain_type = Chain(l_seq, scheme='imgt').chain_type 81 | # if l_chain_type == 'L': 82 | # return True 83 | 84 | name = [line[1]['name']] 85 | job_name = line[1]['Specific'] + '_' + str(line[0]) 86 | for retry in range(50): 87 | try: 88 | data = get_predict_result(job_name, h_seq, l_seq) 89 | if len(data) > 2: 90 | break 91 | except: 92 | time.sleep(5) 93 | continue 94 | if len(data) != 2: 95 | new_data = name + data + [l_chain_type] 96 | new_line_df = pd.DataFrame([new_data], 97 | columns=['Raw_name', 'h_v_gene', 'h_score', 'h_threshold', 'h_classification', 98 | 'l_v_gene', 'l_score', 'l_threshold', 'l_classification', 'h_seq', 'l_seq', 'l_chain_type']) 99 | return new_line_df 100 | else: 101 | return None 102 | 103 | 104 | def main(sample_fpath=None): 105 | """ 106 | To gather Hu-mab method score for eval. 107 | :return: 108 | """ 109 | if sample_fpath is None: 110 | sample_fpath = 'sample_humanization_result.csv' 111 | save_fpath = os.path.join(os.path.dirname(sample_fpath), 'sample_humab_score.csv') 112 | 113 | sample_df = pd.read_csv(sample_fpath) 114 | sample_human_df = sample_df[sample_df['Specific'] == 'humanization'].reset_index(drop=True) 115 | # print(sample_human_df) 116 | 117 | save_humab_df = pd.DataFrame(columns=['Raw_name', 'h_v_gene', 'h_score', 'h_threshold', 'h_classification', 118 | 'l_v_gene', 'l_score', 'l_threshold', 'l_classification', 'h_seq', 'l_seq']) 119 | 120 | results = [] 121 | for line in sample_human_df.iterrows(): 122 | results.append(process_line(line)) 123 | 124 | save_humab_df = pd.concat([result for result in results if result is not None], ignore_index=True) 125 | Not_successful_index = [i for i, result in enumerate(results) if result is None] 126 | print(Not_successful_index) 127 | 128 | save_humab_df.to_csv(save_fpath, index=False) 129 | 130 | 131 | if __name__ == '__main__': 132 | main() 133 | -------------------------------------------------------------------------------- /model/encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentAI4S/HuDiff/bb7636f182699f98c37855dad05a5c6c61b576bd/model/encoder/__init__.py -------------------------------------------------------------------------------- /model/encoder/cross_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from inspect import isfunction 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from typing import Optional, Tuple 8 | 9 | 10 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 11 | """ 12 | Reshape frequency tensor for broadcasting it with another tensor. 13 | 14 | This function reshapes the frequency tensor to have the same shape as the target tensor 'x' 15 | for the purpose of broadcasting the frequency tensor during element-wise operations. 16 | 17 | Args: 18 | freqs_cis (torch.Tensor): Frequency tensor to be reshaped. 19 | x (torch.Tensor): Target tensor for broadcasting compatibility. 20 | 21 | Returns: 22 | torch.Tensor: Reshaped frequency tensor. 23 | 24 | Raises: 25 | AssertionError: If the frequency tensor doesn't match the expected shape. 26 | AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. 27 | """ 28 | ndim = x.ndim 29 | assert 0 <= 1 < ndim 30 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 31 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 32 | return freqs_cis.view(*shape) 33 | 34 | 35 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): 36 | """ 37 | Precompute the frequency tensor for complex exponentials (cis) with given dimensions. 38 | 39 | This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' 40 | and the end index 'end'. The 'theta' parameter scales the frequencies. 41 | The returned tensor contains complex values in complex64 data type. 42 | 43 | Args: 44 | dim (int): Dimension of the frequency tensor. 45 | end (int): End index for precomputing frequencies. 46 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. 47 | 48 | Returns: 49 | torch.Tensor: Precomputed frequency tensor with complex exponentials. 50 | 51 | """ 52 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 53 | t = torch.arange(end, device=freqs.device) # type: ignore 54 | freqs = torch.outer(t, freqs).float() # type: ignore 55 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 56 | return freqs_cis 57 | 58 | 59 | def apply_rotary_emb( 60 | xq: torch.Tensor, 61 | xk: torch.Tensor, 62 | freqs_cis: torch.Tensor, 63 | ) -> Tuple[torch.Tensor, torch.Tensor]: 64 | """ 65 | Apply rotary embeddings to input tensors using the given frequency tensor. 66 | 67 | This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided 68 | frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor 69 | is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are 70 | returned as real tensors. 71 | 72 | Args: 73 | xq (torch.Tensor): Query tensor to apply rotary embeddings. 74 | xk (torch.Tensor): Key tensor to apply rotary embeddings. 75 | freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. 76 | 77 | Returns: 78 | Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. 79 | 80 | 81 | 82 | """ 83 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 84 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 85 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 86 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 87 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 88 | return xq_out.type_as(xq), xk_out.type_as(xk) 89 | 90 | 91 | def exists(val): 92 | return val is not None 93 | 94 | 95 | def default(val, d): 96 | if exists(val): 97 | return val 98 | return d() if isfunction(d) else d 99 | 100 | # feedforward 101 | class GEGLU(nn.Module): 102 | def __init__(self, dim_in, dim_out): 103 | super().__init__() 104 | self.proj = nn.Linear(dim_in, dim_out * 2) 105 | 106 | def forward(self, x): 107 | x, gate = self.proj(x).chunk(2, dim=-1) 108 | return x * F.gelu(gate) 109 | 110 | 111 | class FeedForward(nn.Module): 112 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 113 | super().__init__() 114 | inner_dim = int(dim * mult) 115 | dim_out = default(dim_out, dim) 116 | project_in = nn.Sequential( 117 | nn.Linear(dim, inner_dim), 118 | nn.GELU() 119 | ) if not glu else GEGLU(dim, inner_dim) 120 | 121 | self.net = nn.Sequential( 122 | project_in, 123 | nn.Dropout(dropout), 124 | nn.Linear(inner_dim, dim_out) 125 | ) 126 | 127 | def forward(self, x): 128 | return self.net(x) 129 | 130 | 131 | class AttLayer(nn.Module): 132 | 133 | def __init__(self, d_model, att_model, nhead, length=152): 134 | super().__init__() 135 | self.nhead = nhead 136 | self.dk = (att_model / nhead) ** 0.5 137 | 138 | self.query = nn.Linear(d_model, att_model) 139 | self.key = nn.Linear(d_model, att_model) 140 | self.value = nn.Linear(d_model, att_model) 141 | self.softmax = nn.Softmax(dim=-1) 142 | 143 | self.out_put = nn.Linear(att_model, d_model) 144 | 145 | rope = precompute_freqs_cis(att_model//nhead, length) 146 | self.register_buffer('rope', rope) 147 | 148 | 149 | def forward(self, x, context=None, mask=None): 150 | Q = self.query(x) 151 | if context is None: 152 | K = self.key(x) 153 | V = self.value(x) 154 | else: 155 | K = self.key(context) 156 | V = self.key(context) 157 | 158 | Q = Q.view(Q.shape[0], Q.shape[1], self.nhead, -1) #.permute(0, 2, 1, 3) 159 | K = K.view(K.shape[0], K.shape[1], self.nhead, -1) #.permute(0, 2, 1, 3) 160 | V = V.view(V.shape[0], V.shape[1], self.nhead, -1) #.permute(0, 2, 1, 3) 161 | 162 | Q, K = apply_rotary_emb(Q, K, self.rope) 163 | 164 | Q = Q.transpose(1, 2) 165 | K = K.transpose(1, 2) 166 | V = V.transpose(1, 2) 167 | 168 | attn_weights = torch.matmul(Q, K.transpose(-2, -1)) / self.dk 169 | attn_weights = self.softmax(attn_weights) 170 | output = torch.matmul(attn_weights, V) 171 | output = output.permute(0, 2, 1, 3).contiguous().view(Q.shape[0], -1, Q.shape[1]*Q.shape[3]) 172 | output = self.out_put(output) 173 | return output 174 | 175 | 176 | class TransformerBlock(nn.Module): 177 | def __init__(self, d_model, nhead, att_model, dim_feedforward): 178 | super(TransformerBlock, self).__init__() 179 | self.attention = AttLayer(d_model, att_model, nhead) 180 | self.norm1 = nn.LayerNorm(d_model) 181 | self.norm2 = nn.LayerNorm(d_model) 182 | self.ffn = nn.Sequential( 183 | nn.Linear(d_model, dim_feedforward), 184 | nn.ReLU(), 185 | nn.Linear(dim_feedforward, d_model) 186 | ) 187 | 188 | def forward(self, x, mask=None): 189 | attn_output = self.attention(x, mask) 190 | x = x + attn_output 191 | x = self.norm1(x) 192 | 193 | ffn_output = self.ffn(x) 194 | x = x + ffn_output 195 | x = self.norm2(x) 196 | return x 197 | 198 | 199 | class TransformerNet(nn.Module): 200 | def __init__(self, d_model, att_model, nhead, num_layers, dim_feedforward): 201 | super(TransformerNet, self).__init__() 202 | self.layers = nn.ModuleList([ 203 | TransformerBlock(d_model, nhead, att_model, dim_feedforward) for _ in range(num_layers) 204 | ]) 205 | 206 | def forward(self, x, mask=None): 207 | for layer in self.layers: 208 | x = layer(x, mask) 209 | return x 210 | 211 | 212 | class CrossAttBlock(nn.Module): 213 | 214 | def __init__(self, d_model, att_model, dim_feedforward, nhead): 215 | super().__init__() 216 | self.attnh = AttLayer(d_model, att_model, nhead) 217 | self.attn_hc = AttLayer(d_model, att_model, nhead) 218 | self.attnl = AttLayer(d_model, att_model, nhead) 219 | self.attn_lc = AttLayer(d_model, att_model, nhead) 220 | 221 | self.normh1 = nn.LayerNorm(d_model) 222 | self.normh2 = nn.LayerNorm(d_model) 223 | 224 | self.norml1 = nn.LayerNorm(d_model) 225 | self.norml2 = nn.LayerNorm(d_model) 226 | 227 | self.ffh = nn.Sequential( 228 | nn.Linear(d_model, dim_feedforward), 229 | nn.ReLU(), 230 | nn.Linear(dim_feedforward, d_model) 231 | ) 232 | self.ffl = nn.Sequential( 233 | nn.Linear(d_model, dim_feedforward), 234 | nn.ReLU(), 235 | nn.Linear(dim_feedforward, d_model) 236 | ) 237 | 238 | def forward(self, h, l, mask=None): 239 | """ 240 | 241 | :param h: 242 | :param l: 243 | :param mask: 244 | :return: 245 | """ 246 | at_h = h + self.attnh(h) 247 | at_l = l + self.attnl(l) 248 | 249 | at_h = at_h + self.attn_hc(self.normh1(h), l) 250 | at_l = at_l + self.attn_lc(self.norml1(l), h) 251 | 252 | h = self.ffh(self.normh2(at_h)) + at_h 253 | l = self.ffl(self.norml2(at_l)) + at_l 254 | return h, l 255 | 256 | 257 | class SelfAttBlock(nn.Module): 258 | 259 | def __init__(self, d_model, att_model, dim_feedforward, nhead, rolength): 260 | super().__init__() 261 | self.attn_hl = AttLayer(d_model, att_model, nhead, rolength) 262 | self.attn_hl_c = AttLayer(d_model, att_model, nhead, rolength) 263 | 264 | self.norm_hl1 = nn.LayerNorm(d_model) 265 | self.norm_hl2 = nn.LayerNorm(d_model) 266 | 267 | self.ff_hl = nn.Sequential( 268 | nn.Linear(d_model, dim_feedforward), 269 | nn.ReLU(), 270 | nn.Linear(dim_feedforward, d_model) 271 | ) 272 | 273 | def forward(self, h_l, mask=None): 274 | """ 275 | 276 | :param h: 277 | :param l: 278 | :param mask: 279 | :return: 280 | """ 281 | 282 | at_hl = h_l + self.attn_hl(h_l) 283 | 284 | at_hl = at_hl + self.attn_hl_c(self.norm_hl1(at_hl)) 285 | 286 | h_l = self.ff_hl(self.norm_hl2(at_hl)) + h_l 287 | return h_l 288 | 289 | 290 | 291 | class SelfAttNet(nn.Module): 292 | 293 | def __init__(self, d_model, att_model, dim_feedforward, nhead, rolength, num_cross_layers): 294 | super().__init__() 295 | self.layers = nn.ModuleList( 296 | [SelfAttBlock(d_model, att_model, dim_feedforward, nhead, rolength) for _ in range(num_cross_layers)] 297 | ) 298 | 299 | def forward(self, h_l, mask=None): 300 | """ 301 | 302 | :param h: 303 | :param l: 304 | :param mask: 305 | :return: 306 | """ 307 | # h_l = h_l + pos_emb 308 | for layer in self.layers: 309 | h_l = layer(h_l) 310 | return h_l 311 | 312 | 313 | class CrossAttNet(nn.Module): 314 | 315 | def __init__(self, d_model, att_model, dim_feedforward, nhead, num_cross_layers): 316 | super().__init__() 317 | self.layers = nn.ModuleList( 318 | [CrossAttBlock(d_model, att_model, dim_feedforward, nhead) for _ in range(num_cross_layers)] 319 | ) 320 | 321 | def forward(self, h, l, pos_emb, mask=None): 322 | """ 323 | 324 | :param h: 325 | :param l: 326 | :param mask: 327 | :return: 328 | """ 329 | h = h + pos_emb[:h.size(0)] 330 | l = l + pos_emb[h.size(0):] 331 | for layer in self.layers: 332 | h, l = layer(h, l) 333 | return h, l 334 | -------------------------------------------------------------------------------- /model/encoder/diffusion.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import math 3 | 4 | import numpy as np 5 | import torch as th 6 | 7 | 8 | ########################################################################################## 9 | 10 | # DIFFUSION CODE BASE FOR PROTEIN SEQUENCE DIFFUSION WAS ADAPTED FROM LM-DIFFUSION # 11 | 12 | # (https://github.com/XiangLi1999/Diffusion-LM) # 13 | 14 | ########################################################################################## 15 | 16 | class GaussianDiffusion_SEQDIFF: 17 | """ 18 | T = number of timesteps to set up diffuser with 19 | 20 | schedule = type of noise schedule to use linear, cosine, gaussian 21 | 22 | noise = type of ditribution to sample from; DEFAULT - normal_gaussian 23 | 24 | """ 25 | 26 | def __init__(self, 27 | T=1000, 28 | schedule='sqrt', 29 | sample_distribution='normal', 30 | sample_distribution_gmm_means=[-1.0, 1.0], 31 | sample_distribution_gmm_variances=[1.0, 1.0], 32 | F=1, 33 | ): 34 | 35 | # Use float64 for accuracy. 36 | betas = np.array(get_named_beta_schedule(schedule, T), dtype=np.float64) 37 | self.betas = betas 38 | assert len(betas.shape) == 1, "betas must be 1-D" 39 | assert (betas > 0).all() and (betas <= 1).all() 40 | 41 | self.num_timesteps = int(betas.shape[0]) 42 | self.F = F 43 | 44 | alphas = 1.0 - betas 45 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 46 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 47 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 48 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 49 | 50 | # calculations for posterior q(x_{t-1} | x_t, x_0) 51 | self.posterior_variance = (betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)) 52 | # log calculation clipped because the posterior variance is 0 at the 53 | # beginning of the diffusion chain. 54 | self.posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) 55 | self.posterior_mean_coef1 = (betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)) 56 | self.posterior_mean_coef2 = ((1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)) 57 | 58 | # calculations for diffusion q(x_t | x_{t-1}) and others 59 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 60 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 61 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 62 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 63 | 64 | # sample_distribution_params 65 | self.sample_distribution = sample_distribution 66 | self.sample_distribution_gmm_means = [float(mean) for mean in sample_distribution_gmm_means] 67 | self.sample_distribution_gmm_variances = [float(variance) for variance in sample_distribution_gmm_variances] 68 | 69 | if self.sample_distribution == 'normal': 70 | self.noise_function = th.randn_like 71 | else: 72 | self.noise_function = self.randnmixture_like 73 | 74 | def q_mean_variance(self, x_start, t): 75 | """ 76 | Get the distribution q(x_t | x_0). 77 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 78 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 79 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 80 | """ 81 | mean = ( 82 | _extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 83 | ) 84 | variance = _extract(1.0 - self.alphas_cumprod, t, x_start.shape) 85 | log_variance = _extract( 86 | self.log_one_minus_alphas_cumprod, t, x_start.shape 87 | ) 88 | return mean, variance, log_variance 89 | 90 | def q_sample(self, x_start, t, mask=None, DEVICE=None): 91 | """ 92 | Diffuse the data for a given number of diffusion steps. 93 | In other words, sample from q(x_t | x_0). 94 | :param x_start: the initial data batch. 95 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 96 | :param noise: if specified, the split-out normal noise. 97 | :return: A noisy version of x_start. 98 | """ 99 | 100 | # noise_function is determined in init depending on type of noise specified 101 | noise = self.noise_function(x_start) * (self.F ** 2) 102 | if DEVICE != None: 103 | noise = noise.to(DEVICE) 104 | 105 | assert noise.shape == x_start.shape 106 | x_sample = ( 107 | _extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 108 | + _extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 109 | * noise) 110 | 111 | if mask is not None: 112 | x_sample[mask] = x_start[mask] 113 | 114 | return x_sample 115 | 116 | def q_posterior_mean_variance(self, x_start, x_t, t): 117 | """ 118 | Compute the mean and variance of the diffusion posterior: 119 | q(x_{t-1} | x_t, x_0) 120 | """ 121 | assert x_start.shape == x_t.shape 122 | 123 | posterior_mean = (_extract(self.posterior_mean_coef1, t, x_t.shape) * x_start 124 | + _extract(self.posterior_mean_coef2, t, x_t.shape) * x_t) 125 | 126 | posterior_variance = _extract(self.posterior_variance, t, x_t.shape) 127 | 128 | posterior_log_variance_clipped = _extract(self.posterior_log_variance_clipped, t, x_t.shape) 129 | 130 | assert ( 131 | posterior_mean.shape[0] 132 | == posterior_variance.shape[0] 133 | == posterior_log_variance_clipped.shape[0] 134 | == x_start.shape[0] 135 | ) 136 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 137 | 138 | def randnmixture_like(self, tensor_like, number_normal=3, weights_normal=None): 139 | 140 | if self.sample_distribution_gmm_means and self.sample_distribution_gmm_variances: 141 | assert len(self.sample_distribution_gmm_means) == len(self.sample_distribution_gmm_variances) 142 | 143 | if not weights_normal: 144 | mix = th.distributions.Categorical(th.ones(len(self.sample_distribution_gmm_means))) # number_normal 145 | else: 146 | assert len(weights_normal) == number_normal 147 | mix = th.distributions.Categorical(weights_normal) 148 | # comp = torch.distributions.Normal(torch.randn(number_normal), torch.rand(number_normal)) 149 | comp = th.distributions.Normal(th.tensor(self.sample_distribution_gmm_means), 150 | th.tensor(self.sample_distribution_gmm_variances)) 151 | # comp = torch.distributions.Normal([-3, 3], [1, 1]) 152 | # comp = torch.distributions.Normal([-3, 0, 3], [1, 1, 1]) 153 | # comp = torch.distributions.Normal([-3, 0, 3], [1, 1, 1]) 154 | gmm = th.distributions.mixture_same_family.MixtureSameFamily(mix, comp) 155 | return th.tensor([gmm.sample() for _ in range(np.prod(tensor_like.shape))]).reshape(tensor_like.shape) 156 | 157 | 158 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 159 | """ 160 | Get a pre-defined beta schedule for the given name. 161 | The beta schedule library consists of beta schedules which remain similar 162 | in the limit of num_diffusion_timesteps. 163 | Beta schedules may be added, but should not be removed or changed once 164 | they are committed to maintain backwards compatibility. 165 | """ 166 | if schedule_name == "linear": 167 | # Linear schedule from Ho et al, extended to work for any number of 168 | # diffusion steps. 169 | scale = 1000 / num_diffusion_timesteps 170 | beta_start = scale * 0.0001 171 | beta_end = scale * 0.02 172 | return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 173 | 174 | elif schedule_name == "cosine": 175 | return betas_for_alpha_bar(num_diffusion_timesteps, 176 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, ) 177 | 178 | elif schedule_name == 'sqrt': 179 | return betas_for_alpha_bar(num_diffusion_timesteps, lambda t: 1 - np.sqrt(t + 0.0001), ) 180 | 181 | else: 182 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 183 | 184 | 185 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 186 | """ 187 | Create a beta schedule that discretizes the given alpha_t_bar function, 188 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 189 | :param num_diffusion_timesteps: the number of betas to produce. 190 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 191 | produces the cumulative product of (1-beta) up to that 192 | part of the diffusion process. 193 | :param max_beta: the maximum beta to use; use values lower than 1 to 194 | prevent singularities. 195 | """ 196 | betas = [] 197 | for i in range(num_diffusion_timesteps): 198 | t1 = i / num_diffusion_timesteps 199 | t2 = (i + 1) / num_diffusion_timesteps 200 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 201 | return np.array(betas) 202 | 203 | 204 | def _extract(arr, timesteps, broadcast_shape): 205 | """ 206 | Extract values from a 1-D numpy array for a batch of indices. 207 | :param arr: the 1-D numpy array. 208 | :param timesteps: a tensor of indices into the array to extract. 209 | :param broadcast_shape: a larger shape of K dimensions with the batch 210 | dimension equal to the length of timesteps. 211 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 212 | """ 213 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 214 | while len(res.shape) < len(broadcast_shape): 215 | res = res[..., None] 216 | return res.expand(broadcast_shape) -------------------------------------------------------------------------------- /model/encoder/rotary_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Tuple 7 | 8 | import torch 9 | 10 | 11 | def rotate_half(x): 12 | x1, x2 = x.chunk(2, dim=-1) 13 | return torch.cat((-x2, x1), dim=-1) 14 | 15 | 16 | def apply_rotary_pos_emb(x, cos, sin): 17 | cos = cos[:, : x.shape[-2], :] 18 | sin = sin[:, : x.shape[-2], :] 19 | 20 | return (x * cos) + (rotate_half(x) * sin) 21 | 22 | 23 | class RotaryEmbedding(torch.nn.Module): 24 | """ 25 | The rotary position embeddings from RoFormer_ (Su et. al). 26 | A crucial insight from the method is that the query and keys are 27 | transformed by rotation matrices which depend on the relative positions. 28 | Other implementations are available in the Rotary Transformer repo_ and in 29 | GPT-NeoX_, GPT-NeoX was an inspiration 30 | .. _RoFormer: https://arxiv.org/abs/2104.09864 31 | .. _repo: https://github.com/ZhuiyiTechnology/roformer 32 | .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox 33 | .. warning: Please note that this embedding is not registered on purpose, as it is transformative 34 | (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis 35 | """ 36 | 37 | def __init__(self, dim: int, *_, **__): 38 | super().__init__() 39 | # Generate and save the inverse frequency buffer (non trainable) 40 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 41 | self.register_buffer("inv_freq", inv_freq) 42 | 43 | self._seq_len_cached = None 44 | self._cos_cached = None 45 | self._sin_cached = None 46 | 47 | def _update_cos_sin_tables(self, x, seq_dimension=1): 48 | seq_len = x.shape[seq_dimension] 49 | 50 | # Reset the tables if the sequence length has changed, 51 | # or if we're on a new device (possibly due to tracing for instance) 52 | if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: 53 | self._seq_len_cached = seq_len 54 | t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) 55 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 56 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 57 | 58 | self._cos_cached = emb.cos()[None, :, :] 59 | self._sin_cached = emb.sin()[None, :, :] 60 | 61 | return self._cos_cached, self._sin_cached 62 | 63 | def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 64 | self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) 65 | 66 | return ( 67 | apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), 68 | apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), 69 | ) 70 | 71 | 72 | -------------------------------------------------------------------------------- /model/nanoencoder/abnativ_model.py: -------------------------------------------------------------------------------- 1 | # (c) 2023 Sormannilab and Aubin Ramon 2 | # 3 | # AbNatiV model, Pytorch version 4 | # 5 | # ============================================================================ 6 | 7 | from typing import Tuple 8 | import math 9 | 10 | from .abnativ_vq import VectorQuantize 11 | from .abnativ_utils import find_optimal_cnn1d_padding, find_out_padding_cnn1d_transpose 12 | 13 | import torch 14 | from torch import nn 15 | from torch.nn import functional as F 16 | 17 | # import pytorch_lightning as pl 18 | from einops.layers.torch import Rearrange 19 | 20 | 21 | class PositionalEncoding(nn.Module): 22 | def __init__(self, d_embedding, max_len): 23 | super(PositionalEncoding, self).__init__() 24 | 25 | position = torch.arange(max_len).unsqueeze(1) 26 | div_term = torch.exp(torch.arange(0, d_embedding, 2) * (-math.log(10000.0) / d_embedding)) 27 | pe = torch.zeros(max_len, d_embedding) 28 | 29 | # apply sin to even indices in the array; 2i 30 | pe[:, 0::2] = torch.sin(position * div_term) 31 | 32 | # apply cos to odd indices in the array; 2i+1 33 | pe[:, 1::2] = torch.cos(position * div_term) 34 | self.register_buffer('pe', pe) 35 | 36 | def forward(self, x) -> torch.Tensor: 37 | """ 38 | Args: 39 | x: Tensor, shape [batch_size, input_seq_len, d_embedding] 40 | """ 41 | x = x + self.pe[:x.size(1)] 42 | return x 43 | 44 | 45 | class MHAEncoderBlock(nn.Module): 46 | def __init__(self, d_embedding, num_heads, d_ff, dropout): 47 | super(MHAEncoderBlock, self).__init__() 48 | 49 | self.self_MHA = torch.nn.MultiheadAttention(d_embedding, num_heads, batch_first=True) 50 | 51 | self.MLperceptron = nn.Sequential( 52 | nn.Linear(d_embedding, d_ff), 53 | nn.Dropout(dropout), 54 | nn.ReLU(inplace=True), 55 | nn.Linear(d_ff, d_embedding)) 56 | 57 | self.layernorm1 = nn.LayerNorm(d_embedding, eps=1e-6) 58 | self.layernorm2 = nn.LayerNorm(d_embedding, eps=1e-6) 59 | 60 | self.dropout = nn.Dropout(dropout) 61 | 62 | def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: 63 | """ 64 | Args: 65 | x: Tensor, shape [batch_size, input_seq_len, d_embedding] 66 | """ 67 | # Attention 68 | attn_output, attn_output_weights = self.self_MHA(x, x, x) # (batch_size, input_seq_len, d_embedding) 69 | x = x + self.dropout(attn_output) 70 | x = self.layernorm1(x) 71 | 72 | # MLP 73 | linear_output = self.MLperceptron(x) 74 | x = x + self.dropout(linear_output) 75 | x = self.layernorm2(x) # (batch_size, input_seq_len, d_embedding) + residual 76 | 77 | return x, attn_output_weights 78 | 79 | 80 | class Encoder(nn.Module): 81 | def __init__(self, d_embedding, kernel, stride, num_heads, num_mha_layers, d_ff, 82 | length_seq, alphabet_size, dropout=0): 83 | super(Encoder, self).__init__() 84 | 85 | # CNN1d embedding 86 | self.l_red, self.padding = find_optimal_cnn1d_padding(L_in=length_seq, K=kernel, S=stride) 87 | self.cnn_embedding = nn.Sequential(Rearrange('b l r -> b r l'), 88 | nn.Conv1d(alphabet_size, d_embedding, kernel_size=kernel, stride=stride, 89 | padding=self.padding), 90 | Rearrange('b r l -> b l r')) 91 | 92 | # Positional encoding 93 | self.en_pos_encoding = PositionalEncoding(d_embedding, max_len=self.l_red) 94 | self.en_dropout = nn.Dropout(dropout) 95 | 96 | # MHA blocks 97 | self.en_MHA_blocks = nn.ModuleList([MHAEncoderBlock(d_embedding, num_heads, d_ff, dropout) 98 | for _ in range(num_mha_layers)]) 99 | 100 | def forward(self, x) -> torch.Tensor: 101 | """ 102 | Args: 103 | x: Tensor, shape [batch_size, input_seq_len, alphabet_size] 104 | """ 105 | # CNN1d Embedding 106 | h = self.cnn_embedding(x) # (batch_size, l_red, d_embedding) 107 | 108 | # Positional encoding 109 | h = self.en_pos_encoding(h) 110 | h = self.en_dropout(h) 111 | 112 | #  MHA blocks 113 | for i, l in enumerate(self.en_MHA_blocks): 114 | h, attn_enc_weights = self.en_MHA_blocks[i](h) # (batch_size, l_red, d_embedding) 115 | 116 | return h 117 | 118 | 119 | class Decoder(nn.Module): 120 | def __init__(self, d_embedding, kernel, stride, num_heads, num_mha_layers, d_ff, 121 | length_seq, alphabet_size, dropout=0): 122 | super(Decoder, self).__init__() 123 | 124 | # Positional encoding 125 | self.l_red, self.padding = find_optimal_cnn1d_padding(L_in=length_seq, K=kernel, S=stride) 126 | self.de_pos_encoding = PositionalEncoding(d_embedding, max_len=self.l_red) 127 | self.de_dropout = nn.Dropout(dropout) 128 | 129 | # MHA blocks 130 | self.de_MHA_blocks = nn.ModuleList([MHAEncoderBlock(d_embedding, num_heads, d_ff, dropout) 131 | for _ in range(num_mha_layers)]) 132 | 133 | # Dense reconstruction 134 | self.dense_to_alphabet = nn.Linear(d_embedding, alphabet_size) 135 | self.dense_reconstruction = nn.Linear(alphabet_size * self.l_red, length_seq * alphabet_size) 136 | 137 | # CNN1d reconstruction 138 | self.out_pad = find_out_padding_cnn1d_transpose(L_obj=length_seq, L_in=self.l_red, K=kernel, S=stride, 139 | P=self.padding) 140 | self.cnn_reconstruction = nn.Sequential(Rearrange('b l r -> b r l'), 141 | nn.ConvTranspose1d(d_embedding, alphabet_size, kernel_size=kernel, 142 | stride=stride, 143 | padding=self.padding, output_padding=self.out_pad), 144 | Rearrange('b r l -> b l r')) 145 | 146 | def forward(self, q) -> torch.Tensor: 147 | """ 148 | Args: 149 | q: Tensor, shape [batch_size, l_red, d_embedding] 150 | """ 151 | # Positional encoding 152 | z = self.de_pos_encoding(q) 153 | z = self.de_dropout(z) 154 | 155 | #  MHA blocks 156 | for i, l in enumerate(self.de_MHA_blocks): 157 | z, attn_dec_weights = self.de_MHA_blocks[i](z) # (batch_size, l_red, d_embedding) 158 | 159 | # CNN reconstruction 160 | z = self.cnn_reconstruction(z) # (batch_size, input_seq_len, alphabet_size) 161 | z_recon = F.softmax(z, dim=-1) 162 | 163 | return z_recon 164 | 165 | 166 | class AbNatiV_Model(nn.Module): 167 | def __init__(self, hparams: dict): 168 | super(AbNatiV_Model, self).__init__() 169 | 170 | self.encoder = Encoder(hparams['d_embedding'], hparams['kernel'], hparams['stride'], hparams['num_heads'], 171 | hparams['num_mha_layers'], hparams['d_ff'], hparams['length_seq'], 172 | hparams['alphabet_size'], dropout=hparams['drop']) 173 | 174 | self.decoder = Decoder(hparams['d_embedding'], hparams['kernel'], hparams['stride'], hparams['num_heads'], 175 | hparams['num_mha_layers'], hparams['d_ff'], hparams['length_seq'], 176 | hparams['alphabet_size'], dropout=hparams['drop']) 177 | 178 | self.vqvae = VectorQuantize( 179 | dim=hparams['d_embedding'], 180 | codebook_size=hparams['num_embeddings'], 181 | codebook_dim=hparams['embedding_dim_code_book'], 182 | decay=hparams['decay'], 183 | kmeans_init=True, 184 | commitment_weight=hparams['commitment_cost'] 185 | ) 186 | 187 | self.learning_rate = hparams['learning_rate'] 188 | # self.save_hyperparameters() 189 | 190 | def forward(self, data) -> dict: 191 | inputs = data 192 | m_inputs = data.clone().detach() 193 | 194 | x = self.encoder(m_inputs) 195 | vq_outputs = self.vqvae(x) 196 | x_recon = self.decoder(vq_outputs['quantize_projected_out']) 197 | 198 | # Loss computing 199 | recon_error_pres_pposi = F.mse_loss(x_recon, inputs, reduction='none') 200 | recon_error_pposi = torch.mean(recon_error_pres_pposi, dim=-1) 201 | recon_error_pbe = torch.mean(recon_error_pposi, dim=1) 202 | 203 | loss_pbe = torch.add(recon_error_pbe, vq_outputs['loss_vq_commit_pbe']) 204 | 205 | return { 206 | 'inputs': inputs, # (batch_size, input_seq_len, alphabet_size) 207 | 'x_recon': x_recon, # (batch_size, input_seq_len, alphabet_size) 208 | 'recon_error_pres_pposi': recon_error_pres_pposi, # (batch_size, input_seq_len, alphabet_size) 209 | 'recon_error_pposi': recon_error_pposi, # (batch_size, input_seq_len) 210 | 'recon_error_pbe': recon_error_pbe, # (batch_size) 211 | 'loss_pbe': loss_pbe, # (batch_size) 212 | **vq_outputs 213 | } 214 | 215 | def configure_optimizers(self): 216 | optim_groups = list(self.encoder.parameters()) + \ 217 | list(self.decoder.parameters()) + \ 218 | list(self.vqvae.parameters()) 219 | 220 | return torch.optim.AdamW(optim_groups, lr=self.learning_rate) 221 | 222 | def training_step(self, batch, batch_idx) -> torch.float32: 223 | vqvae_output = self(batch) 224 | 225 | loss_vqvae = torch.mean(vqvae_output['loss_pbe']) 226 | self.log("train_loss_vqvae", loss_vqvae, on_step=True, prog_bar=True, logger=True) 227 | 228 | loss_vq_commit = torch.mean(vqvae_output['loss_vq_commit_pbe']) 229 | self.log("train_loss_vq_commit", loss_vq_commit, on_step=True, prog_bar=True, logger=True) 230 | 231 | nmse_accuracy = torch.mean(vqvae_output['recon_error_pbe']) 232 | self.log("train_loss_nmse_recons", nmse_accuracy, on_step=True, prog_bar=True, logger=True) 233 | 234 | perplexity = vqvae_output['perplexity'] 235 | self.log("train_perplexity", perplexity, on_step=True, prog_bar=True, logger=True) 236 | 237 | return loss_vqvae 238 | 239 | def validation_step(self, batch, batch_idx) -> dict: 240 | model_output = self(batch) 241 | return {'val_loss': torch.mean(model_output['loss_pbe']), 'model_output': model_output} 242 | 243 | def validation_epoch_end(self, outputs) -> dict: 244 | val_losses = torch.Tensor([out['val_loss'] for out in outputs]) 245 | total_val_loss = torch.mean(val_losses) 246 | self.log('val_loss', total_val_loss, on_epoch=True, logger=True) 247 | 248 | val_accuracies = torch.Tensor([torch.mean(out['model_output']['recon_error_pbe']) for out in outputs]) 249 | total_val_accuracy = torch.mean(val_accuracies) 250 | self.log('val_nmse_accuracy', total_val_accuracy, on_epoch=True, logger=True) 251 | 252 | val_perplexities = torch.Tensor([out['model_output']['perplexity'] for out in outputs]) 253 | total_val_perplexity = torch.mean(val_perplexities) 254 | self.log('val_perplexity', total_val_perplexity, on_epoch=True, logger=True) 255 | 256 | return {'val_loss': total_val_loss, 'val_nmse_accuracy': total_val_accuracy, 257 | 'val_perplexity': total_val_perplexity} 258 | 259 | 260 | -------------------------------------------------------------------------------- /model/nanoencoder/abnativ_onehot.py: -------------------------------------------------------------------------------- 1 | # (c) 2023 Sormannilab and Aubin Ramon 2 | # 3 | # AbNAtiV BERT-style masking OneHotEncoder Iterator. 4 | # 5 | # ============================================================================ 6 | 7 | from typing import Tuple 8 | import numpy as np 9 | import math 10 | import random 11 | import pandas as pd 12 | from pandas.api.types import CategoricalDtype 13 | 14 | from Bio import SeqIO 15 | import torch 16 | 17 | alphabet = ['A', 'C', 'D', 'E', 'F', 'G','H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', '-'] 18 | 19 | def data_loader_masking_bert_onehot_fasta(fp_data: str, batch_size: int, perc_masked_residues: float, 20 | is_masking: bool) -> torch.utils.data.DataLoader: 21 | ''' 22 | Generate a Torch dataloader iterator from fp_data. 23 | 24 | Parameters 25 | ---------- 26 | fp_data: str 27 | The align seq list. 28 | batch_size: int 29 | perc_masked_residues: float 30 | Ratio of residues to apply the BERT masking on (between 0 and 1). 31 | is_masking: bool 32 | 33 | ''' 34 | iterator = IterableMaskingBertOnehotDatasetFasta(fp_data, perc_masked_residues=perc_masked_residues, is_masking=is_masking) 35 | loader = torch.utils.data.DataLoader(iterator, batch_size=batch_size, num_workers=0, shuffle=is_masking) 36 | return loader 37 | 38 | 39 | class IterableMaskingBertOnehotDatasetFasta(torch.utils.data.IterableDataset): 40 | ''' 41 | BERT-style masking onehot generator for all sequences given a fasta file. 42 | ''' 43 | def __init__(self, fp_seq_list, perc_masked_residues=0.0, is_masking=False): 44 | self.fp_seq = fp_seq_list 45 | self.perc_masked_residues = perc_masked_residues 46 | self.is_masking = is_masking 47 | 48 | def __iter__(self) -> torch.utils.data.IterableDataset: 49 | for record in SeqIO.parse(self.fp_seq, 'fasta'): 50 | if len(str(record.seq)) != 149: 51 | raise Exception( 52 | f'Sequence {record.id} is shorter than 149 characters. All sequences must be aligned with the AHo scheme.') 53 | yield torch_masking_BERT_onehot(str(record.seq), perc_masked_residues=self.perc_masked_residues, 54 | is_masking=self.is_masking) 55 | 56 | def torch_masking_BERT_onehot(seq: str, perc_masked_residues: float=0.0, 57 | is_masking: bool=False, alphabet: list=alphabet) -> Tuple[torch.Tensor, torch.Tensor] or torch.Tensor: 58 | ''' 59 | BERT-style masking on a one-hot encoding input. When a residue is masked, it is replaced 60 | by the dummie vector [1/21,...,1/21] of size 21. 80% of perc_masked_residues are masked, 61 | 10% are replaced by another residue, 10% are left as they are. 62 | 63 | Parameters 64 | ---------- 65 | seq: str 66 | perc_masked_residues: float 67 | Ratio of residues to apply the BERT masking on (between 0 and 1). 68 | is_masking: bool 69 | False for evaluation. 70 | alphabet: list 71 | List of string of the alphabet of residues used in the one hot encoder 72 | 73 | Returns 74 | ------- 75 | onehot_seq: tensor 76 | One hot encoded input. 77 | m_tf_onehot_seq: tensor 78 | BERT masked one hot encoded input. 79 | 80 | ''' 81 | 82 | alphabet = ['A', 'C', 'D', 'E', 'F', 'G','H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', '-'] 83 | 84 | # One Hot Encoding 85 | onehot_seq = np.array((pd.get_dummies(pd.Series(list(seq)).astype(CategoricalDtype(categories=alphabet))))).astype(float) 86 | onehot_seq = torch.tensor(onehot_seq, dtype=torch.float32) 87 | ln_seq = len(onehot_seq) 88 | 89 | m_tf_onehot_seq = onehot_seq.clone().detach() 90 | 91 | if is_masking: 92 | if perc_masked_residues > 1: 93 | raise NotImplementedError('Masking percentage should be between 0 and 1.') 94 | 95 | # the onehot vector of the masked residue 96 | len_alphabet = len(alphabet) 97 | masked_letter = [1/len_alphabet]*len_alphabet 98 | 99 | # MASKING 100 | nb_masking = math.floor(ln_seq * perc_masked_residues) 101 | nb_to_mask = math.floor(nb_masking*0.8) #80% replace with mask token 102 | nb_to_replace = math.floor(nb_masking*0.1) #10% replace with random residue 103 | 104 | if nb_to_mask != 0: 105 | 106 | rd_ids = torch.Tensor(random.sample(range(ln_seq),ln_seq)[:nb_to_mask+nb_to_replace]).type(torch.int64) 107 | 108 | rd_alphabet_selection_to_replace = random.choices(alphabet, k=nb_to_replace) 109 | dummies_to_replace = np.array((pd.get_dummies(pd.Series(rd_alphabet_selection_to_replace).astype(CategoricalDtype(categories=alphabet))))) 110 | 111 | updates = np.array([masked_letter]*nb_to_mask) 112 | updates = torch.Tensor(np.concatenate((updates,dummies_to_replace))) 113 | 114 | m_tf_onehot_seq[rd_ids] = updates 115 | 116 | if is_masking: 117 | return onehot_seq, m_tf_onehot_seq 118 | else: 119 | return onehot_seq -------------------------------------------------------------------------------- /model/nanoencoder/abnativ_utils.py: -------------------------------------------------------------------------------- 1 | # (c) 2023 Sormannilab and Aubin Ramon 2 | # 3 | # Diverse functions to run AbNatiV. 4 | # 5 | # ============================================================================ 6 | 7 | from typing import Tuple 8 | import math 9 | 10 | import torch 11 | from torch import nn, einsum 12 | import torch.nn.functional as F 13 | import torch.distributed as distributed 14 | 15 | from einops import rearrange, repeat 16 | 17 | def is_protein(seq, aa_list): 18 | """ 19 | Check if a str corresponds to a protein sequence 20 | return bool 21 | """ 22 | for aa in seq: 23 | if aa not in aa_list: 24 | return False 25 | return True 26 | 27 | def l_out_cnn1d(L_in:int,K:int,S:int,P:int,D:int=1) -> float: 28 | '''Formula to find the L_out dimension of an input (dim=L_in) 29 | in cnn_1d.''' 30 | return (L_in+2*P-D*(K-1)-1)/S + 1 31 | 32 | def find_optimal_cnn1d_padding(L_in:int,K,S:int) -> Tuple[int,int]: 33 | '''Find the minimal padding giving the kernel size K and stride S 34 | for a CNN1D without losing any piece of information.''' 35 | P=0 36 | L_out = l_out_cnn1d(L_in,K,S,P) 37 | 38 | assert L_in>=K, 'Kernel size higher than input dimension, the conv1d will not work' 39 | 40 | while not L_out.is_integer() and 2*P<=S: 41 | L_out = l_out_cnn1d(L_in,K,S,P) 42 | P+=1 43 | 44 | if 2*P>=S: P-=1 45 | return math.floor(L_out), P 46 | 47 | def l_out_cnn1d_transpose(L_in:int,K:int,S:int,P:int,D:int=1) -> int: 48 | '''Formula to find the L_out dimension of an input (dim=L_in) 49 | in cnn_1d.''' 50 | return (L_in-1)*S -2*P + D*(K-1) + 1 51 | 52 | def find_out_padding_cnn1d_transpose(L_obj:int,L_in:int,K:int,S:int,P:int) -> int: 53 | '''Find the minimal output padding giving the kernel size K and stride S 54 | to add after a CNN1D transpose layer to reach L_obj (objective).''' 55 | L_out = l_out_cnn1d_transpose(L_in,K,S,P) 56 | assert L_obj>=L_out, 'Make sure the padding is correct, the ouput \ 57 | of the CNN1D transpose is larger than expeceted' 58 | return L_obj-L_out 59 | 60 | # From the enhancing VQ (https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/vector_quantize_pytorch.py) 61 | # Copyright (c) 2020 Phil Wang (MIT Licenced) 62 | 63 | def exists(val): 64 | return val is not None 65 | 66 | def default(val, d): 67 | return val if exists(val) else d 68 | 69 | def noop(*args, **kwargs): 70 | pass 71 | 72 | def l2norm(t): 73 | return F.normalize(t, p = 2, dim = -1) 74 | 75 | def log(t, eps = 1e-20): 76 | return torch.log(t.clamp(min = eps)) 77 | 78 | def uniform_init(*shape): 79 | t = torch.empty(shape) 80 | nn.init.kaiming_uniform_(t) 81 | return t 82 | 83 | def gumbel_noise(t): 84 | noise = torch.zeros_like(t).uniform_(0, 1) 85 | return -log(-log(noise)) 86 | 87 | def gumbel_sample(t, temperature = 1., dim = -1): 88 | if temperature == 0: 89 | return t.argmax(dim = dim) 90 | 91 | return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim) 92 | 93 | def ema_inplace(moving_avg, new, decay): 94 | moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) 95 | 96 | def laplace_smoothing(x, n_categories, eps = 1e-5): 97 | return (x + eps) / (x.sum() + n_categories * eps) 98 | 99 | def sample_vectors(samples, num): 100 | num_samples, device = samples.shape[0], samples.device 101 | if num_samples >= num: 102 | indices = torch.randperm(num_samples, device = device)[:num] 103 | else: 104 | indices = torch.randint(0, num_samples, (num,), device = device) 105 | 106 | return samples[indices] 107 | 108 | def batched_sample_vectors(samples, num): 109 | return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim = 0)], dim = 0) 110 | 111 | def pad_shape(shape, size, dim = 0): 112 | return [size if i == dim else s for i, s in enumerate(shape)] 113 | 114 | def sample_multinomial(total_count, probs): 115 | device = probs.device 116 | probs = probs.cpu() 117 | 118 | total_count = probs.new_full((), total_count) 119 | remainder = probs.new_ones(()) 120 | sample = torch.empty_like(probs, dtype = torch.long) 121 | 122 | for i, p in enumerate(probs): 123 | s = torch.binomial(total_count, p / remainder) 124 | sample[i] = s 125 | total_count -= s 126 | remainder -= p 127 | 128 | return sample.to(device) 129 | 130 | def all_gather_sizes(x, dim): 131 | size = torch.tensor(x.shape[dim], dtype = torch.long, device = x.device) 132 | all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())] 133 | distributed.all_gather(all_sizes, size) 134 | return torch.stack(all_sizes) 135 | 136 | def all_gather_variably_sized(x, sizes, dim = 0): 137 | rank = distributed.get_rank() 138 | all_x = [] 139 | 140 | for i, size in enumerate(sizes): 141 | t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim)) 142 | distributed.broadcast(t, src = i, async_op = True) 143 | all_x.append(t) 144 | 145 | distributed.barrier() 146 | return all_x 147 | 148 | def sample_vectors_distributed(local_samples, num): 149 | local_samples = rearrange(local_samples, '1 ... -> ...') 150 | 151 | rank = distributed.get_rank() 152 | all_num_samples = all_gather_sizes(local_samples, dim = 0) 153 | 154 | if rank == 0: 155 | samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum()) 156 | else: 157 | samples_per_rank = torch.empty_like(all_num_samples) 158 | 159 | distributed.broadcast(samples_per_rank, src = 0) 160 | samples_per_rank = samples_per_rank.tolist() 161 | 162 | local_samples = sample_vectors(local_samples, samples_per_rank[rank]) 163 | all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim = 0) 164 | out = torch.cat(all_samples, dim = 0) 165 | 166 | return rearrange(out, '... -> 1 ...') 167 | 168 | def batched_bincount(x, *, minlength): 169 | batch, dtype, device = x.shape[0], x.dtype, x.device 170 | target = torch.zeros(batch, minlength, dtype = dtype, device = device) 171 | values = torch.ones_like(x) 172 | target.scatter_add_(-1, x, values) 173 | return target 174 | 175 | def kmeans( 176 | samples, 177 | num_clusters, 178 | num_iters = 10, 179 | use_cosine_sim = False, 180 | sample_fn = batched_sample_vectors, 181 | all_reduce_fn = noop 182 | ): 183 | num_codebooks, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device 184 | 185 | means = sample_fn(samples, num_clusters) 186 | 187 | for _ in range(num_iters): 188 | if use_cosine_sim: 189 | dists = samples @ rearrange(means, 'h n d -> h d n') 190 | else: 191 | dists = -torch.cdist(samples, means, p = 2) 192 | 193 | buckets = torch.argmax(dists, dim = -1) 194 | bins = batched_bincount(buckets, minlength = num_clusters) 195 | all_reduce_fn(bins) 196 | 197 | zero_mask = bins == 0 198 | bins_min_clamped = bins.masked_fill(zero_mask, 1) 199 | 200 | new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype = dtype) 201 | 202 | new_means.scatter_add_(1, repeat(buckets, 'h n -> h n d', d = dim), samples) 203 | new_means = new_means / rearrange(bins_min_clamped, '... -> ... 1') 204 | all_reduce_fn(new_means) 205 | 206 | if use_cosine_sim: 207 | new_means = l2norm(new_means) 208 | 209 | means = torch.where( 210 | rearrange(zero_mask, '... -> ... 1'), 211 | means, 212 | new_means 213 | ) 214 | 215 | return means, bins 216 | 217 | def batched_embedding(indices, embeds): 218 | batch, dim = indices.shape[1], embeds.shape[-1] 219 | indices = repeat(indices, 'h b n -> h b n d', d = dim) 220 | embeds = repeat(embeds, 'h c d -> h b c d', b = batch) 221 | return embeds.gather(2, indices) 222 | 223 | # regularization losses 224 | 225 | def orthogonal_loss_fn(t): 226 | # eq (2) from https://arxiv.org/abs/2112.00384 227 | h, n = t.shape[:2] 228 | normed_codes = l2norm(t) 229 | cosine_sim = einsum('h i d, h j d -> h i j', normed_codes, normed_codes) 230 | return (cosine_sim ** 2).sum() / (h * n ** 2) - (1 / n) 231 | -------------------------------------------------------------------------------- /model/nanoencoder/abnativ_vq.py: -------------------------------------------------------------------------------- 1 | # (c) 2023 Sormannilab and Aubin Ramon 2 | # 3 | # Vector-Quantization of the latent space in the AbNatiV model. 4 | # 5 | # Modified from (https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/vector_quantize_pytorch.py) 6 | # under the copyright (c) 2020 Phil Wang (MIT Licenced) 7 | # ============================================================================ 8 | 9 | 10 | from einops import rearrange 11 | import pandas as pd 12 | import torch 13 | from torch import nn, einsum 14 | import torch.nn.functional as F 15 | import torch.distributed as distributed 16 | from torch.cuda.amp import autocast 17 | 18 | from .abnativ_utils import uniform_init, kmeans, sample_vectors_distributed, noop, batched_embedding, batched_sample_vectors 19 | from .abnativ_utils import l2norm, gumbel_sample, ema_inplace, default 20 | 21 | 22 | class CosineSimCodebook(nn.Module): 23 | def __init__( 24 | self, 25 | dim, 26 | codebook_size, 27 | num_codebooks = 1, 28 | kmeans_init = False, 29 | kmeans_iters = 10, 30 | sync_kmeans = True, 31 | decay = 0.8, 32 | eps = 1e-5, 33 | threshold_ema_dead_code = 3, 34 | use_ddp = False, 35 | learnable_codebook = False, 36 | sample_codebook_temp = 0. 37 | ): 38 | super().__init__() 39 | self.decay = decay 40 | 41 | if not kmeans_init: 42 | embed = l2norm(uniform_init(num_codebooks, codebook_size, dim)) 43 | else: 44 | embed = torch.zeros(num_codebooks, codebook_size, dim) 45 | 46 | self.codebook_size = codebook_size 47 | self.num_codebooks = num_codebooks 48 | 49 | self.kmeans_iters = kmeans_iters 50 | self.eps = eps 51 | self.threshold_ema_dead_code = threshold_ema_dead_code 52 | self.sample_codebook_temp = sample_codebook_temp 53 | 54 | self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors 55 | self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop 56 | self.all_reduce_fn = distributed.all_reduce if use_ddp else noop 57 | 58 | self.register_buffer('initted', torch.Tensor([not kmeans_init])) 59 | self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size)) 60 | 61 | self.learnable_codebook = learnable_codebook 62 | if learnable_codebook: 63 | self.embed = nn.Parameter(embed) 64 | else: 65 | self.register_buffer('embed', embed) 66 | 67 | @torch.jit.ignore 68 | def init_embed_(self, data): 69 | if self.initted: 70 | return 71 | 72 | embed, cluster_size = kmeans( 73 | data, 74 | self.codebook_size, 75 | self.kmeans_iters, 76 | use_cosine_sim = True, 77 | sample_fn = self.sample_fn, 78 | all_reduce_fn = self.kmeans_all_reduce_fn 79 | ) 80 | 81 | self.embed.data.copy_(embed) 82 | self.cluster_size.data.copy_(cluster_size) 83 | self.initted.data.copy_(torch.Tensor([True])) 84 | 85 | def replace(self, batch_samples, batch_mask): 86 | batch_samples = l2norm(batch_samples) 87 | 88 | for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))): 89 | if not torch.any(mask): 90 | continue 91 | 92 | sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item()) 93 | self.embed.data[ind][mask] = rearrange(sampled, '1 ... -> ...') 94 | 95 | def expire_codes_(self, batch_samples): 96 | if self.threshold_ema_dead_code == 0: 97 | return 98 | 99 | expired_codes = self.cluster_size < self.threshold_ema_dead_code 100 | 101 | if not torch.any(expired_codes): 102 | return 103 | 104 | batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d') 105 | self.replace(batch_samples, batch_mask = expired_codes) 106 | 107 | @autocast(enabled = False) 108 | def forward(self, x): 109 | needs_codebook_dim = x.ndim < 4 110 | 111 | x = x.float() 112 | 113 | if needs_codebook_dim: 114 | x = rearrange(x, '... -> 1 ...') 115 | 116 | shape, dtype = x.shape, x.dtype 117 | 118 | flatten = rearrange(x, 'h ... d -> h (...) d') 119 | flatten = l2norm(flatten) 120 | 121 | self.init_embed_(flatten) 122 | 123 | embed = self.embed if not self.learnable_codebook else self.embed.detach() 124 | embed = l2norm(embed) 125 | 126 | dist = einsum('h n d, h c d -> h n c', flatten, embed) 127 | embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp) 128 | embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) 129 | embed_ind = embed_ind.view(*shape[:-1]) 130 | 131 | quantize = batched_embedding(embed_ind, self.embed) 132 | 133 | if self.training: 134 | bins = embed_onehot.sum(dim = 1) 135 | self.all_reduce_fn(bins) 136 | 137 | ema_inplace(self.cluster_size, bins, self.decay) 138 | 139 | zero_mask = (bins == 0) 140 | bins = bins.masked_fill(zero_mask, 1.) 141 | 142 | embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot) 143 | self.all_reduce_fn(embed_sum) 144 | 145 | embed_normalized = embed_sum / rearrange(bins, '... -> ... 1') 146 | embed_normalized = l2norm(embed_normalized) 147 | 148 | embed_normalized = torch.where( 149 | rearrange(zero_mask, '... -> ... 1'), 150 | embed, 151 | embed_normalized 152 | ) 153 | 154 | ema_inplace(self.embed, embed_normalized, self.decay) 155 | self.expire_codes_(x) 156 | 157 | if needs_codebook_dim: 158 | quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind)) 159 | 160 | return quantize, embed_ind 161 | 162 | 163 | class VectorQuantize(nn.Module): 164 | def __init__( 165 | self, 166 | dim, 167 | codebook_size, 168 | codebook_dim, 169 | heads = 1, 170 | separate_codebook_per_head = False, 171 | decay = 0.8, 172 | eps = 1e-5, 173 | kmeans_init = True, 174 | kmeans_iters = 10, 175 | sync_kmeans = True, 176 | threshold_ema_dead_code = 3, 177 | commitment_weight = 1., 178 | orthogonal_reg_weight = 0., 179 | orthogonal_reg_active_codes_only = False, 180 | orthogonal_reg_max_codes = None, 181 | sample_codebook_temp = 0., 182 | sync_codebook = False 183 | ): 184 | super().__init__() 185 | self.heads = heads 186 | self.separate_codebook_per_head = separate_codebook_per_head 187 | 188 | codebook_dim = default(codebook_dim, dim) 189 | codebook_input_dim = codebook_dim * heads 190 | 191 | requires_projection = codebook_input_dim != dim 192 | self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() 193 | self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() 194 | 195 | self.eps = eps 196 | self.commitment_weight = commitment_weight 197 | 198 | has_codebook_orthogonal_loss = orthogonal_reg_weight > 0 199 | self.orthogonal_reg_weight = orthogonal_reg_weight 200 | self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only 201 | self.orthogonal_reg_max_codes = orthogonal_reg_max_codes 202 | 203 | codebook_class = CosineSimCodebook 204 | 205 | self._codebook = codebook_class( 206 | dim = codebook_dim, 207 | num_codebooks = heads if separate_codebook_per_head else 1, 208 | codebook_size = codebook_size, 209 | kmeans_init = kmeans_init, 210 | kmeans_iters = kmeans_iters, 211 | sync_kmeans = sync_kmeans, 212 | decay = decay, 213 | eps = eps, 214 | threshold_ema_dead_code = threshold_ema_dead_code, 215 | use_ddp = sync_codebook, 216 | learnable_codebook = has_codebook_orthogonal_loss, 217 | sample_codebook_temp = sample_codebook_temp 218 | ) 219 | 220 | self.codebook_size = codebook_size 221 | 222 | @property 223 | def codebook(self): 224 | codebook = self._codebook.embed 225 | if self.separate_codebook_per_head: 226 | return codebook 227 | 228 | return rearrange(codebook, '1 ... -> ...') 229 | 230 | def forward(self, x,): 231 | shape, device, heads, is_multiheaded, codebook_size = x.shape, x.device, self.heads, self.heads > 1, self.codebook_size 232 | 233 | x = self.project_in(x) 234 | 235 | if is_multiheaded: 236 | ein_rhs_eq = 'h b n d' if self.separate_codebook_per_head else '1 (b h) n d' 237 | x = rearrange(x, f'b n (h d) -> {ein_rhs_eq}', h = heads) 238 | 239 | quantize, embed_ind = self._codebook(x) 240 | 241 | if self.training: 242 | quantize = x + (quantize - x).detach() 243 | 244 | 245 | detached_inputs = x.detach() 246 | loss = F.mse_loss(quantize, detached_inputs, reduction='none') 247 | loss_pbe = torch.mean(loss, dim=(1,2)) # (batch_size) 248 | 249 | if self.commitment_weight > 0: 250 | detached_quantize = quantize.detach() 251 | commit_loss = F.mse_loss(detached_quantize, x, reduction='none') 252 | 253 | loss_pbe = loss_pbe + torch.mean(commit_loss * self.commitment_weight, dim=(1,2)) # (batch_size) 254 | 255 | if is_multiheaded: 256 | if self.separate_codebook_per_head: 257 | quantize = rearrange(quantize, 'h b n d -> b n (h d)', h = heads) 258 | embed_ind = rearrange(embed_ind, 'h b n -> b n h', h = heads) 259 | else: 260 | quantize = rearrange(quantize, '1 (b h) n d -> b n (h d)', h = heads) 261 | embed_ind = rearrange(embed_ind, '1 (b h) n -> b n h', h = heads) 262 | 263 | quantize_latent = quantize.detach().clone() 264 | quantize = self.project_out(quantize) 265 | 266 | avg_probs = torch.mean(F.one_hot(embed_ind, self.codebook_size).type(torch.float32).view((-1, self.codebook_size)), 0) 267 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 268 | 269 | return { 270 | 'quantize_projected_in': x, # (batch_size, l_r, codebook_dim) 271 | 'quantize_latent': quantize_latent, # (batch_size, l_r, codebook_dim) 272 | 'quantize_projected_out': quantize, # (batch_size, l_r, dim) 273 | 'loss_vq_commit_pbe': loss_pbe, # (batch_size) 274 | 'perplexity': perplexity, # (batch_size) 275 | 'encoding_indices': embed_ind # (batch_size, l_r) 276 | } 277 | 278 | -------------------------------------------------------------------------------- /model/nanoencoder/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from inspect import isfunction 4 | 5 | import torch 6 | from ..encoder.cross_attention import AttLayer 7 | 8 | 9 | class TransformerBlock(nn.Module): 10 | def __init__(self, d_model, nhead, att_model, dim_feedforward): 11 | super(TransformerBlock, self).__init__() 12 | self.attention = AttLayer(d_model, att_model, nhead) 13 | self.norm1 = nn.LayerNorm(d_model) 14 | self.norm2 = nn.LayerNorm(d_model) 15 | self.ffn = nn.Sequential( 16 | nn.Linear(d_model, dim_feedforward), 17 | nn.ReLU(), 18 | nn.Linear(dim_feedforward, d_model) 19 | ) 20 | 21 | def forward(self, x, mask=None): 22 | attn_output = self.attention(x, mask) 23 | x = x + attn_output 24 | x = self.norm1(x) 25 | 26 | ffn_output = self.ffn(x) 27 | x = x + ffn_output 28 | x = self.norm2(x) 29 | return x 30 | 31 | 32 | 33 | class SelfAttBlock(nn.Module): 34 | 35 | def __init__(self, d_model, att_model, dim_feedforward, nhead): 36 | super().__init__() 37 | self.attn_h = AttLayer(d_model, att_model, nhead) 38 | self.attn_h_c = AttLayer(d_model, att_model, nhead) 39 | 40 | self.norm_h1 = nn.LayerNorm(d_model) 41 | self.norm_h2 = nn.LayerNorm(d_model) 42 | 43 | self.ff_h = nn.Sequential( 44 | nn.Linear(d_model, dim_feedforward), 45 | nn.ReLU(), 46 | nn.Linear(dim_feedforward, d_model) 47 | ) 48 | 49 | def forward(self, h, mask=None): 50 | """ 51 | 52 | :param h: 53 | :param l: 54 | :param mask: 55 | :return: 56 | """ 57 | 58 | at_h = h + self.attn_hl(h) 59 | 60 | at_h = at_h + self.attn_h_c(self.norm_h1(at_h)) 61 | 62 | h = self.ff_h(self.norm_h2(at_h)) + h 63 | return h 64 | 65 | 66 | 67 | class SelfAttNet(nn.Module): 68 | 69 | def __init__(self, d_model, att_model, dim_feedforward, nhead, num_cross_layers): 70 | super().__init__() 71 | self.layers = nn.ModuleList( 72 | [SelfAttBlock(d_model, att_model, dim_feedforward, nhead) for _ in range(num_cross_layers)] 73 | ) 74 | 75 | def forward(self, h, mask=None): 76 | """ 77 | 78 | :param h: 79 | :param l: 80 | :param mask: 81 | :return: 82 | """ 83 | # h_l = h_l + pos_emb 84 | for layer in self.layers: 85 | h = layer(h) 86 | return h -------------------------------------------------------------------------------- /nanobody_scripts/nano_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from tqdm import tqdm 4 | from abnumber import Chain 5 | import numpy as np 6 | 7 | import pandas as pd 8 | from evaluation.T20_eval import frame_main as tframemain 9 | from utils.misc import get_logger 10 | 11 | 12 | def cal_fr_preservation(chain1, chain2): 13 | identity = 0 14 | fr_sum = 0 15 | align = chain1.align(chain2) 16 | for pos in align.positions: 17 | if not pos.is_in_cdr(): 18 | a1, a2 = align[pos] 19 | if a1 == a2: 20 | identity += 1 21 | fr_sum += 1 22 | return identity / fr_sum 23 | 24 | 25 | def cal_group_fr_germline_identity(df, scheme='imgt'): 26 | identity_fr_ratio_list = [] 27 | for idx in tqdm(df.index): 28 | try: 29 | h_seq = df.iloc[idx]['h_seq'] 30 | h_chain = Chain(h_seq, scheme=scheme) 31 | h_chain_graft = h_chain.graft_cdrs_onto_human_germline() 32 | fr_h_ratio = cal_fr_preservation(h_chain, h_chain_graft) 33 | identity_fr_ratio_list.append(fr_h_ratio) 34 | finally: 35 | continue 36 | 37 | return identity_fr_ratio_list 38 | 39 | 40 | def cal_group_fr_germline_identity_for_sample(df, scheme='imgt'): 41 | identity_fr_ratio_list = [] 42 | for idx in tqdm(df.index): 43 | try: 44 | h_seq = df.iloc[idx]['hseq'] 45 | h_chain = Chain(h_seq, scheme=scheme) 46 | h_chain_graft = h_chain.graft_cdrs_onto_human_germline() 47 | fr_h_ratio = cal_fr_preservation(h_chain, h_chain_graft) 48 | identity_fr_ratio_list.append(fr_h_ratio) 49 | finally: 50 | continue 51 | 52 | return identity_fr_ratio_list 53 | 54 | 55 | def cal_mean(vh_dir, vhh_dir): 56 | vh_fpath = os.path.join(vh_dir, 'sample_nano_vh_abnativ_seq_scores.csv') 57 | vhh_fpath = os.path.join(vhh_dir, 'sample_nano_vhh_abnativ_seq_scores.csv') 58 | 59 | sample_vh_df = pd.read_csv(vh_fpath) 60 | sample_vh_score = sample_vh_df['AbNatiV VH Score'] 61 | 62 | sample_vhh_df = pd.read_csv(vhh_fpath) 63 | sample_vhh_score = sample_vhh_df['AbNatiV VHH Score'] 64 | 65 | ref_vh_score = 0.7378085839359757 66 | ref_vhh_score = 0.9143594023426274 67 | 68 | dev_vh_score = sample_vh_score.mean() - ref_vh_score 69 | print('Raw sample result: {}'.format(sample_vh_score.mean())) 70 | 71 | return dev_vh_score, sample_vh_score.mean(), sample_vhh_score.mean() 72 | 73 | 74 | def get_raw_frame_t20_score(): 75 | raw_frame_t20_fpath = '/sample_frame_t20_score.csv' 76 | raw_frame_t20_df = pd.read_csv(raw_frame_t20_fpath) 77 | raw_frame_t20_mean = raw_frame_t20_df['h_score'].mean() 78 | return raw_frame_t20_df, raw_frame_t20_mean 79 | 80 | 81 | def main(root_path=None): 82 | 83 | if root_path is None: 84 | 85 | root_path = 'sample_humanization_result.csv' 86 | 87 | 88 | logdir = os.path.dirname(root_path) 89 | logger = get_logger('sample', logdir, log_name='eval_log.txt') 90 | 91 | # abnativ path. 92 | exec_path = 'bin/abnativ' 93 | input_fa_path = os.path.join(os.path.dirname(root_path), 'sample_identity.fa') 94 | output_vh_dir = os.path.join(os.path.dirname(root_path), 'sample_nano_vh/') 95 | output_vhh_dir = os.path.join(os.path.dirname(root_path), 'sample_nano_vhh/') 96 | if not os.path.exists(output_vh_dir): 97 | print('Eval vh score ......') 98 | subprocess.Popen([exec_path, 'score', '-nat', 'VH', '-i', input_fa_path, '-odir', output_vh_dir, '-oid', 'sample_nano_vh', '-align'], 99 | stderr=subprocess.PIPE, stdout=subprocess.PIPE).communicate() 100 | else: 101 | print('VH exists, Skip!') 102 | 103 | if not os.path.exists(output_vhh_dir): 104 | print('Eval vhh score ......') 105 | subprocess.Popen([exec_path, 'score', '-nat', 'VHH', '-i', input_fa_path, '-odir', output_vhh_dir, '-oid', 'sample_nano_vhh', '-align', '-isVHH'], 106 | stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL).communicate() 107 | else: 108 | print('VHH exists, Skip!') 109 | 110 | dev_vh, sample_ab_mean, sample_ab_vhh_mean = cal_mean(output_vh_dir, output_vhh_dir) 111 | 112 | # T20 113 | t20_frame_save_fpath = tframemain(root_path) 114 | sample_t20_frame_df = pd.read_csv(t20_frame_save_fpath) 115 | sample_frame_t20 = sample_t20_frame_df['h_score'].mean() 116 | 117 | raw_frame_t20_df, raw_frame_t20 = get_raw_frame_t20_score() 118 | 119 | # Identity 120 | sample_result_df = pd.read_csv(root_path) 121 | sample_human_df = sample_result_df[sample_result_df['Specific'] == 'humanization'].reset_index(drop=True) 122 | sample_identity = cal_group_fr_germline_identity_for_sample(sample_human_df, scheme='imgt') 123 | raw_identity = cal_group_fr_germline_identity(raw_frame_t20_df, scheme='imgt') 124 | 125 | 126 | logger.info('Raw Frame t20 score {}'.format(raw_frame_t20)) 127 | logger.info('Sample Frame t20 score {}'.format(sample_frame_t20)) 128 | logger.info('Improve Frame t20 score {}'.format(sample_frame_t20-raw_frame_t20)) 129 | 130 | logger.info('Eval path is {}'.format(root_path)) 131 | logger.info('VH score improve: {}'.format(dev_vh)) 132 | logger.info('Sample VH score: {}'.format(sample_ab_mean)) 133 | logger.info('Sample VHH score: {}'.format(sample_ab_vhh_mean)) 134 | 135 | logger.info('Sample Germline identity(seq_num): {}({})'.format(np.array(sample_identity).mean(), len(sample_identity))) 136 | logger.info('Raw Germline identity(seq_num): {}({})'.format(np.array(raw_identity).mean(), len(raw_identity))) 137 | 138 | 139 | if __name__ == '__main__': 140 | current_ld_library_path = os.getenv("LD_LIBRARY_PATH", "") 141 | os.environ['LD_LIBRARY_PATH'] = 'anaconda3/envs/abnativ/lib:' + current_ld_library_path 142 | 143 | main() -------------------------------------------------------------------------------- /nanobody_scripts/nanofinetune_run.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | echo "Using GPU: $CUDA_VISIBLE_DEVICES" 3 | VHH_DATA_PATH='oas_vhh_data/vhh_nano_idx.pt' 4 | CONFIG_PATH='configs/training_nano_framework.yml' 5 | LOG_PATH='tmp/nano_finetune_log/' 6 | python nanobody_scripts/nanofinetune.py \ 7 | --vhh_data_fpath $VHH_DATA_PATH \ 8 | --config_path $CONFIG_PATH \ 9 | --log_path $LOG_PATH \ -------------------------------------------------------------------------------- /nanobody_scripts/nanotrain.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import pickle 3 | import numpy as np 4 | import argparse 5 | import yaml 6 | from easydict import EasyDict 7 | import shutil 8 | import sys 9 | current_dir = os.path.dirname(os.path.dirname(__file__)) 10 | sys.path.append(current_dir) 11 | 12 | import torch 13 | import torch.utils.tensorboard 14 | from tqdm import tqdm 15 | from torch.nn.utils import clip_grad_norm_ 16 | 17 | from dataset.oas_unpair_dataset_new import OasHeavyMaskCollater 18 | from torch.utils.data import DataLoader 19 | from utils.train_utils import model_selected, optimizer_selected, scheduler_selected 20 | from utils.misc import seed_all, get_new_log_dir, get_logger, inf_iterator, count_parameters 21 | from utils.loss import MaskedAccuracy, OasMaskedHeavyCrossEntropyLoss 22 | from utils.train_utils import get_dataset 23 | 24 | 25 | def convert_multi_gpu_checkpoint_to_single_gpu(checkpoint): 26 | if 'module' in list(checkpoint['model'].keys())[0]: 27 | new_state_dict = {} 28 | for key, value in checkpoint['model'].items(): 29 | new_key = key.replace('module.', '') # Remove 'module.' prefix 30 | new_state_dict[new_key] = value 31 | checkpoint['model'] = new_state_dict 32 | return checkpoint['model'] 33 | 34 | 35 | def freeze_parameters(block): 36 | for x in block: 37 | x.requires_grad = False 38 | 39 | def unfreeze_parameters(block): 40 | for x in block: 41 | x.requires_grad = True 42 | 43 | def train(it): 44 | 45 | H_sum_loss, H_sum_nll = 0., 0. 46 | H_sum_cdr_loss = 0. 47 | sum_loss = 0 48 | H_sum_acc_loss = 0. 49 | sum_roc_auc = 0. 50 | 51 | model.train() 52 | for _ in range(config.train.batch_acc): 53 | optimizer.zero_grad() 54 | 55 | (H_src, H_tgt, H_region, chain_type, 56 | H_masks, H_cdr_masks, 57 | H_timesteps) = next(train_iterator) 58 | H_src, H_tgt = H_src.to(device), H_tgt.to(device) 59 | H_region = H_region.to(device) 60 | chain_type = chain_type.to(device) 61 | H_masks, H_cdr_masks = H_masks.to(device), H_cdr_masks.to(device) 62 | H_timesteps = H_timesteps.to(device) 63 | 64 | H_pred = model(H_src, H_region, chain_type) 65 | 66 | H_loss, H_nll, H_cdr_loss = cross_loss( 67 | H_pred, 68 | H_tgt, 69 | H_masks, 70 | H_cdr_masks, 71 | H_timesteps 72 | ) 73 | if args.train_loss == 'fr': 74 | loss = H_loss 75 | elif args.train_loss == 'all': 76 | loss = H_loss + H_cdr_loss 77 | else: 78 | loss = None 79 | print("Please set correct train loss type!") 80 | 81 | loss.mean() 82 | loss.backward() 83 | optimizer.step() 84 | 85 | sum_loss += loss 86 | H_sum_loss += H_loss 87 | H_sum_nll += H_nll 88 | H_sum_cdr_loss += H_cdr_loss 89 | 90 | # Those value indicate whether the pred equl to tgt. max may is 1. 91 | # Not backward. 92 | H_acc_loss, roc_auc = mask_acc_loss(H_pred, H_tgt, H_masks) 93 | H_sum_acc_loss += H_acc_loss 94 | sum_roc_auc += roc_auc 95 | 96 | 97 | mean_loss = sum_loss / config.train.batch_acc 98 | mean_H_loss = H_sum_loss / config.train.batch_acc 99 | mean_H_nll = H_sum_nll / config.train.batch_acc 100 | 101 | mean_H_cdr_loss = H_sum_cdr_loss / config.train.batch_acc 102 | 103 | # Not backward. 104 | mean_H_acc_loss = H_sum_acc_loss / config.train.batch_acc 105 | mean_roc_auc = sum_roc_auc / config.train.batch_acc 106 | 107 | logger.info('Training iter {}, Loss is: {:.6f} | H_loss: {:.6f} | H_nll: {:.6f} ' 108 | '| H_cdr_loss: {:.6f} | H_acc: {:.6f} | ROC_AUC: {:.6f}'. 109 | format(it, mean_loss, mean_H_loss, mean_H_nll, mean_H_cdr_loss, mean_H_acc_loss, mean_roc_auc)) 110 | writer.add_scalar('train/loss', mean_loss, it) 111 | writer.add_scalar('train/H_loss', mean_H_loss, it) 112 | writer.add_scalar('train/H_nll', mean_H_nll, it) 113 | writer.add_scalar('train/H_cdr_loss', mean_H_cdr_loss, it) 114 | writer.add_scalar('train/H_acc', mean_H_acc_loss, it) 115 | writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], it) 116 | writer.add_scalar('train/roc_auc', mean_roc_auc, it) 117 | writer.flush() 118 | 119 | 120 | 121 | def valid(it, valid_type): 122 | H_sum_loss, H_sum_nll = 0., 0. 123 | H_sum_cdr_loss = 0. 124 | H_sum_acc_loss = 0. 125 | sum_valid_loss = 0. 126 | sum_roc_auc = 0. 127 | model.eval() 128 | 129 | 130 | val_sum = len(val_loader) 131 | with torch.no_grad(): 132 | for batch in tqdm(val_loader, desc='Val', total=len(val_loader)): 133 | (H_src, H_tgt, H_region, chain_type, 134 | H_masks, H_cdr_masks, 135 | H_timesteps) = batch 136 | H_src, H_tgt = H_src.to(device), H_tgt.to(device) 137 | H_region = H_region.to(device) 138 | chain_type = chain_type.to(device) 139 | H_masks, H_cdr_masks = H_masks.to(device), H_cdr_masks.to(device) 140 | H_timesteps = H_timesteps.to(device) 141 | 142 | H_pred = model(H_src, H_region, chain_type) 143 | 144 | H_loss, H_nll, H_cdr_loss = cross_loss( 145 | H_pred, 146 | H_tgt, 147 | H_masks, 148 | H_cdr_masks, 149 | H_timesteps 150 | ) 151 | 152 | if args.train_loss == 'fr': 153 | loss = H_loss 154 | elif args.train_loss == 'all': 155 | loss = H_loss + H_cdr_loss 156 | else: 157 | loss = None 158 | print("Please set correct train loss type!") 159 | 160 | sum_valid_loss += loss 161 | 162 | H_sum_loss += H_loss 163 | H_sum_nll += H_nll 164 | H_sum_cdr_loss += H_cdr_loss 165 | 166 | # Those value indicate whether the pred equl to tgt. max may is 1. 167 | H_acc_loss, roc_auc = mask_acc_loss(H_pred, H_tgt, H_masks) 168 | 169 | # Not backward. 170 | H_sum_acc_loss += H_acc_loss 171 | sum_roc_auc += roc_auc 172 | 173 | mean_loss = sum_valid_loss / val_sum 174 | mean_H_loss = H_sum_loss / val_sum 175 | mean_H_nll = H_sum_nll / val_sum 176 | mean_H_cdr_loss = H_sum_cdr_loss / val_sum 177 | 178 | # Not backward. 179 | mean_H_acc_loss = H_sum_acc_loss / val_sum 180 | mean_roc_auc = sum_roc_auc / val_sum 181 | 182 | scheduler.step(mean_loss) 183 | 184 | logger.info('Validation iter {}, Loss is: {:.6f} | H_loss: {:.6f} | H_nll: {:.6f} ' 185 | '| H_cdr_loss: {:.6f} | H_acc: {:.6f} | ROC_AUC: {:.6f}'. 186 | format(it, mean_loss, mean_H_loss, mean_H_nll, 187 | mean_H_cdr_loss, mean_H_acc_loss, mean_roc_auc)) 188 | writer.add_scalar('val/loss', mean_loss, it) 189 | writer.add_scalar('val/H_loss', mean_H_loss, it) 190 | writer.add_scalar('val/H_nll', mean_H_nll, it) 191 | writer.add_scalar('val/H_cdr_loss', mean_H_cdr_loss, it) 192 | writer.add_scalar('val/H_acc', mean_H_acc_loss, it) 193 | writer.add_scalar('val/roc_auc', mean_roc_auc, it) 194 | writer.flush() 195 | 196 | return mean_loss 197 | 198 | 199 | if __name__ == '__main__': 200 | # Required args. 201 | parser = argparse.ArgumentParser() 202 | parser.add_argument('--unpair_data_path', type=str, 203 | default=None 204 | ) 205 | parser.add_argument('--data_name', type=str, 206 | default='heavy') 207 | parser.add_argument('--train_model', type=str, 208 | default='heavy') 209 | parser.add_argument('--data_version', type=str, 210 | default='test') 211 | parser.add_argument('--train_loss', type=str, 212 | default='all', choices=['fr', 'all']) 213 | parser.add_argument('--config_path', type=str, 214 | default=None 215 | ) 216 | parser.add_argument('--log_path', type=str, 217 | default=None 218 | ) 219 | parser.add_argument('--resume', type=bool, 220 | default=False) 221 | parser.add_argument('--checkpoint', type=str, 222 | default=None) 223 | 224 | args = parser.parse_args() 225 | 226 | # Device 227 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 228 | 229 | # Config parameters. 230 | if not args.resume: 231 | with open(args.config_path, 'r') as f: 232 | config = EasyDict(yaml.safe_load(f)) 233 | else: 234 | assert args.checkpoint != '', "Need Specified Checkpoint." 235 | ckpt_path = args.checkpoint 236 | ckpt = torch.load(ckpt_path, map_location='cpu') 237 | config = ckpt['config'] 238 | 239 | version = f'{args.data_name}_{args.train_model}_{args.data_version}_{args.train_loss}' 240 | # Create Log dir. 241 | log_dir = get_new_log_dir( 242 | root=args.log_path, 243 | prefix=version 244 | ) 245 | 246 | # Checkpoints dir. 247 | ckpt_dir = os.path.join(log_dir, 'checkpoints') 248 | os.makedirs(ckpt_dir, exist_ok=True) 249 | 250 | # logger and writer 251 | logger = get_logger('train', log_dir) 252 | writer = torch.utils.tensorboard.SummaryWriter(log_dir) 253 | 254 | logger.info(args) 255 | logger.info(config) 256 | 257 | 258 | # Copy files for checking. 259 | shutil.copyfile(args.config_path, os.path.join(log_dir, os.path.basename(args.config_path))) 260 | shutil.copyfile('./nanobody_scripts/nanotrain.py', os.path.join(log_dir, 'nanotrain.py')) 261 | shutil.copytree('./model', os.path.join(log_dir, 'model')) 262 | 263 | 264 | # Fixed 265 | seed_all(config.train.seed) 266 | 267 | # Create dataloader. 268 | h_subsets = get_dataset(args.unpair_data_path, args.data_name, args.data_version) 269 | train_dataset, val_dataset = h_subsets['train'], h_subsets['val'] 270 | collater = OasHeavyMaskCollater() 271 | 272 | # Only consider Heavy. 273 | train_iterator = inf_iterator(DataLoader( 274 | train_dataset, 275 | batch_size=config.train.batch_size, 276 | num_workers=config.train.num_workers, 277 | shuffle=True, 278 | collate_fn=collater 279 | )) 280 | logger.info(f'Training: {len(train_dataset)} Validation: {len(val_dataset)}') 281 | val_loader = DataLoader( 282 | val_dataset, 283 | batch_size=config.train.batch_size, 284 | num_workers=config.train.num_workers, 285 | collate_fn=collater 286 | ) 287 | logger.info('Dataloader has created!') 288 | 289 | # Build model. 290 | logger.info('Building model and initializing!') 291 | 292 | model = model_selected(config).to(device) 293 | if args.resume: 294 | ckpt_model = convert_multi_gpu_checkpoint_to_single_gpu(ckpt) 295 | model.load_state_dict(ckpt_model) 296 | 297 | # Build optimizer and scheduler. 298 | optimizer = optimizer_selected(config.train.optimizer, model) 299 | scheduler = scheduler_selected(config.train.scheduler, optimizer) 300 | 301 | # Config the type of loss. 302 | cross_loss = OasMaskedHeavyCrossEntropyLoss() 303 | mask_acc_loss = MaskedAccuracy() # Do not be considered during backward, only make sure the correction of mask. 304 | 305 | if args.resume: 306 | optimizer.load_state_dict(ckpt['optimizer']) 307 | scheduler.load_state_dict(ckpt['scheduler']) 308 | # """Do not use the ckpt optimizer, because other layer has freezed.""" 309 | it_sum = ckpt['iteration'] 310 | logger.info('The re iteration start from {}'.format(it_sum)) 311 | 312 | logger.info(f'# trainable parameters: {count_parameters(model) / 1e6:.4f} M') 313 | logger.info('Training...') 314 | best_val_loss = torch.inf 315 | best_iter = 0 316 | for it in range(0, config.train.max_iter+1): 317 | train(it) 318 | if it % config.train.valid_step == 0 or it == config.train.max_iter: 319 | valid_loss = valid(it, valid_type=args.data_name) 320 | # valid_loss = 1 321 | if valid_loss < best_val_loss: 322 | best_val_loss, best_iter = valid_loss, it 323 | logger.info(f'Bset validate loss achieved: {best_val_loss:.6f}') 324 | ckpt_path = os.path.join(ckpt_dir, '%d.pt'%it) 325 | torch.save({ 326 | 'config': config, 327 | 'model': model.state_dict(), 328 | 'optimizer': optimizer.state_dict(), 329 | 'scheduler': scheduler.state_dict(), 330 | 'iteration': it, 331 | }, ckpt_path) 332 | else: 333 | logger.info(f'[Validate] Val loss is not improved. ' 334 | f'Best val loss: {best_val_loss:.6f} at iter {best_iter}') 335 | 336 | -------------------------------------------------------------------------------- /nanobody_scripts/nanotrain_run.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | echo "Using GPU: $CUDA_VISIBLE_DEVICES" 3 | UNPAIR_DATA_PATH='oas_heavy_human_data/heavy_nano_idx.pt' 4 | CONFIG_PATH='configs/heavy_train.yml' 5 | LOG_PATH='tmp/heavy_train_log/' 6 | python nanobody_scripts/nanotrain.py \ 7 | --unpair_data_path $UNPAIR_DATA_PATH \ 8 | --config_path $CONFIG_PATH \ 9 | --log_path $LOG_PATH \ -------------------------------------------------------------------------------- /nanobody_scripts/sample_for_nano_cdr.py: -------------------------------------------------------------------------------- 1 | """ This script only consider for the nanobody. """ 2 | import os.path 3 | 4 | import numpy as np 5 | import torch 6 | from tqdm import tqdm 7 | import argparse 8 | import pandas as pd 9 | from abnumber import Chain 10 | from anarci import anarci, number 11 | from copy import deepcopy 12 | import re 13 | from Bio.Seq import Seq 14 | from Bio.SeqRecord import SeqRecord 15 | from Bio import SeqIO 16 | 17 | from nanosample import (batch_input_element, save_nano, seqs_to_fasta, 18 | compare_length, get_diff_region_aa_seq, get_pad_seq, 19 | get_input_element, get_nano_line, out_humanization_df, 20 | save_seq_to_fasta, split_fasta_for_save, 21 | get_multi_model_state 22 | ) 23 | from utils.tokenizer import Tokenizer 24 | from utils.train_utils import model_selected 25 | from utils.misc import get_new_log_dir, get_logger, seed_all 26 | 27 | # Finetune package 28 | from model.nanoencoder.abnativ_model import AbNatiV_Model 29 | from model.nanoencoder.model import NanoAntiTFNet 30 | 31 | 32 | def get_nano_seq_from_fasta(fpath): 33 | """ 34 | Split the heavy and light chain from the raw fasta file. 35 | :param fpath: the raw fasta file path. 36 | :return: heavy sequence, light sequence. 37 | """ 38 | nano_chain = None 39 | sequences = SeqIO.parse(fpath, 'fasta') 40 | for seq in sequences: 41 | if 'Nanobody' in seq.description: 42 | nano_chain = str(seq.seq) 43 | else: 44 | continue 45 | assert nano_chain is not None, print("Reading the fasta has problem.") 46 | return nano_chain 47 | 48 | 49 | if __name__ == '__main__': 50 | parser = argparse.ArgumentParser(description="This program is designed to humanize non-human nanobodies.") 51 | parser.add_argument('--ckpt', type=str, 52 | default=None, 53 | help='The ckpt path of the pretrained path.' 54 | ) 55 | parser.add_argument('--nano_complex_fasta', type=str, 56 | default=None, 57 | help='fasta file of the nanobody.' 58 | ) 59 | parser.add_argument('--batch_size', type=int, 60 | default=10, 61 | help='the batch size of sample.' 62 | ) 63 | parser.add_argument('--sample_number', type=int, 64 | default=100, 65 | help='The number of all sample.' 66 | ) 67 | parser.add_argument('--seed', type=int, 68 | default=42 69 | ) 70 | parser.add_argument('--sample_order', type=str, 71 | default='shuffle') 72 | parser.add_argument('--sample_method', type=str, 73 | default='gen', choices=['gen', 'rl_gen']) 74 | parser.add_argument('--length_limit', type=str, 75 | default='not_equal') 76 | parser.add_argument('--model', type=str, 77 | default='finetune_vh', choices=['pretrain', 'finetune_vh']) 78 | parser.add_argument('--fa_version', type=str, 79 | default='v_nano') 80 | parser.add_argument('--inpaint_sample', type=eval, 81 | default=True) 82 | parser.add_argument('--structure', type=eval, 83 | default=False) 84 | args = parser.parse_args() 85 | 86 | batch_size = args.batch_size 87 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 88 | # seed_all(args.seed) 89 | 90 | # Make sure the name of sample log. 91 | pdb_name = os.path.basename(args.nano_complex_fasta).split('.')[0] 92 | sample_tag = f'{pdb_name}_{args.model}_vhh' 93 | 94 | # log dir 95 | log_path = os.path.dirname(args.nano_complex_fasta) 96 | log_dir = get_new_log_dir( 97 | root=log_path, 98 | prefix=sample_tag 99 | ) 100 | logger = get_logger('test', log_dir) 101 | 102 | # Here we specify the finetune model to generate the humanization seq. 103 | ckpt = torch.load(args.ckpt) 104 | config = ckpt['config'] 105 | abnativ_state, _, infilling_state = get_multi_model_state(ckpt) 106 | # Abnativ model. 107 | hparams = ckpt['abnativ_params'] 108 | abnativ_model = AbNatiV_Model(hparams) 109 | abnativ_model.load_state_dict(abnativ_state) 110 | abnativ_model.to(device) 111 | # infilling model. 112 | # infilling_params = config.model 113 | infilling_params = ckpt['infilling_params'] 114 | infilling_model = NanoAntiTFNet(**infilling_params) 115 | infilling_model.load_state_dict(infilling_state) 116 | infilling_model.to(device) 117 | 118 | # Carefull!!! tmp 119 | config.model['equal_weight'] = True 120 | config.model['vhh_nativeness'] = False 121 | config.model['human_threshold'] = None 122 | config.model['human_all_seq'] = False 123 | config.model['temperature'] = False 124 | 125 | model_dict = { 126 | 'abnativ': abnativ_model, 127 | 'infilling': infilling_model, 128 | 'target_infilling': infilling_model 129 | } 130 | framework_model = model_selected(config, pretrained_model=model_dict, tokenizer=Tokenizer()) 131 | model = framework_model.infilling_pretrain 132 | model.eval() 133 | 134 | logger.info(args.ckpt) 135 | logger.info(args.seed) 136 | 137 | 138 | # Read the fasta file of nanobody. 139 | nano_chain = get_nano_seq_from_fasta(args.nano_complex_fasta) 140 | 141 | # save path 142 | save_fpath = os.path.join(log_dir, 'sample_humanization_result.csv') 143 | origin = 'Nano' 144 | with open(save_fpath, 'a', encoding='UTF-8') as f: 145 | f.write('Specific,name,hseq,\n') 146 | f.write(f'{origin},{pdb_name},{nano_chain}\n') 147 | 148 | wrong_idx_list = [] 149 | length_not_equal_list = [] 150 | sample_number = args.sample_number 151 | 152 | 153 | try: 154 | nano_pad_token, nano_pad_region, nano_loc, ms_tokenizer = batch_input_element( 155 | nano_chain, 156 | inpaint_sample=args.inpaint_sample, 157 | batch_size=batch_size 158 | ) 159 | except: 160 | logger.info('This nanobody encoding may have problem, please check!') 161 | 162 | if args.sample_order == 'shuffle': 163 | np.random.shuffle(nano_loc) 164 | 165 | duplicated_set = set() 166 | 167 | while sample_number > 0: 168 | all_token = ms_tokenizer.toks 169 | with torch.no_grad(): 170 | for i in tqdm(nano_loc, total=len(nano_loc), desc='Nanobody Humanization Process'): 171 | nano_prediction = model( 172 | nano_pad_token.to(device), 173 | nano_pad_region.to(device), 174 | H_chn_type=None 175 | ) 176 | 177 | nano_pred = nano_prediction[:, i, :len(all_token)-1] 178 | nano_soft = torch.nn.functional.softmax(nano_pred, dim=1) 179 | nano_sample = torch.multinomial(nano_soft, num_samples=1) 180 | nano_pad_token[:, i] = nano_sample.squeeze() 181 | 182 | nano_untokenized = [ms_tokenizer.idx2seq(s) for s in nano_pad_token] 183 | for _, g_h in enumerate(nano_untokenized): 184 | if sample_number == 0: 185 | break 186 | 187 | with open(save_fpath, 'a', encoding='UTF-8') as f: 188 | # try: 189 | sample_origin = 'humanization' 190 | sample_name = str(pdb_name) 191 | # Make sure that the sample seq can be detected by the Chain. 192 | # Duplicated. 193 | if g_h not in duplicated_set: 194 | test_chain = Chain(g_h, scheme='imgt') 195 | f.write(f'{sample_origin},{sample_name},{g_h}\n') 196 | duplicated_set.add(g_h) 197 | sample_number -= 1 198 | logger.info('Already Sample number {}'.format(args.sample_number - sample_number)) 199 | logger.info('Sample Heavy Chain Seq: {}'.format(g_h)) 200 | else: 201 | sample_number -= 1 202 | 203 | # Save as fasta for biophi oasis. 204 | fasta_save_fpath = os.path.join(log_dir, 'sample_identity.fa') 205 | logger.info('Save fasta fpath: {}'.format(fasta_save_fpath)) 206 | sample_df = pd.read_csv(save_fpath) 207 | sample_human_df = sample_df[sample_df['Specific'] == 'humanization'].reset_index(drop=True) 208 | seqs_to_fasta(sample_human_df, fasta_save_fpath, version=args.fa_version) 209 | 210 | # Split save as fasta for structure prediction. 211 | if args.structure: 212 | split_fasta_for_save(save_fpath) 213 | 214 | 215 | logger.info('Length did not equal list: {}'.format(length_not_equal_list)) 216 | logger.info('Wrong idx: {}'.format(wrong_idx_list)) 217 | 218 | 219 | 220 | 221 | -------------------------------------------------------------------------------- /start_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | docker run -it --gpus all --rm --hostname=local-env --network=host \ 3 | -v `pwd`:/opt/ml/env -w /opt/ml/env \ 4 | DOCKER_PATH bash 5 | -------------------------------------------------------------------------------- /utils/anti_numbering.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import torch 3 | 4 | def get_regions(aa_seq, env_name='antidiff'): 5 | ''' 6 | :param aa_seq: the sequence of antibody. 7 | :return: the different regions label. 8 | ''' 9 | cmd_str = f'ANARCI --sequence {aa_seq}' 10 | complete_cmd = f'eval "$(conda shell.bash hook)" && conda activate {env_name} && {cmd_str}' 11 | cmd_out = subprocess.check_output(complete_cmd, shell=True) 12 | line_strs = cmd_out.decode('utf-8').split('\n') 13 | assert line_strs[2] == '# Domain 1 of 1' 14 | 15 | sub_strs = line_strs[5].split('|') 16 | chn_type = sub_strs[2] 17 | if chn_type == 'K': 18 | chn_type = 'L' 19 | idx_resd_beg = int(sub_strs[5]) # inclusive 20 | idx_resd_end = int(sub_strs[6]) # inclusive 21 | 22 | idx_resd = idx_resd_beg 23 | labl_vec = torch.zeros(len(aa_seq), dtype=torch.int8) # 0: framework 24 | fv1, cdr1, fv2, cdr2, fv3, cdr3, fv4 = 0, 0, 0, 0, 0, 0, 0 25 | for line_str in line_strs: 26 | if not line_str.startswith(chn_type): 27 | continue 28 | if line_str.endswith('-'): 29 | continue 30 | idx_resd_imgt = int(line_str.split()[1]) 31 | if idx_resd_imgt < 27: 32 | fv1 += 1 33 | elif 27 <= idx_resd_imgt <= 38: 34 | labl_vec[idx_resd] = 1 # CDR-1 35 | cdr1 += 1 36 | elif 38 < idx_resd_imgt < 56: 37 | fv2 += 1 38 | elif 56 <= idx_resd_imgt <= 65: 39 | labl_vec[idx_resd] = 2 # CDR-2 40 | cdr2 += 1 41 | elif 65 < idx_resd_imgt < 105: 42 | fv3 += 1 43 | elif 105 <= idx_resd_imgt <= 117: 44 | labl_vec[idx_resd] = 3 # CDR-3 45 | cdr3 += 1 46 | else: 47 | fv4 += 1 48 | idx_resd += 1 49 | sum_length = fv1 + fv2 + fv3 + fv4 + cdr1 + cdr2 + cdr3 50 | assert idx_resd == idx_resd_end + 1, f'{idx_resd} {idx_resd_beg} {idx_resd_end} {chn_type} {cmd_out}' 51 | # assert len(aa_seq) == sum_length, 'Acc wrong.' 52 | if not len(aa_seq) == sum_length: 53 | assert len(aa_seq) > sum_length, 'AA seq smaller than sum_length' 54 | fv4 += len(aa_seq) - sum_length 55 | 56 | true_chain_type = sub_strs[2] 57 | 58 | return labl_vec, [fv1, cdr1, fv2, cdr2, fv3, cdr3, fv4], true_chain_type 59 | 60 | def get_seq_list_from_SeqRecords(seqrecord): 61 | seq_list = [] 62 | for seq_info in seqrecord: 63 | seq = str(seq_info.seq) 64 | seq_list.append(seq) 65 | return seq_list 66 | 67 | 68 | if __name__ == '__main__': 69 | seq = 'EVQLVESGGGLVQPGGSLRLSSAISGFSISSTSIDWVRQAPGKGLEWVARISPSSGSTSYADSVKGRFTISADTSKNTVYLQMNSLRAEDTAVYYTGRPLPEMGFFTQIPAMVDYRGQGTLVTVSS' 70 | lable = get_regions(seq) 71 | print(lable) -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import numpy as np 4 | import random 5 | import os 6 | import logging 7 | 8 | 9 | 10 | def get_new_log_dir(root='./logs', prefix='', tag=''): 11 | """ 12 | :param root: the dir path of log. 13 | :param prefix: the prefix name of log file. 14 | :param tag: the tag name of log file 15 | :return: the path of log dir. 16 | """ 17 | fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime()) 18 | if prefix != '': 19 | fn = prefix + '_' + fn 20 | if tag != '': 21 | fn = fn + '_' + tag 22 | log_dir = os.path.join(root, fn) 23 | os.makedirs(log_dir, exist_ok=True) 24 | return log_dir 25 | 26 | 27 | def seed_all(seed): 28 | """ Seed. """ 29 | torch.manual_seed(seed) 30 | np.random.seed(seed) 31 | random.seed(seed) 32 | 33 | 34 | def get_logger(name, log_dir=None, log_name=None): 35 | logger = logging.getLogger(name) 36 | logger.setLevel(logging.DEBUG) 37 | formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s') 38 | 39 | stream_handler = logging.StreamHandler() 40 | stream_handler.setLevel(logging.DEBUG) 41 | stream_handler.setFormatter(formatter) 42 | logger.addHandler(stream_handler) 43 | 44 | if log_dir is not None: 45 | if log_name is not None: 46 | file_handler = logging.FileHandler(os.path.join(log_dir, log_name)) 47 | else: 48 | file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt')) 49 | file_handler.setLevel(logging.DEBUG) 50 | file_handler.setFormatter(formatter) 51 | logger.addHandler(file_handler) 52 | 53 | return logger 54 | 55 | 56 | def inf_iterator(iterable): 57 | iterator = iterable.__iter__() 58 | while True: 59 | try: 60 | yield iterator.__next__() 61 | except StopIteration: 62 | iterator = iterable.__iter__() 63 | 64 | def count_parameters(model): 65 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | """The tokenizer for amino-acid sequences.""" 2 | 3 | import torch 4 | from torch import nn 5 | """Constants.""" 6 | 7 | import numpy as np 8 | 9 | 10 | # mapping between 1-char & 3-char residue names 11 | RESD_MAP_1TO3 = { 12 | 'A': 'ALA', 13 | 'R': 'ARG', 14 | 'N': 'ASN', 15 | 'D': 'ASP', 16 | 'C': 'CYS', 17 | 'Q': 'GLN', 18 | 'E': 'GLU', 19 | 'G': 'GLY', 20 | 'H': 'HIS', 21 | 'I': 'ILE', 22 | 'L': 'LEU', 23 | 'K': 'LYS', 24 | 'M': 'MET', 25 | 'F': 'PHE', 26 | 'P': 'PRO', 27 | 'S': 'SER', 28 | 'T': 'THR', 29 | 'W': 'TRP', 30 | 'Y': 'TYR', 31 | 'V': 'VAL', 32 | } 33 | RESD_MAP_3TO1 = {v: k for k, v in RESD_MAP_1TO3.items()} 34 | RESD_NAMES_1C = sorted(list(RESD_MAP_1TO3.keys())) 35 | RESD_NAMES_3C = sorted(list(RESD_MAP_1TO3.values())) 36 | RESD_NUM = len(RESD_NAMES_1C) # := 20. 37 | RESD_WITH_X = RESD_NAMES_1C + ['X'] 38 | RESD_ORDER_WITH_X = {restype: i for i, restype in enumerate(RESD_WITH_X)} 39 | N_ATOMS_PER_RESD = 14 # TRP 40 | N_ANGLS_PER_RESD = 7 # TRP (omega, phi, psi, chi1, chi2, chi3, and chi4) 41 | 42 | 43 | class Tokenizer(): 44 | """The tokenizer for amino-acid sequences.""" 45 | 46 | def __init__(self, has_bos=False, has_eos=False): 47 | """Constructor function.""" 48 | 49 | # setup configurations 50 | self.has_bos = has_bos 51 | self.has_eos = has_eos 52 | 53 | # additional configurations 54 | # self.tok_bos = '' 55 | self.tok_eos = '' 56 | self.tok_msk = '' 57 | # self.tok_unk = '' 58 | self.tok_pad = '-' 59 | self.toks = [*RESD_WITH_X, self.tok_pad, self.tok_msk] 60 | self.tok2idx_dict = {tok: idx for idx, tok in enumerate(self.toks)} 61 | self.idx_msk = self.tok2idx(self.tok_msk) 62 | self.idx_pad = self.tok2idx(self.tok_pad) 63 | 64 | 65 | @property 66 | def n_toks(self): 67 | """Get the number of tokens.""" 68 | 69 | return len(self.toks) 70 | 71 | 72 | def tok2idx(self, tok): 73 | """Convert a single token into its index.""" 74 | 75 | return self.tok2idx_dict[tok] 76 | 77 | 78 | def seq2idx(self, aa_seq): 79 | """Convert the amino-acid sequence into a 1-D vector of token indices.""" 80 | 81 | aa_seq_ext = [*aa_seq] 82 | if self.has_bos: 83 | aa_seq_ext = [self.tok_bos] + aa_seq_ext 84 | if self.has_eos: 85 | aa_seq_ext = aa_seq_ext + [self.tok_eos] 86 | idx_vec = torch.tensor([self.tok2idx_dict[x] for x in aa_seq_ext]) 87 | 88 | return idx_vec 89 | 90 | 91 | def seq2idx_batch(self, aa_seq_list): 92 | """Convert amino-acid sequences into token indices in the batch mode.""" 93 | 94 | idx_vec_list = [self.seq2idx(x) for x in aa_seq_list] 95 | idx_mat = nn.utils.rnn.pad_sequence( 96 | idx_vec_list, batch_first=True, padding_value=self.idx_pad) 97 | 98 | return idx_mat 99 | 100 | 101 | def idx2seq(self, idx_vec): 102 | """Convert the 1-D vector of token indices into an amino-acid sequence.""" 103 | 104 | aa_seq_ext = [self.toks[x] for x in idx_vec.tolist() if x != self.idx_pad] 105 | if self.has_bos: 106 | aa_seq_ext = aa_seq_ext[1:] # skip the token 107 | if self.has_eos: 108 | aa_seq_ext = aa_seq_ext[:-1] # skip the token 109 | aa_seq = ''.join(aa_seq_ext) 110 | 111 | return aa_seq 112 | 113 | def idx2seq_pad(self, idx_vec): 114 | """Convert the 1-D vector of token indices into an amino-acid sequence.""" 115 | 116 | aa_seq_ext = [self.toks[x] for x in idx_vec.tolist()] 117 | if self.has_bos: 118 | aa_seq_ext = aa_seq_ext[1:] # skip the token 119 | if self.has_eos: 120 | aa_seq_ext = aa_seq_ext[:-1] # skip the token 121 | aa_seq = ''.join(aa_seq_ext) 122 | 123 | return aa_seq 124 | 125 | def idx2seq_pad_batch(self, idx_mat): 126 | """Convert token indices into amino-acid sequences in the batch mode.""" 127 | 128 | n_seqs = idx_mat.shape[0] 129 | aa_seq_list = [self.idx2seq_pad(idx_mat[x]) for x in range(n_seqs)] 130 | 131 | return aa_seq_list 132 | 133 | def idx2seq_batch(self, idx_mat): 134 | """Convert token indices into amino-acid sequences in the batch mode.""" 135 | 136 | n_seqs = idx_mat.shape[0] 137 | aa_seq_list = [self.idx2seq(idx_mat[x]) for x in range(n_seqs)] 138 | 139 | return aa_seq_list 140 | 141 | def chain_type_idx(self, chain): 142 | if chain == 'H': 143 | return 0 144 | elif chain == 'L': 145 | return 1 146 | elif chain == 'K': 147 | return 2 148 | else: 149 | raise TypeError('Chain Type has problem.') 150 | 151 | def main(): 152 | """Main entry.""" 153 | 154 | # test samples 155 | aa_seq_list = [ 156 | 'EVQLVESGGGLVQPGGSLRLSSAISGFSISSTSIDWVRQAPGKGLEWVARISPSSGSTSYADSVKGRFTISADTSKNTVYLQMNSLRAEDTAVYYTGRPLPEMGFFTQIPAMVDYRGQGTLVTVSS', 157 | 'QVQLQESGGGLVQPGGSLRLSCAASGFTFSSAIMTWVRQAPGKGREWVSTIGSDGSITTYADSVKGRFTISRDNARNTLYLQMNSLKPEDTAVYYCTSAGRRGPGTQVTVSS', 158 | ] 159 | 160 | # initialization 161 | tokenizer = Tokenizer() 162 | print(f'# of tokens: {tokenizer.n_toks}') 163 | 164 | # test w/ 165 | idx_mat = tokenizer.seq2idx_batch(aa_seq_list) 166 | print(f'idx_mat: {idx_mat.shape}') 167 | 168 | # test w/ 169 | aa_seq_list_out = tokenizer.idx2seq_batch(idx_mat) 170 | print(f'sequences: {aa_seq_list_out}') 171 | for aa_seq, aa_seq_out in zip(aa_seq_list, aa_seq_list_out): 172 | assert aa_seq == aa_seq_out, f'mismatched amino-acid sequences: {aa_seq} vs. {aa_seq_out}' 173 | 174 | 175 | if __name__ == '__main__': 176 | main() -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | from model.encoder.model import ByteNetLMTime, AntiTFNet, AntiFrameWork 2 | from model.nanoencoder.model import NanoAntiTFNet, NanoInfillingFramework 3 | from dataset.oas_unpair_dataset_new import OasUnPairDataset 4 | from dataset.oas_pair_dataset_new import OasPairDataset 5 | from utils.warmup import GradualWarmupScheduler 6 | 7 | from torch.utils.data import Dataset 8 | from torch.utils.data import Subset 9 | from torch.optim import Adam, AdamW 10 | from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR, LambdaLR 11 | from torch.optim.lr_scheduler import _LRScheduler 12 | 13 | import torch 14 | 15 | 16 | class WarmupPolyLR(_LRScheduler): 17 | def __init__(self, optimizer, warmup_iters, max_iters, max_lr, min_lr, power=2, last_epoch=-1): 18 | self.warmup_iters = warmup_iters 19 | self.max_iters = max_iters 20 | self.max_lr = max_lr 21 | self.min_lr = min_lr 22 | self.power = power 23 | self.last_decay_lr = [group['lr'] for group in optimizer.param_groups] 24 | super(WarmupPolyLR, self).__init__(optimizer, last_epoch, verbose=True) 25 | 26 | def get_lr(self): 27 | if self._step_count < self.warmup_iters: 28 | self.last_decay_lr = [base_lr + (self.max_lr - base_lr) * (self._step_count / self.warmup_iters) for base_lr in self.base_lrs] 29 | elif self._step_count < self.max_iters: 30 | decay_factor = (1 - (self._step_count - self.warmup_iters) / (self.max_iters - self.warmup_iters)) ** self.power 31 | self.last_decay_lr = [self.max_lr * decay_factor + (1 - decay_factor) * base_lr for base_lr in self.base_lrs] 32 | if min(self.last_decay_lr) <= self.min_lr: 33 | self.last_decay_lr = [self.min_lr for _ in self.base_lrs] 34 | return self.last_decay_lr 35 | 36 | 37 | def warmup(n_warmup_steps): 38 | def get_lr(step): 39 | return min((step + 1) / n_warmup_steps, 1.0) 40 | return get_lr 41 | 42 | 43 | def model_selected(config, pretrained_model=None, tokenizer=None): 44 | if config.name == 'evo_oadm': 45 | return ByteNetLMTime(**config.model) 46 | elif config.name == 'trans_oadm': 47 | return AntiTFNet(**config.model) 48 | elif config.name == 'antibody_finetune': 49 | return AntiFrameWork(config.model, pretrained_model, tokenizer) 50 | elif config.name == 'nano': 51 | return NanoAntiTFNet(**config.model) 52 | elif config.name == 'infilling': 53 | return NanoInfillingFramework(config.model, pretrained_model, tokenizer) 54 | else: 55 | pass 56 | 57 | 58 | def optimizer_selected(optimizer, model): 59 | if optimizer.type == 'Adam': 60 | return Adam( 61 | filter(lambda p: p.requires_grad, model.parameters()), 62 | lr=optimizer.lr, 63 | weight_decay=optimizer.weight_decay 64 | ) 65 | elif optimizer.type == 'AdamW': 66 | return AdamW( 67 | filter(lambda p: p.requires_grad, model.parameters()), 68 | lr=optimizer.lr, 69 | weight_decay=optimizer.weight_decay 70 | ) 71 | else: 72 | pass 73 | 74 | 75 | def scheduler_selected(scheduler, optimizer): 76 | if scheduler.type == 'plateau': 77 | return ReduceLROnPlateau( 78 | optimizer, 79 | factor=scheduler.factor, 80 | patience=scheduler.patience, 81 | min_lr=scheduler.min_lr 82 | ) 83 | elif scheduler.type == 'cosine_annal': 84 | return CosineAnnealingLR( 85 | optimizer, 86 | T_max=scheduler.T_max, 87 | ) 88 | elif scheduler.type == 'warm_up': 89 | return WarmupPolyLR( 90 | optimizer, 91 | warmup_iters=scheduler.warmup_steps, 92 | max_iters=scheduler.max_steps, 93 | max_lr=scheduler.max_lr, 94 | min_lr=scheduler.min_lr 95 | ) 96 | else: 97 | pass 98 | 99 | def split_data(path, dataset): 100 | split = torch.load(path) 101 | subsets = {k: Subset(dataset, indices=v) for k, v in split.items()} 102 | return subsets 103 | 104 | 105 | def get_dataset(root, name, version, split=True): 106 | if name == 'pair': 107 | dataset = OasPairDataset(root, version=version) 108 | split_path = dataset.index_path 109 | if split: 110 | return split_data(split_path, dataset) 111 | else: 112 | return dataset 113 | 114 | elif name == 'unpair': 115 | h_dataset = OasUnPairDataset(data_dpath=root, chaintype='heavy') 116 | l_dataset = OasUnPairDataset(data_dpath=root, chaintype='light') 117 | h_split_path = h_dataset.index_path 118 | l_split_path = l_dataset.index_path 119 | if split: 120 | h_subsets = split_data(h_split_path, h_dataset) 121 | l_subsets = split_data(l_split_path, l_dataset) 122 | return h_subsets, l_subsets 123 | else: 124 | return h_dataset, l_dataset 125 | 126 | elif name == 'mouse': 127 | dataset = OasPairDataset(root, version=version, mouse=True) 128 | split_path = dataset.index_path 129 | if split: 130 | return split_data(split_path, dataset) 131 | else: 132 | return dataset 133 | 134 | elif name == 'heavy': 135 | h_dataset = OasUnPairDataset(data_dpath=root, chaintype='heavy') 136 | h_split_path = h_dataset.index_path 137 | if split: 138 | h_subsets = split_data(h_split_path, h_dataset) 139 | return h_subsets 140 | else: 141 | return h_dataset 142 | 143 | elif name == 'vhh': 144 | vhh_dataset = OasUnPairDataset(data_dpath=root, chaintype='vhh') 145 | vhh_split_path = vhh_dataset.index_path 146 | if split: 147 | vhh_subsets = split_data(vhh_split_path, vhh_dataset) 148 | return vhh_subsets 149 | else: 150 | return vhh_dataset 151 | 152 | else: 153 | raise NotImplementedError('Unknown dataset: %s' % name) 154 | -------------------------------------------------------------------------------- /utils/warmup.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2019 Ildoo Kim 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | from torch.optim.lr_scheduler import ReduceLROnPlateau 25 | from torch.optim.lr_scheduler import _LRScheduler 26 | 27 | 28 | class GradualWarmupScheduler(_LRScheduler): 29 | """ Gradually warm-up(increasing) learning rate in optimizer. 30 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 31 | Args: 32 | optimizer (Optimizer): Wrapped optimizer. 33 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 34 | total_epoch: target learning rate is reached at total_epoch, gradually 35 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 36 | """ 37 | 38 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 39 | self.multiplier = multiplier 40 | if self.multiplier < 1.: 41 | raise ValueError('multiplier should be greater thant or equal to 1.') 42 | self.total_epoch = total_epoch 43 | self.after_scheduler = after_scheduler 44 | self.finished = False 45 | super(GradualWarmupScheduler, self).__init__(optimizer) 46 | 47 | def get_lr(self): 48 | if self.last_epoch > self.total_epoch: 49 | if self.after_scheduler: 50 | if not self.finished: 51 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 52 | self.finished = True 53 | return self.after_scheduler.get_last_lr() 54 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 55 | 56 | if self.multiplier == 1.0: 57 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 58 | else: 59 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in 60 | self.base_lrs] 61 | 62 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 63 | if epoch is None: 64 | epoch = self.last_epoch + 1 65 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 66 | if self.last_epoch <= self.total_epoch: 67 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in 68 | self.base_lrs] 69 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 70 | param_group['lr'] = lr 71 | else: 72 | if epoch is None: 73 | self.after_scheduler.step(metrics, None) 74 | else: 75 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 76 | 77 | def step(self, metrics=None, epoch=None): 78 | if type(self.after_scheduler) != ReduceLROnPlateau: 79 | if self.finished and self.after_scheduler: 80 | if epoch is None: 81 | self.after_scheduler.step(None) 82 | else: 83 | self.after_scheduler.step(epoch - self.total_epoch) 84 | self._last_lr = self.after_scheduler.get_last_lr() 85 | else: 86 | return super(GradualWarmupScheduler, self).step(epoch) 87 | else: 88 | self.step_ReduceLROnPlateau(metrics, epoch) 89 | --------------------------------------------------------------------------------