├── .gitignore ├── structural-evolution-overview.png ├── LICENSE ├── bin ├── mpnn_ab_benchmarking.sh ├── if_ab_benchmarking.sh ├── generate_dms.py ├── dms_utils.py ├── parse_abysis.py ├── eval_ablang.py ├── recommend.py ├── score_log_likelihoods.py ├── esm1v_ab_benchmarking.sh ├── multichain_util.py ├── plot_mpnn_benchmarks.py ├── util.py ├── dms_enrichment.py ├── plot_esm1v_benchmarks.py └── plot_ab-binding_benchmarks.py ├── README.md └── environment.yml /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | output 3 | *.pyc 4 | *.png 5 | *.log 6 | .tar -------------------------------------------------------------------------------- /structural-evolution-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/varun-shanker/structural-evolution/HEAD/structural-evolution-overview.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 varun-shanker 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /bin/mpnn_ab_benchmarking.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | run_protein_mpnn() { 4 | python ../ProteinMPNN/protein_mpnn_run.py \ 5 | --path_to_fasta "$1" \ 6 | --score_only 1 \ 7 | --save_score 1 \ 8 | --pdb_path_chains "$2" \ 9 | --out_folder "$3" \ 10 | --pdb_path "$4" 11 | } 12 | 13 | # G6 Light Chain 14 | run_protein_mpnn \ 15 | "data/ab_mutagenesis_expts/g6/g6_2fjg_lc_lib.fasta" \ 16 | "L" \ 17 | "output/ab_mutagenesis_expts/g6/proteinMpnnLC" \ 18 | "data/ab_mutagenesis_expts/g6/2fjg_vlh_fvar.pdb" 19 | 20 | # G6 Heavy Chain 21 | run_protein_mpnn \ 22 | "data/ab_mutagenesis_expts/g6/g6_2fjg_hc_lib.fasta" \ 23 | "H" \ 24 | "output/ab_mutagenesis_expts/g6/proteinMpnnHC" \ 25 | "data/ab_mutagenesis_expts/g6/2fjg_vlh_fvar.pdb" 26 | 27 | # CR6261 Heavy Chain 28 | run_protein_mpnn \ 29 | "data/ab_mutagenesis_expts/cr6261/cr6261_3gbn_hc_lib.fasta" \ 30 | "H" \ 31 | "output/ab_mutagenesis_expts/cr6261/mpnn" \ 32 | "data/ab_mutagenesis_expts/cr6261/3gbn_ablh_fvar.pdb" 33 | 34 | # CR9114 Heavy Chain 35 | run_protein_mpnn \ 36 | "data/ab_mutagenesis_expts/cr9114/cr9114_4fqi_hc_lib.fasta" \ 37 | "H" \ 38 | "output/ab_mutagenesis_expts/cr9114/mpnn" \ 39 | "data/ab_mutagenesis_expts/cr9114/4fqi_ablh_fvar.pdb" -------------------------------------------------------------------------------- /bin/if_ab_benchmarking.sh: -------------------------------------------------------------------------------- 1 | abs=("cr6261" "cr9114" "g6" "g6") 2 | ab_fastas=("cr6261_3gbn_hc_lib.fasta" "cr9114_4fqi_hc_lib.fasta" "g6_2fjg_hc_lib.fasta" "g6_2fjg_lc_lib.fasta") 3 | data_path="data/ab_mutagenesis_expts/" 4 | out_prefix="output/ab_mutagenesis_expts/" 5 | 6 | for ((i=0; i<${#abs[@]}; i++)); do 7 | ab="${abs[i]}" 8 | ab_fasta="${ab_fastas[i]}" 9 | ab_dir_path="${data_path}${ab}/" 10 | struc_list=("${ab_dir_path}"*.pdb) 11 | ab_out_dir="${out_prefix}${ab}/" 12 | 13 | # Set the default chain value 14 | chain="H" 15 | # Special handling for 'g6' antibody 16 | if [[ "$ab" == "g6" ]]; then 17 | [[ "$ab_fasta" == *"lc"* ]] && chain="L" 18 | fi 19 | 20 | # gather pdbs and filter the hc/lc only structure from being scored by library for the other chain 21 | if [[ "$ab" == "g6" && "$chain" == "L" ]]; then 22 | struc_list=($(echo "${struc_list[@]}" | tr ' ' '\n' | grep -v '_h_' | tr '\n' ' ')) 23 | elif [[ "$ab" == "g6" && "$chain" == "H" ]]; then 24 | struc_list=($(echo "${struc_list[@]}" | tr ' ' '\n' | grep -v '_l_' | tr '\n' ' ')) 25 | fi 26 | 27 | mkdir -p "$ab_out_dir" 28 | 29 | for struc in "${struc_list[@]}"; do 30 | out_file="${ab_out_dir}${struc##*/}" 31 | if [[ "$ab" == "g6" ]]; then 32 | chain_modeled="$([ "$chain" == "H" ] && echo "hc" || echo "lc")" 33 | out_file="${out_file%_fvar.pdb}_${chain_modeled}_scores.csv" 34 | else 35 | out_file="${out_file%_fvar.pdb}_scores.csv" 36 | fi 37 | 38 | if [[ ! -f "$out_file" ]]; then 39 | python bin/score_log_likelihoods.py "$struc" --chain "$chain" --seqpath "${ab_dir_path}${ab_fasta}" --outpath "$out_file" 40 | else 41 | echo "$out_file already exists. Skipping..." 42 | fi 43 | done 44 | done -------------------------------------------------------------------------------- /bin/generate_dms.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from dms_utils import deep_mutational_scan 3 | from pathlib import Path 4 | import numpy as np 5 | 6 | import esm 7 | from util import load_structure, extract_coords_from_structure 8 | import biotite.structure 9 | from collections import defaultdict 10 | 11 | 12 | def get_native_seq(pdbfile, chain): 13 | structure = load_structure(pdbfile, chain) 14 | _ , native_seq = extract_coords_from_structure(structure) 15 | return native_seq 16 | 17 | def write_dms_lib(args): 18 | '''Writes a deep mutational scanning library, including the native/wildtype (wt) of the 19 | indicated target chain in the structure to an output Fasta file''' 20 | 21 | sequence = get_native_seq(args.pdbfile, args.chain) 22 | Path(args.outpath).parent.mkdir(parents=True, exist_ok=True) 23 | with open(args.dmspath, 'w') as f: 24 | f.write('>wt\n') 25 | f.write(sequence+'\n') 26 | for pos, wt, mt in deep_mutational_scan(sequence): 27 | assert(sequence[pos] == wt) 28 | mut_seq = sequence[:pos] + mt + sequence[(pos + 1):] 29 | 30 | f.write('>' + str(wt) + str(pos+1) + str(mt) + '\n') 31 | f.write(mut_seq + '\n') 32 | 33 | 34 | def main(): 35 | parser = argparse.ArgumentParser( 36 | description='Create a DMS library based on target chain in the structure.' 37 | ) 38 | parser.add_argument( 39 | 'pdbfile', type=str, 40 | help='input filepath, either .pdb or .cif', 41 | ) 42 | parser.add_argument( 43 | '--dmspath', type=str, 44 | help='output filepath for dms library', 45 | ) 46 | parser.add_argument( 47 | '--chain', type=str, 48 | help='chain id for the chain of interest', default='A', 49 | ) 50 | 51 | args = parser.parse_args() 52 | 53 | if args.dmspath is None: 54 | args.dmspath = f'predictions/{args.pdbfile[:-4]}-{args.chain}_dms.fasta' 55 | 56 | write_dms_lib(args) 57 | 58 | if __name__ == '__main__': 59 | 60 | main() 61 | -------------------------------------------------------------------------------- /bin/dms_utils.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import datetime 3 | from dateutil.parser import parse as dparse 4 | import errno 5 | import numpy as np 6 | import os 7 | import random 8 | import sys 9 | import time 10 | import warnings 11 | from Bio import pairwise2 12 | from Bio import BiopythonWarning 13 | warnings.simplefilter('ignore', BiopythonWarning) 14 | from Bio import Seq, SeqIO 15 | 16 | 17 | np.random.seed(1) 18 | random.seed(1) 19 | 20 | AAs = [ 21 | 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 22 | 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', 23 | ] 24 | 25 | def tprint(string): 26 | string = str(string) 27 | sys.stdout.write(str(datetime.datetime.now()) + ' | ') 28 | sys.stdout.write(string + '\n') 29 | sys.stdout.flush() 30 | 31 | def mkdir_p(path): 32 | try: 33 | os.makedirs(path) 34 | except OSError as exc: # Python >2.5 35 | if exc.errno == errno.EEXIST and os.path.isdir(path): 36 | pass 37 | else: 38 | raise 39 | 40 | def deep_mutational_scan(sequence, exclude_noop=True): 41 | for pos, wt in enumerate(sequence): 42 | for mt in AAs: 43 | if exclude_noop and wt == mt: 44 | continue 45 | yield (pos, wt, mt) 46 | 47 | 48 | def make_mutations(seq, mutations): 49 | mut_seq = [ char for char in seq ] 50 | for mutation in mutations: 51 | wt, pos, mt = mutation[0], int(mutation[1:-1]) - 1, mutation[-1] 52 | assert(seq[pos] == wt) 53 | mut_seq[pos] = mt 54 | mut_seq = ''.join(mut_seq).replace('-', '') 55 | return mut_seq 56 | 57 | def find_mutations(seq1, seq2): 58 | alignment = pairwise2.align.globalms( 59 | seq1, seq2, 5, -4, -6, -.1, one_alignment_only=True, 60 | )[0] 61 | 62 | mutation_set = [] 63 | pos1, pos2, pos_map = 0, 0, {} 64 | for wt, mt in zip(alignment[0], alignment[1]): 65 | if wt != '-': 66 | pos1 += 1 67 | if mt != '-': 68 | pos2 += 1 69 | if wt != mt and wt != '-' and mt != '-': 70 | mut_str = f'{wt}{pos1}{mt}' 71 | mutation_set.append(mut_str) 72 | 73 | return mutation_set -------------------------------------------------------------------------------- /bin/parse_abysis.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import os 4 | 5 | AAs = set('ACDEFGHIKLMNPQRSTVWY') 6 | 7 | seqs_abs = { 8 | 'cr6261_vh': 'EVQLVESGAEVKKPGSSVKVSCKASGGPFRSYAISWVRQAPGQGPEWMGGIIPIFGTTKYAPKFQGRVTITADDFAGTVYMELSSLRSEDTAMYYCAKHMGYQVRETMDVWGKGTTVTVSS', 9 | 'cr9114_vh': 'QVQLVQSGAEVKKPGSSVKVSCKSSGGTSNNYAISWVRQAPGQGLDWMGGISPIFGSTAYAQKFQGRVTISADIFSNTAYMELNSLTSEDTAVYFCARHGNYYYYSGMDVWGQGTTVTVSS', 10 | 'g6_vh': 'EVQLVESGGGLVQPGGSLRLSCAASGFTISDYWIHWVRQAPGKGLEWVAGITPAGGYTYYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARFVFFLPYAMDYWGQGTLVTV', 11 | 'g6_vl': 'DIQMTQSPSSLSASVGDRVTITCRASQDVSTAVAWYQQKPGKAPKLLIYSASFLYSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYTTPPTFGQGTKVEIK', 12 | 'beb_vh': 'QITLKESGPTLVKPTQTLTLTCTFSGFSLSISGVGVGWLRQPPGKALEWLALIYWDDDKRYSPSLKSRLTISKDTSKNQVVLKMTNIDPVDTATYYCAHHSISTIFDHWGQGTLVTVSS', 13 | 'beb_vl' : 'QSALTQPASVSGSPGQSITISCTATSSDVGDYNYVSWYQQHPGKAPKLMIFEVSDRPSGISNRFSGSKSGNTASLTISGLQAEDEADYYCSSYTTSSAVFGGGTKLTVL', 14 | 'sa58_vh': 'QVQLAQSGSELRKPGASVKVSCDTSGHSFTSNAIHWVRQAPGQGLEWMGWINTDTGTPTYAQGFTGRFVFSLDTSARTAYLQISSLKADDTAVFYCARERDYSDYFFDYWGQGTLVTVSS', 15 | 'sa58_vl': 'EVVMTQSPASLSVSPGERATLSCRARASLGISTDLAWYQQRPGQAPRLLIYGASTRATGIPARFSGSGSGTEFTLTISSLQSEDSAVYYCQQYSNWPLTFGGGTKVEIK', 16 | } 17 | 18 | 19 | def parse_abysis(seq, seq_name): 20 | fname = f'data/abysis/abysis_counts_{seq_name}.json' 21 | with open(fname) as f: 22 | json_data = json.load(f) 23 | 24 | data = [] 25 | for idx, freq_table in enumerate(json_data['frequencies']): 26 | wt = seq[idx] 27 | total = freq_table['total'] 28 | for entry in freq_table['counts']: 29 | if entry['aa'] == seq[idx]: 30 | wt_frac = entry['c'] / total 31 | aa_to_freq, seen = {}, set() 32 | for entry in freq_table['counts']: 33 | mt = entry['aa'] 34 | seen.add(mt) 35 | counts = entry['c'] 36 | frac = entry['c'] / total 37 | ratio = frac / wt_frac 38 | data.append([ idx + 1, wt, mt, counts, frac, ratio ]) 39 | for mt in AAs - seen: 40 | data.append([ idx + 1, wt, mt, 0, 0., 0. ]) 41 | 42 | df = pd.DataFrame(data, columns=[ 43 | 'pos', 44 | 'wt', 45 | 'mt', 46 | 'counts', 47 | 'fraction', 48 | 'likelihood_ratio', 49 | ]) 50 | 51 | if any(prefix in seq_name for prefix in ['cr6261', 'cr9114', 'g6']): 52 | subdirectory = seq_name.split('_')[0] 53 | else: 54 | subdirectory = '' 55 | 56 | output_dir = os.path.join('output', 'ab_mutagenesis_expts', subdirectory) 57 | os.makedirs(output_dir, exist_ok=True) 58 | 59 | output_path = os.path.join(output_dir, f'abysis_counts_{seq_name}.txt') 60 | df.to_csv(output_path, sep='\t') 61 | 62 | if __name__ == '__main__': 63 | for seq_name in seqs_abs: 64 | seq = seqs_abs[seq_name] 65 | parse_abysis(seq, seq_name) -------------------------------------------------------------------------------- /bin/eval_ablang.py: -------------------------------------------------------------------------------- 1 | import ablang 2 | import numpy as np 3 | import pandas as pd 4 | import scipy.special 5 | 6 | heavy_ablang = ablang.pretrained("heavy") 7 | heavy_ablang.freeze() 8 | 9 | light_ablang = ablang.pretrained("light") 10 | light_ablang.freeze() 11 | 12 | ab_dict = dict( 13 | cr6261= ('EVQLVESGAEVKKPGSSVKVSCKASGGPFRSYAISWVRQAPGQGPEWMGGIIPIFGTTKYAPKFQGRVTITADDFAGTVYMELSSLRSEDTAMYYCAKHMGYQVRETMDVWGKGTTVTVSS', 14 | 'QSVLTQPPSVSAAPGQKVTISCSGSSSNIGNDYVSWYQQLPGTAPKLLIYDNNKRPSGIPDRFSGSKSGTSATLGITGLQTGDEANYYCATWDRRPTAYVVFGGGTKLTVL'), 15 | cr9114= ('QVQLVQSGAEVKKPGSSVKVSCKSSGGTSNNYAISWVRQAPGQGLDWMGGISPIFGSTAYAQKFQGRVTISADIFSNTAYMELNSLTSEDTAVYFCARHGNYYYYSGMDVWGQGTTVTVSS', 16 | 'QSALTQPPAVSGTPGQRVTISCSGSDSNIGRRSVNWYQQFPGTAPKLLIYSNDQRPSVVPDRFSGSKSGTSASLAISGLQSEDEAEYYCAAWDDSLKGAVFGGGTQLTVL'), 17 | g6 = ('EVQLVESGGGLVQPGGSLRLSCAASGFTISDYWIHWVRQAPGKGLEWVAGITPAGGYTYYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARFVFFLPYAMDYWGQGTLVTV', 18 | 'DIQMTQSPSSLSASVGDRVTITCRASQDVSTAVAWYQQKPGKAPKLLIYSASFLYSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYTTPPTFGQGTKVEIK') 19 | 20 | ) 21 | 22 | def eval_ablang(s, ab, chain): 23 | if chain == 'hc': 24 | fname = 'output/ab_mutagenesis_expts/'+ab+'/' + ab +'_hc_ablangScores.csv' 25 | log_likelihoods = heavy_ablang(s, mode = 'likelihood')[0][1:-1] 26 | alphabet = heavy_ablang.tokenizer.vocab_to_aa 27 | elif chain == 'lc': 28 | fname = 'output/ab_mutagenesis_expts/'+ab+'/' + ab +'_lc_ablangScores.csv' 29 | log_likelihoods = light_ablang(s, mode = 'likelihood')[0][1:-1] 30 | alphabet = light_ablang.tokenizer.vocab_to_aa 31 | 32 | assert (log_likelihoods.shape)[0] == len(s) 33 | 34 | filt_alphabet = {key: value for key, value in alphabet.items() if value.isalpha()} 35 | log_likelihood_ratio = [] 36 | for i,res_log_likelihoods in enumerate(log_likelihoods): 37 | wt_res = s[i] 38 | wt_index = list(filt_alphabet.values()).index(wt_res) 39 | wt_log_likelihood = res_log_likelihoods[wt_index] 40 | log_likelihood_ratio.extend(res_log_likelihoods-wt_log_likelihood) 41 | 42 | res_order = [alphabet[key] for key in range(1, 21)] #extract order of residues in likelihood 43 | mt = res_order * len(s) 44 | wt = [char for char in s for i in range(len(res_order))] 45 | pos = [i+1 for i in range(len(s)) for j in range(len(res_order))] 46 | data = { 47 | 'pos': pos, 48 | 'wt': wt, 49 | 'mt': mt, 50 | 'log_likelihood' : log_likelihoods.flatten(), 51 | 'log_likelihood_ratio' : log_likelihood_ratio, 52 | } 53 | df = pd.DataFrame(data) 54 | df.to_csv(fname, index = False) 55 | 56 | 57 | def main(): 58 | 59 | for ab in ab_dict: 60 | for s,chain in zip(ab_dict[ab], ('hc','lc')): 61 | eval_ablang(s, ab, chain) 62 | 63 | 64 | if __name__ == '__main__': 65 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # structural-evolution 2 | 3 | This repository scripts for running the analysis described in the paper ["Unsupervised evolution of protein and antibody complexes with a structure-informed language model"](https://www.science.org/stoken/author-tokens/ST-1968/full). 4 | 5 |

6 | structural-evolution-overview 7 |

8 | 9 | ## Setup/Installation 10 | 1. Clone this repository 11 | ``` 12 | git clone https://github.com/varun-shanker/structural-evolution.git 13 | ``` 14 | 2. Install and Activate Conda Environment with Required Dependencies 15 | ``` 16 | conda env create -f environment.yml 17 | conda activate struct-evo 18 | ``` 19 | 3. Download and unzip the model weights from [here](https://zenodo.org/records/12631662), then insert them in torch checkpoints. 20 | ``` 21 | wget -P ~/.cache/torch/hub/checkpoints https://zenodo.org/records/12631662/files/esm_if1_20220410.zip 22 | unzip ~/.cache/torch/hub/checkpoints/esm_if1_20220410.zip 23 | ``` 24 | 4. Navigate to repository 25 | ``` 26 | cd structural-evolution 27 | ``` 28 | 29 | ## Generating Predictions 30 | 31 | To evaluate this model on a new protein or protein complex structure, run 32 | ```bash 33 | python bin/recommend.py [pdb/cif file] --chain [X] 34 | ``` 35 | where `[pdb file]` is the file path to the pdb/cif structure file of the protein or protein complex and `[X]` is the target chain you wish to evolve. The default script will output the top `n`=10 predicted substitutions at unique residue positions (`maxrep=1`), where `n` and `maxrep` can be modified using the arguments (see below). 36 | 37 | To recommend mutations to antibody variable domain sequences, we have simply run the above script separately on the heavy and light chains. 38 | 39 | Additional arguments: 40 | 41 | ``` 42 | --seqpath: filepath where fasta with dms library should be saved (defaults to new subdirectory in outputs directory) 43 | --outpath: output filepath for scores of variant sequences (defaults to new subdirectory in outputs directory) 44 | --chain: chain id for the chain of interest 45 | --n: number of top recommendations to be output (default: n=10) 46 | --maxrep: maximum representation of a single site in the output recommendations (default: maxrep = 1 is a unique set of recommendations where each mutation of a given wildtype residue is recommended at most once) 47 | --upperbound: only residue positions less than the user-defined upperbound are considered for recommendation in the final output (but all positions are still conditioned for scoring) 48 | --order: for multichain conditioning, provides option to specify the order of chains 49 | --offset: integer offset or adjustment for labeling of residue indices encoded in the structure file 50 | --multichain-backbone: use the backbones of all chains in the input for conditioning (default is True) 51 | --singlechain-backbone: use the backbone of only the target chain in the input for conditioning 52 | --nogpu: Do not use GPU even if available 53 | ``` 54 | 55 | For example, to generate mutations for the heavy chain of LYCoV-1404, we would simply run the following: 56 | 57 | ``` 58 | python bin/recommend.py examples/7mmo_abc_fvar.pdb \ 59 | --chain A --seqpath examples/7mmo_chainA_lib.fasta \ 60 | --outpath examples/7mmo_chainA_scores.csv \ 61 | --upperbound 109 --offset 1 62 | ``` 63 | In this example, we use a pdb structure file with variable regions of both chains of the antibody in complex with the antigen, SARS-CoV-2 receptor binding domain (RBD). To obtain recommendations specifically for the heavy chain, we specify chain A. The fasta file containing the library screened *in silico* and the corresponding output scores file are saved at the indicated paths. 64 | To limit the recommendations that are output and exclude mutations predicted in the final framework region, we set the upper bound to 109 (this value will vary for each antibody). Since the first residue is not included in the structure, we specify an offset of 1 to ensure the returned mutations are correctly indexed. 65 | 66 | ## Paper analysis scripts 67 | 68 | To reproduce the analysis in the paper, first download and extract data with the commands: 69 | ```bash 70 | wget https://zenodo.org/record/11260318/files/data.tar.gz 71 | tar xvf data.tar.gz 72 | ``` 73 | To evaluate alternate sequence-only and structure-based scoring methods, follow directions [here](https://github.com/facebookresearch/esm?tab=readme-ov-file#zs_variant) and [here](https://github.com/dauparas/ProteinMPNN) for installation instructions. 74 | 75 | ## Citation 76 | 77 | Please cite the following publication when referencing this work. 78 | 79 | ``` 80 | @article {Shanker-struct-evo, 81 | author = {Shanker, Varun and Bruun, Theodora and Hie, Brian and Kim, Peter}, 82 | title = {Unsupervised evolution of protein and antibody complexes with a structure-informed language model}, 83 | year = {2024}, 84 | doi = {10.1126/science.adk8946}, 85 | publisher = {American Association for the Advancement of Science}, 86 | URL = {https://www.science.org/doi/10.1126/science.adk8946}, 87 | journal = {Science} 88 | } 89 | ``` 90 | 91 | ## License 92 | This project is licensed under the MIT License - see the LICENSE file for details. 93 | -------------------------------------------------------------------------------- /bin/recommend.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from biotite.sequence.io.fasta import FastaFile, get_sequences 3 | import numpy as np 4 | from pathlib import Path 5 | import torch 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | import os 9 | 10 | import esm 11 | from dms_utils import deep_mutational_scan 12 | import warnings 13 | import pandas as pd 14 | import util 15 | import multichain_util 16 | 17 | import score_log_likelihoods 18 | 19 | def get_native_seq(pdbfile, chain): 20 | structure = util.load_structure(pdbfile, chain) 21 | _ , native_seq = util.extract_coords_from_structure(structure) 22 | return native_seq 23 | 24 | def write_dms_lib(args): 25 | '''Writes a deep mutational scanning library, including the native/wildtype (wt) of the 26 | indicated target chain in the structure to an output Fasta file''' 27 | 28 | sequence = get_native_seq(args.pdbfile, args.chain) 29 | Path(args.seqpath).parent.mkdir(parents=True, exist_ok=True) 30 | with open(args.seqpath, 'w') as f: 31 | f.write('>wt\n') 32 | f.write(sequence+'\n') 33 | for pos, wt, mt in deep_mutational_scan(sequence): 34 | assert(sequence[pos] == wt) 35 | mut_seq = sequence[:pos] + mt + sequence[(pos + 1):] 36 | f.write('>' + str(wt) + str(pos+1+args.offset) + str(mt) + '\n') 37 | f.write(mut_seq + '\n') 38 | 39 | def get_top_n(args): 40 | recs, rec_inds = [], [] 41 | scores_df = pd.read_csv(args.outpath).sort_values(by = 'log_likelihood', ascending = False) 42 | 43 | for seqid in scores_df['seqid']: 44 | res_ind = seqid[1:-1] 45 | if (rec_inds.count(res_ind) < args.maxrep): 46 | if args.upperbound == None or (int(res_ind) < int(args.upperbound)): 47 | recs.append(seqid) 48 | rec_inds.append(res_ind) 49 | if len(recs) == args.n: 50 | break 51 | 52 | print(f'\n Chain {args.chain}') 53 | print(*recs, sep='\n') 54 | 55 | def get_model_checkpoint_path(filename): 56 | # Expanding the user's home directory 57 | return os.path.expanduser(f"~/.cache/torch/hub/checkpoints/{filename}") 58 | 59 | def main(): 60 | parser = argparse.ArgumentParser( 61 | description='Score sequences based on a given structure.' 62 | ) 63 | parser.add_argument( 64 | 'pdbfile', type=str, 65 | help='input filepath, either .pdb or .cif', 66 | ) 67 | parser.add_argument( 68 | '--seqpath', type=str, 69 | help='filepath where fasta of dms library should be saveda', 70 | ) 71 | parser.add_argument( 72 | '--outpath', type=str, 73 | help='output filepath for scores of variant sequences', 74 | ) 75 | parser.add_argument( 76 | '--chain', type=str, 77 | help='chain id for the chain of interest', default='A', 78 | ) 79 | parser.set_defaults(multichain_backbone=True) 80 | parser.add_argument( 81 | '--multichain-backbone', action='store_true', 82 | help='use the backbones of all chains in the input for conditioning' 83 | ) 84 | parser.add_argument( 85 | '--singlechain-backbone', dest='multichain_backbone', 86 | action='store_false', 87 | help='use the backbone of only target chain in the input for conditioning' 88 | ) 89 | parser.add_argument( 90 | '--order', type=str, default=None, 91 | help='for multichain, option to specify order of chains' 92 | ) 93 | parser.add_argument( 94 | '--n', type=int, 95 | help='number of desired predictions to be output', 96 | default=10, 97 | ) 98 | parser.add_argument( 99 | '--maxrep', type=int, 100 | help='maximum representation of a single site in the top recommendations \ 101 | (eg: maxrep = 1 is a unique set where no wildtype residue is mutated more than once)', 102 | default=1, 103 | ) 104 | parser.add_argument( 105 | '--offset', type=int, 106 | help='integer offset for labeling of residue indices encoded in the structure', 107 | default=0, 108 | ) 109 | parser.add_argument( 110 | '--upperbound', type=int, 111 | help='only residue positions less than the user-defined upperbound are considered to be recommended for screening \ 112 | (but all positions are still conditioned for scoring)', 113 | default=None, 114 | ) 115 | parser.add_argument( 116 | "--nogpu", action="store_true", 117 | help="Do not use GPU even if available" 118 | ) 119 | 120 | args = parser.parse_args() 121 | 122 | if args.seqpath is None: 123 | args.seqpath = f'output/{args.pdbfile[:-4]}-chain{args.chain}_dms.fasta' 124 | 125 | if args.outpath is None: 126 | args.outpath = f'output/{args.pdbfile[:-4]}-chain{args.chain}_scores.csv' 127 | 128 | #write dms library for target chain 129 | write_dms_lib(args) 130 | 131 | model_checkpoint_path = get_model_checkpoint_path('esm_if1_20220410.pt') 132 | with warnings.catch_warnings(): 133 | warnings.simplefilter('ignore', UserWarning) 134 | model, alphabet = esm.pretrained.load_model_and_alphabet( \ 135 | model_checkpoint_path \ 136 | ) 137 | model = model.eval() 138 | 139 | if args.multichain_backbone: 140 | score_log_likelihoods.score_multichain_backbone(model, alphabet, args) 141 | else: 142 | score_log_likelihoods.score_singlechain_backbone(model, alphabet, args) 143 | 144 | get_top_n(args) 145 | 146 | 147 | if __name__ == '__main__': 148 | main() -------------------------------------------------------------------------------- /bin/score_log_likelihoods.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from biotite.sequence.io.fasta import FastaFile, get_sequences 3 | import numpy as np 4 | from pathlib import Path 5 | import torch 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | import warnings 9 | import os 10 | from dms_utils import deep_mutational_scan 11 | import esm 12 | import pandas as pd 13 | from multichain_util import extract_coords_from_complex, _concatenate_coords, _concatenate_seqs, score_sequence_in_complex 14 | from util import get_sequence_loss, load_structure, load_coords, score_sequence, extract_coords_from_structure 15 | 16 | def get_native_seq(pdbfile, chain): 17 | structure = load_structure(pdbfile, chain) 18 | _ , native_seq = extract_coords_from_structure(structure) 19 | return native_seq 20 | 21 | def score_singlechain_backbone(model, alphabet, args): 22 | if torch.cuda.is_available() and not args.nogpu: 23 | model = model.cuda() 24 | print("Transferred model to GPU") 25 | 26 | coords, native_seq = load_coords(args.pdbfile, args.chain) 27 | print('Native sequence loaded from structure file:') 28 | print(native_seq) 29 | print('\n') 30 | ll, _ = score_sequence( 31 | model, alphabet, coords, native_seq) 32 | print('Native sequence') 33 | print(f'Log likelihood: {ll:.2f}') 34 | print(f'Perplexity: {np.exp(-ll):.2f}') 35 | print('\nScoring variant sequences from sequence file..\n') 36 | infile = FastaFile() 37 | infile.read(args.seqpath) 38 | seqs = get_sequences(infile) 39 | Path(args.outpath).parent.mkdir(parents=True, exist_ok=True) 40 | with open(args.outpath, 'w') as fout: 41 | fout.write('seqid,log_likelihood\n') 42 | for header, seq in tqdm(seqs.items()): 43 | ll, _ = score_sequence( 44 | model, alphabet, coords, str(seq)) 45 | fout.write(header + ',' + str(ll) + '\n') 46 | print(f'Results saved to {args.outpath}') 47 | 48 | 49 | def score_multichain_backbone(model, alphabet, args): 50 | if torch.cuda.is_available() and not args.nogpu: 51 | model = model.cuda() 52 | print("Transferred model to GPU") 53 | 54 | structure = load_structure(args.pdbfile) 55 | coords, native_seqs = extract_coords_from_complex(structure) 56 | target_chain_id = args.chain 57 | native_seq = native_seqs[target_chain_id] 58 | order = args.order 59 | 60 | print('Native sequence loaded from structure file:') 61 | print(native_seq) 62 | print('\n') 63 | 64 | ll_complex, ll_targetchain = score_sequence_in_complex( 65 | model, 66 | alphabet, 67 | coords, 68 | native_seqs, 69 | target_chain_id, 70 | native_seq, 71 | order=order, 72 | ) 73 | print('Native sequence') 74 | print(f'Log likelihood of complex: {ll_complex:.2f}') 75 | print(f'Log likelihood of target chain: {ll_targetchain:.2f}') 76 | print(f'Perplexity: {np.exp(ll_complex):.2f}') 77 | 78 | print('\nScoring variant sequences from sequence file..\n') 79 | infile = FastaFile() 80 | infile.read(args.seqpath) 81 | seqs = get_sequences(infile) 82 | Path(args.outpath).parent.mkdir(parents=True, exist_ok=True) 83 | with open(args.outpath, 'w') as fout: 84 | fout.write('seqid,log_likelihood, log_likelihood_target\n') 85 | for header, seq in tqdm(seqs.items()): 86 | ll_complex, ll_targetchain = score_sequence_in_complex( 87 | model, 88 | alphabet, 89 | coords, 90 | native_seqs, 91 | target_chain_id, 92 | str(seq), 93 | order=order, 94 | ) 95 | fout.write(header + ',' + str(ll_complex) + ',' + str(ll_targetchain) + '\n') 96 | print(f'Results saved to {args.outpath}') 97 | 98 | def get_model_checkpoint_path(filename): 99 | # Expanding the user's home directory 100 | return os.path.expanduser(f"~/.cache/torch/hub/checkpoints/{filename}") 101 | 102 | def main(): 103 | parser = argparse.ArgumentParser( 104 | description='Score sequences based on a given structure.' 105 | ) 106 | parser.add_argument( 107 | 'pdbfile', type=str, 108 | help='input filepath, either .pdb or .cif', 109 | ) 110 | parser.add_argument( 111 | '--seqpath', type=str, 112 | help='input filepath for variant sequences in a .fasta file', 113 | ) 114 | parser.add_argument( 115 | '--outpath', type=str, 116 | help='output filepath for scores of variant sequences', 117 | ) 118 | parser.add_argument( 119 | '--chain', type=str, 120 | help='chain id for the chain of interest', default='A', 121 | ) 122 | parser.set_defaults(multichain_backbone=True) 123 | parser.add_argument( 124 | '--multichain-backbone', action='store_true', 125 | help='use the backbones of all chains in the input for conditioning' 126 | ) 127 | parser.add_argument( 128 | '--order', type=str, default=None, 129 | help='for multichain, specify order' 130 | ) 131 | parser.add_argument( 132 | '--singlechain-backbone', dest='multichain_backbone', 133 | action='store_false', 134 | help='use the backbone of only target chain in the input for conditioning' 135 | ) 136 | 137 | parser.add_argument( 138 | "--nogpu", action="store_true", 139 | help="Do not use GPU even if available" 140 | ) 141 | args = parser.parse_args() 142 | 143 | if args.outpath is None: 144 | args.outpath = f'output/{args.pdbfile[:-4]}-chain{args.chain}_scores.csv' 145 | 146 | model_checkpoint_path = get_model_checkpoint_path('esm_if1_20220410.pt') 147 | with warnings.catch_warnings(): 148 | warnings.simplefilter('ignore', UserWarning) 149 | model, alphabet = esm.pretrained.load_model_and_alphabet( \ 150 | model_checkpoint_path \ 151 | ) 152 | model = model.eval() 153 | 154 | 155 | if args.multichain_backbone: 156 | score_multichain_backbone(model, alphabet, args) 157 | else: 158 | score_singlechain_backbone(model, alphabet, args) 159 | 160 | 161 | if __name__ == '__main__': 162 | main() 163 | -------------------------------------------------------------------------------- /bin/esm1v_ab_benchmarking.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | run_prediction() { 4 | python ../esm/examples/variant-prediction/predict.py \ 5 | --model-location esm1v_t33_650M_UR90S_1 esm1v_t33_650M_UR90S_2 esm1v_t33_650M_UR90S_3 esm1v_t33_650M_UR90S_4 esm1v_t33_650M_UR90S_5 \ 6 | --sequence "$1" \ 7 | --dms-input "$2" \ 8 | --mutation-col mutant \ 9 | --dms-output "$3" \ 10 | --scoring-strategy 'masked-marginals' \ 11 | --offset-idx "$4" 12 | } 13 | 14 | # CR6261 15 | run_prediction "EVQLVESGAEVKKPGSSVKVSCKASGGPFRSYAISWVRQAPGQGPEWMGGIIPIFGTTKYAPKFQGRVTITADDFAGTVYMELSSLRSEDTAMYYCAKHMGYQVRETMDVWGKGTTVTVSS" \ 16 | "data/ab_mutagenesis_expts/cr6261/cr6261_singleMuts_exp_data.csv" \ 17 | "output/ab_mutagenesis_expts/cr6261/cr6261_exp_data_maskMargLabeled.csv" \ 18 | 1 19 | 20 | # CR9114 21 | run_prediction "QVQLVQSGAEVKKPGSSVKVSCKSSGGTSNNYAISWVRQAPGQGLDWMGGISPIFGSTAYAQKFQGRVTISADIFSNTAYMELNSLTSEDTAVYFCARHGNYYYYSGMDVWGQGTTVTVSS" \ 22 | "data/ab_mutagenesis_expts/cr9114/cr9114_singleMuts_exp_data.csv" \ 23 | "output/ab_mutagenesis_expts/cr9114/cr9114_exp_data_maskMargLabeled.csv" \ 24 | 1 25 | 26 | # G6 HC 27 | run_prediction "EVQLVESGGGLVQPGGSLRLSCAASGFTISDYWIHWVRQAPGKGLEWVAGITPAGGYTYYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARFVFFLPYAMDYWGQGTLVTV" \ 28 | "data/ab_mutagenesis_expts/g6/g6_hc_exp_data.csv" \ 29 | "output/ab_mutagenesis_expts/g6/g6Hc_exp_data_maskMargLabeled.csv" \ 30 | 1 31 | 32 | # G6 LC 33 | run_prediction "DIQMTQSPSSLSASVGDRVTITCRASQDVSTAVAWYQQKPGKAPKLLIYSASFLYSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYTTPPTFGQGTKVEIK" \ 34 | "data/ab_mutagenesis_expts/g6/g6_lc_exp_data.csv" \ 35 | "output/ab_mutagenesis_expts/g6/g6Lc_exp_data_maskMargLabeled.csv" \ 36 | 1 37 | 38 | # G6 HC (both chains) 39 | run_prediction "EVQLVESGGGLVQPGGSLRLSCAASGFTISDYWIHWVRQAPGKGLEWVAGITPAGGYTYYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARFVFFLPYAMDYWGQGTLVTVDIQMTQSPSSLSASVGDRVTITCRASQDVSTAVAWYQQKPGKAPKLLIYSASFLYSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYTTPPTFGQGTKVEIK" \ 40 | "data/ab_mutagenesis_expts/g6/g6_hc_exp_data.csv" \ 41 | "output/ab_mutagenesis_expts/g6/g6Hc_esm1vbothchains_exp_data_maskMargLabeled.csv" \ 42 | 1 43 | 44 | # G6 LC (both chains) 45 | run_prediction "EVQLVESGGGLVQPGGSLRLSCAASGFTISDYWIHWVRQAPGKGLEWVAGITPAGGYTYYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARFVFFLPYAMDYWGQGTLVTVDIQMTQSPSSLSASVGDRVTITCRASQDVSTAVAWYQQKPGKAPKLLIYSASFLYSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYTTPPTFGQGTKVEIK" \ 46 | "data/ab_mutagenesis_expts/g6/g6_lc_exp_data.csv" \ 47 | "output/ab_mutagenesis_expts/g6/g6Lc_esm1vbothchains_exp_data_maskMargLabeled.csv" \ 48 | -117 49 | 50 | # CR9114 (both chains) 51 | run_prediction "QVQLVQSGAEVKKPGSSVKVSCKSSGGTSNNYAISWVRQAPGQGLDWMGGISPIFGSTAYAQKFQGRVTISADIFSNTAYMELNSLTSEDTAVYFCARHGNYYYYSGMDVWGQGTTVTVSSQSALTQPPAVSGTPGQRVTISCSGSDSNIGRRSVNWYQQFPGTAPKLLIYSNDQRPSVVPDRFSGSKSGTSASLAISGLQSEDEAEYYCAAWDDSLKGAVFGGGTQLTVL" \ 52 | "data/ab_mutagenesis_expts/cr9114/cr9114_singleMuts_exp_data.csv" \ 53 | "output/ab_mutagenesis_expts/cr9114/cr9114_esm1vbothchains_exp_data_maskMargLabeled.csv" \ 54 | 1 55 | 56 | # CR6261 (both chains) 57 | run_prediction "EVQLVESGAEVKKPGSSVKVSCKASGGPFRSYAISWVRQAPGQGPEWMGGIIPIFGTTKYAPKFQGRVTITADDFAGTVYMELSSLRSEDTAMYYCAKHMGYQVRETMDVWGKGTTVTVSSQSVLTQPPSVSAAPGQKVTISCSGSSSNIGNDYVSWYQQLPGTAPKLLIYDNNKRPSGIPDRFSGSKSGTSATLGITGLQTGDEANYYCATWDRRPTAYVVFGGGTKLTVL" \ 58 | "data/ab_mutagenesis_expts/cr6261/cr6261_singleMuts_exp_data.csv" \ 59 | "output/ab_mutagenesis_expts/cr6261/cr6261_esm1vbothchains_exp_data_maskMargLabeled.csv" \ 60 | 1 61 | 62 | # CR9114 (antibody + antigen) 63 | run_prediction "QVQLVQSGAEVKKPGSSVKVSCKSSGGTSNNYAISWVRQAPGQGLDWMGGISPIFGSTAYAQKFQGRVTISADIFSNTAYMELNSLTSEDTAVYFCARHGNYYYYSGMDVWGQGTTVTVSSQSALTQPPAVSGTPGQRVTISCSGSDSNIGRRSVNWYQQFPGTAPKLLIYSNDQRPSVVPDRFSGSKSGTSASLAISGLQSEDEAEYYCAAWDDSLKGAVFGGGTQLTVLADPGDQICIGYHANNSTEQVDTIMEKNVTVTHAQDILEKKHNGKLCDLDGVKPLILRDCSVAGWLLGNPMCDEFINVPEWSYIVEKANPVNDLCYPGDFNDYEELKHLLSRINHFEKIQIIPKSSWSSHEASLGVSSACPYQGKSSFFRNVVWLIKKNSTYPTIKRSYNNTNQEDLLVLWGIHHPNDAAEQTKLYQNPTTYISVGTSTLNQRLVPRIATRSKVNGQSGRMEFFWTILKPNDAINFESNGNFIAPEYAYKIVKKGDSTIMKSELEYGNCNTKCQTPMGAINSSMPFHNIHPLTIGECPKYVKSNRLVLATGLRNSPQRERRRKKRGLFGAIAGFIEGGWQGMVDGWYGYHHSNEQGSGYAADKESTQKAIDGVTNKVNSIIDKMNTQFEAVGREFNNLERRIENLNKKMEDGFLDVWTYNAELLVLMENERTLDFHDSNVKNLYDKVRLQLRDNAKELGNGCFEFYHKCDNECMESVRNGTYDYPQYSEEARLKREEISSGR" \ 64 | "data/ab_mutagenesis_expts/cr9114/cr9114_singleMuts_exp_data.csv" \ 65 | "output/ab_mutagenesis_expts/cr9114/cr9114_esm1vAbAg_exp_data_maskMargLabeled.csv" \ 66 | 1 67 | 68 | # CR6261 (antibody + antigen) 69 | run_prediction "EVQLVESGAEVKKPGSSVKVSCKASGGPFRSYAISWVRQAPGQGPEWMGGIIPIFGTTKYAPKFQGRVTITADDFAGTVYMELSSLRSEDTAMYYCAKHMGYQVRETMDVWGKGTTVTVSSQSVLTQPPSVSAAPGQKVTISCSGSSSNIGNDYVSWYQQLPGTAPKLLIYDNNKRPSGIPDRFSGSKSGTSATLGITGLQTGDEANYYCATWDRRPTAYVVFGGGTKLTVLADPGDTICIGYHANNSTDTVDTVLEKNVTVTHSVNLLEDSHNGKLCKLKGIAPLQLGKCNIAGWLLGNPECDLLLTASSWSYIVETSNSENGTCYPGDFIDYEELREQLSSVSSFEKFEIFPKTSSWPNHETTKGVTAACSYAGASSFYRNLLWLTKKGSSYPKLSKSYVNNKGKEVLVLWGVHHPPTGTDQQSLYQNADAYVSVGSSKYNRRFTPEIAARPKVRDQAGRMNYYWTLLEPGDTITFEATGNLIAPWYAFALNRGSGSGIITSDAPVHDCNTKCQTPHGAINSSLPFQNIHPVTIGECPKYVRSTKLRMATGLRNIPSIQSRGLFGAIAGFIEGGWTGMIDGWYGYHHQNEQGSGYAADQKSTQNAIDGITNKVNSVIEKMNTQFTAVGKEFNNLERRIENLNKKVDDGFLDIWTYNAELLVLLENERTLDFHDSNVRNLYEKVKSQLKNNAKEIGNGCFEFYHKCDDACMESVRNGTYDYPKYSEESKLNREEIDGVSGR" \ 70 | "data/ab_mutagenesis_expts/cr6261/cr6261_singleMuts_exp_data.csv" \ 71 | "output/ab_mutagenesis_expts/cr6261/cr6261_esm1vAbAg_exp_data_maskMargLabeled.csv" \ 72 | 1 73 | 74 | # G6 HC (antibody + antigen) 75 | run_prediction "EVQLVESGGGLVQPGGSLRLSCAASGFTISDYWIHWVRQAPGKGLEWVAGITPAGGYTYYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARFVFFLPYAMDYWGQGTLVTVDIQMTQSPSSLSASVGDRVTITCRASQDVSTAVAWYQQKPGKAPKLLIYSASFLYSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYTTPPTFGQGTKVEIKGQNHHEVVKFMDVYQRSYCHPIETLVDIFQEYPDEIEYIFKPSCVPLMRCGGCCNDEGLECVPTEESNITMQIMRIKPHQGQHIGEMSFLQHNKCECRPKKD" \ 76 | "data/ab_mutagenesis_expts/g6/g6_hc_exp_data.csv" \ 77 | "output/ab_mutagenesis_expts/g6/g6Hc_esm1vAbAg_exp_data_maskMargLabeled.csv" \ 78 | 1 79 | 80 | # G6 LC (antibody + antigen) 81 | run_prediction "EVQLVESGGGLVQPGGSLRLSCAASGFTISDYWIHWVRQAPGKGLEWVAGITPAGGYTYYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARFVFFLPYAMDYWGQGTLVTVDIQMTQSPSSLSASVGDRVTITCRASQDVSTAVAWYQQKPGKAPKLLIYSASFLYSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYTTPPTFGQGTKVEIKGQNHHEVVKFMDVYQRSYCHPIETLVDIFQEYPDEIEYIFKPSCVPLMRCGGCCNDEGLECVPTEESNITMQIMRIKPHQGQHIGEMSFLQHNKCECRPKKD" \ 82 | "data/ab_mutagenesis_expts/g6/g6_lc_exp_data.csv" \ 83 | "output/ab_mutagenesis_expts/g6/g6Lc_esm1vAbAg_exp_data_maskMargLabeled.csv" \ 84 | -117 -------------------------------------------------------------------------------- /bin/multichain_util.py: -------------------------------------------------------------------------------- 1 | import biotite.structure 2 | import numpy as np 3 | import torch 4 | from typing import Sequence, Tuple, List 5 | from util import ( 6 | load_structure, 7 | extract_coords_from_structure, 8 | load_coords, 9 | get_sequence_loss, 10 | get_encoder_output, 11 | ) 12 | import util 13 | 14 | 15 | def extract_coords_from_complex(structure: biotite.structure.AtomArray): 16 | """ 17 | Args: 18 | structure: biotite AtomArray 19 | Returns: 20 | Tuple (coords_list, seq_list) 21 | - coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C 22 | coordinates representing the backbone of each chain 23 | - seqs: Dictionary mapping chain ids to native sequences of each chain 24 | """ 25 | coords = {} 26 | seqs = {} 27 | all_chains = biotite.structure.get_chains(structure) 28 | for chain_id in all_chains: 29 | chain = structure[structure.chain_id == chain_id] 30 | coords[chain_id], seqs[chain_id] = extract_coords_from_structure(chain) 31 | return coords, seqs 32 | 33 | 34 | def load_complex_coords(fpath, chains): 35 | """ 36 | Args: 37 | fpath: filepath to either pdb or cif file 38 | chains: the chain ids (the order matters for autoregressive model) 39 | Returns: 40 | Tuple (coords_list, seq_list) 41 | - coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C 42 | coordinates representing the backbone of each chain 43 | - seqs: Dictionary mapping chain ids to native sequences of each chain 44 | """ 45 | structure = load_structure(fpath, chains) 46 | return extract_coords_from_complex(structure) 47 | 48 | #* 49 | def _concatenate_coords( 50 | coords, 51 | target_chain_id, 52 | padding_length=10, 53 | order=None 54 | ): 55 | """ 56 | Args: 57 | coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C 58 | coordinates representing the backbone of each chain 59 | target_chain_id: The chain id to sample sequences for 60 | padding_length: Length of padding between concatenated chains 61 | Returns: 62 | Tuple (coords, seq) 63 | - coords_concatenated is an L x 3 x 3 array for N, CA, C coordinates, a 64 | concatenation of the chains with padding in between 65 | AND target chain placed first 66 | - seq is the extracted sequence, with padding tokens inserted 67 | between the concatenated chains 68 | """ 69 | pad_coords = np.full((padding_length, 3, 3), np.nan, dtype=np.float32) 70 | if order is None: 71 | order = ( 72 | [ target_chain_id ] + 73 | [ chain_id for chain_id in coords if chain_id != target_chain_id ] 74 | ) 75 | coords_list, coords_chains = [], [] 76 | for idx, chain_id in enumerate(order): 77 | if idx > 0: 78 | coords_list.append(pad_coords) 79 | coords_chains.append([ 'pad' ] * padding_length) 80 | coords_list.append(list(coords[chain_id])) 81 | coords_chains.append([ chain_id ] * coords[chain_id].shape[0]) 82 | coords_concatenated = np.concatenate(coords_list, axis=0) 83 | coords_chains = np.concatenate(coords_chains, axis=0).ravel() 84 | return coords_concatenated, coords_chains 85 | 86 | #* 87 | def _concatenate_seqs( 88 | native_seqs, 89 | target_seq, 90 | target_chain_id, 91 | padding_length=10, 92 | order=None, 93 | ): 94 | """ 95 | Args: 96 | native_seqs: Dictionary mapping chain ids to corresponding AA sequence 97 | target_seq: The chain id to sample sequences for 98 | padding_length: Length of padding between concatenated chains 99 | Returns: 100 | native_seqs_concatenated: Array of length L, concatenation of the chain 101 | sequences with padding in between 102 | """ 103 | if order is None: 104 | order = ( 105 | [ target_chain_id ] + 106 | [ chain_id for chain_id in native_seqs if chain_id != target_chain_id ] 107 | ) 108 | native_seqs_list = [] 109 | for idx, chain_id in enumerate(order): 110 | if idx > 0: 111 | native_seqs_list.append([''] * (padding_length - 1) + ['']) 112 | if chain_id == target_chain_id: 113 | native_seqs_list.append(list(target_seq)) 114 | else: 115 | native_seqs_list.append(list(native_seqs[chain_id])) 116 | native_seqs_concatenated = ''.join(np.concatenate(native_seqs_list, axis=0)) 117 | return native_seqs_concatenated 118 | 119 | 120 | #* 121 | def sample_sequence_in_complex(model, coords, target_chain_id, temperature=1., 122 | padding_length=10): 123 | """ 124 | Samples sequence for one chain in a complex. 125 | Args: 126 | model: An instance of the GVPTransformer model 127 | coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C 128 | coordinates representing the backbone of each chain 129 | target_chain_id: The chain id to sample sequences for 130 | padding_length: padding length in between chains 131 | Returns: 132 | Sampled sequence for the target chain 133 | """ 134 | target_chain_len = coords[target_chain_id].shape[0] 135 | all_coords, coords_chains = _concatenate_coords(coords, target_chain_id) 136 | device = next(model.parameters()).device 137 | 138 | # Supply padding tokens for other chains to avoid unused sampling for speed 139 | padding_pattern = [''] * all_coords.shape[0] 140 | for i in range(target_chain_len): 141 | padding_pattern[i] = '' 142 | sampled = model.sample(all_coords, partial_seq=padding_pattern, 143 | temperature=temperature, device=device) 144 | sampled = sampled[:target_chain_len] 145 | return sampled 146 | 147 | 148 | #* 149 | def score_sequence_in_complex( 150 | model, 151 | alphabet, 152 | coords, 153 | native_seqs, 154 | target_chain_id, 155 | target_seq, 156 | padding_length=10, 157 | order=None, 158 | ): 159 | """ 160 | Scores sequence for one chain in a complex. 161 | Args: 162 | model: An instance of the GVPTransformer model 163 | alphabet: Alphabet for the model 164 | coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C 165 | coordinates representing the backbone of each chain 166 | native_seqs: Dictionary mapping chain ids to sequence 167 | extracted from each chain 168 | target_chain_id: The chain id to sample sequences for 169 | target_seq: Target sequence for the target chain for scoring. 170 | padding_length: padding length in between chains 171 | Returns: 172 | Tuple (ll_fullseq, ll_withcoord) 173 | - ll_fullseq: Average log-likelihood over the full target chain 174 | - ll_targetseq Average log-likelihood in target chain excluding those 175 | residues without coordinates 176 | """ 177 | 178 | assert(len(target_seq) == len(native_seqs[target_chain_id])) 179 | 180 | all_coords, coords_chains = _concatenate_coords( 181 | coords, 182 | target_chain_id, 183 | order=order, 184 | ) 185 | all_seqs = _concatenate_seqs( 186 | native_seqs, 187 | target_seq, 188 | target_chain_id, 189 | order=order, 190 | ) 191 | 192 | loss, target_padding_mask = get_sequence_loss(model, alphabet, all_coords, 193 | all_seqs) 194 | assert(all_coords.shape[0] == coords_chains.shape[0] == loss.shape[0]) 195 | 196 | ll_fullseq = -np.mean(loss[coords_chains != 'pad']) 197 | ll_targetseq = -np.mean(loss[coords_chains == target_chain_id]) 198 | 199 | return ll_fullseq, ll_targetseq 200 | 201 | 202 | def get_encoder_output_for_complex(model, alphabet, coords, target_chain_id): 203 | """ 204 | Args: 205 | model: An instance of the GVPTransformer model 206 | alphabet: Alphabet for the model 207 | coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C 208 | coordinates representing the backbone of each chain 209 | target_chain_id: The chain id to sample sequences for 210 | Returns: 211 | Dictionary mapping chain id to encoder output for each chain 212 | """ 213 | all_coords = _concatenate_coords(coords, target_chain_id) 214 | all_rep = get_encoder_output(model, alphabet, all_coords) 215 | target_chain_len = coords[target_chain_id].shape[0] 216 | return all_rep[:target_chain_len] -------------------------------------------------------------------------------- /bin/plot_mpnn_benchmarks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import bokeh.io 4 | import bokeh.plotting 5 | import bokeh.palettes 6 | from bokeh.transform import factor_cmap 7 | import datashader 8 | import holoviews as hv 9 | import holoviews.operation.datashader 10 | import os 11 | from natsort import natsorted 12 | 13 | hv.extension("bokeh") 14 | 15 | import warnings 16 | # Suppress FutureWarning messages 17 | warnings.simplefilter(action='ignore',) 18 | 19 | cr9114_dict = { 20 | 'ab_name' : 'CR9114', 21 | 'files' : { 22 | 'ESM-IF1': 'output/ab_mutagenesis_expts/cr9114/4fqi_ablh_scores.csv', 23 | 'ProteinMPNN': 'output/ab_mutagenesis_expts/cr9114/mpnn/score_only/', 24 | 25 | }, 26 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/cr9114/cr9114_exp_data.csv', 27 | dtype = {'genotype': str}, 28 | ).rename(columns={'h1_mean': 'H1', 'h3_mean' : 'H3'}), 29 | 'ag_columns': ['H1', 'H3'], 30 | 'expt_type': 'Combinatorial Mutagenesis for Affinity', 31 | 'palette': bokeh.palettes.Spectral6 32 | } 33 | 34 | cr6261_dict = { 35 | 'ab_name' : 'CR6261', 36 | 'files' : { 37 | 'ESM-IF1': 'output/ab_mutagenesis_expts/cr6261/3gbn_ablh_scores.csv', 38 | 'ProteinMPNN': 'output/ab_mutagenesis_expts/cr6261/mpnn/score_only/' 39 | }, 40 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/cr6261/cr6261_exp_data.csv', 41 | dtype={'genotype': str}, 42 | ).rename(columns={'h1_mean': 'H1', 'h9_mean' : 'H9'}), 43 | 'ag_columns': ['H1', 'H9'], 44 | 'expt_type': 'Combinatorial Mutagenesis for Affinity', 45 | 'palette': bokeh.palettes.Dark2_6, 46 | } 47 | 48 | 49 | g6LC_dict = { 50 | 'ab_name' : 'g6', 51 | 'files' : { 52 | 'ESM-IF1':'output/ab_mutagenesis_expts/g6/2fjg_vlh_lc_scores.csv', 53 | 'ProteinMPNN':'output/ab_mutagenesis_expts/g6/proteinMpnnLC/score_only/', 54 | 55 | }, 56 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/g6/g6_lc_exp_data.csv'), 57 | 'ag_columns': ['norm_binding'], 58 | 'expt_type': 'Deep Mutational Scan for Binding', 59 | 'palette': bokeh.palettes.Pastel1_6, 60 | 'chain': 'VL' 61 | } 62 | 63 | g6HC_dict = { 64 | 'ab_name' : 'g6', 65 | 'files' : { 66 | 'ESM-IF1': 'output/ab_mutagenesis_expts/g6/2fjg_vlh_hc_scores.csv', 67 | 'ProteinMPNN':'output/ab_mutagenesis_expts/g6/proteinMpnnHC/score_only/', 68 | }, 69 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/g6/g6_hc_exp_data.csv'), 70 | 'ag_columns': ['norm_binding'], 71 | 'expt_type': 'Deep Mutational Scan for Binding', 72 | 'palette': bokeh.palettes.Pastel1_4, 73 | 'chain': 'VH' 74 | } 75 | def combine_mpnn_scores(dir): 76 | global_scores = [] 77 | npz_files = [f for f in os.listdir(dir) if f.endswith('.npz')] 78 | npz_files = natsorted(npz_files)[:-1] 79 | 80 | for fname in npz_files: 81 | file_path = os.path.join(dir, fname) 82 | data = np.load(file_path) 83 | global_scores.append(-1* data['global_score'][0]) 84 | data.close() 85 | 86 | 87 | return global_scores 88 | 89 | 90 | #get melted correlations matrix of structure input x target antigen, for a given antibody 91 | def get_corr(ab_name, files, dms_df, ag_columns,): 92 | 93 | #retreive correlations from inverse fold scores 94 | for key, filepath in files.items(): 95 | column_label = key 96 | if key == 'ProteinMPNN': 97 | scores = combine_mpnn_scores(filepath) 98 | dms_df[column_label] = scores 99 | elif key == 'ESM-IF1': 100 | df = pd.read_csv(files['ESM-IF1']) # Read the CSV file into a DataFrame 101 | log_likelihood_column = df['log_likelihood'] # Extract the 'log_likelihood' column 102 | dms_df[column_label] = log_likelihood_column 103 | 104 | 105 | conditions = list(files.keys()) 106 | 107 | correlations = dms_df[ag_columns + conditions].corr(method='spearman') 108 | correlations = correlations.drop(conditions, axis= 1) 109 | correlations = correlations.drop(ag_columns, axis= 0) 110 | 111 | # Melt the correlations DataFrame and rename the index column 112 | melted_correlations = ( 113 | correlations 114 | .reset_index() 115 | .melt(id_vars='index', var_name='Target', value_name='Correlation') 116 | .rename(columns={'index': 'Input'}) 117 | ) 118 | 119 | print(melted_correlations) 120 | return melted_correlations 121 | 122 | return all_ag_melted 123 | 124 | #get melted correlations matrix of structure input x target antigen, for g6 antibody 125 | def get_g6_corr(g6Hc_dict, g6LC_dict, dropLOQ = False ): 126 | 127 | ab_name = 'g6' 128 | ag_columns = g6LC_dict['ag_columns'] 129 | vh_and_vl_dms_df = pd.DataFrame({}) 130 | 131 | for d in [g6HC_dict, g6LC_dict]: 132 | 133 | files = d['files'] 134 | dms_df = d['dms_df'] 135 | 136 | #retreive correlations from abysis, ablang, and esm1v scores 137 | for key, filepath in files.items(): 138 | column_label = key 139 | if key == 'ProteinMPNN': 140 | scores = combine_mpnn_scores(filepath) 141 | dms_df[column_label] = scores 142 | elif key == 'ESM-IF1': 143 | df = pd.read_csv(files['ESM-IF1']) # Read the CSV file into a DataFrame 144 | log_likelihood_column = df['log_likelihood'] # Extract the 'log_likelihood' column 145 | dms_df[column_label] = log_likelihood_column 146 | 147 | 148 | vh_and_vl_dms_df = pd.concat([vh_and_vl_dms_df, dms_df], ignore_index= True) 149 | 150 | conditions = list(files.keys()) 151 | 152 | correlations = vh_and_vl_dms_df[ag_columns + conditions].corr(method='spearman') 153 | correlations = correlations.drop(conditions, axis= 1) 154 | correlations = correlations.drop(ag_columns, axis= 0) 155 | 156 | # Melt the correlations DataFrame and rename the index column 157 | melted_correlations = ( 158 | correlations 159 | .reset_index() 160 | .melt(id_vars='index', var_name='Target', value_name='Correlation') 161 | .rename(columns={'index': 'Input'}) 162 | ) 163 | melted_correlations['Target'] = ['VEGF-A'] * len(melted_correlations) 164 | 165 | print(melted_correlations) 166 | return melted_correlations 167 | 168 | 169 | #plot correlation bars for cr antibodies: cr6261 and cr9114 170 | def plot_hbar(title, melted_correlations, palette = bokeh.palettes.Spectral6, ): 171 | 172 | melted_correlations['cats'] = melted_correlations.apply(lambda x: (x["Target"], x["Input"]), axis = 1) 173 | factors = list(melted_correlations.cats)[::-1] 174 | 175 | p = bokeh.plotting.figure( 176 | height=340, 177 | width=440, 178 | x_axis_label="Spearman Correlation", 179 | x_range=[0, 1], 180 | y_range=bokeh.models.FactorRange(*factors), 181 | tools="save", 182 | title = title 183 | ) 184 | 185 | 186 | p.hbar( 187 | source=melted_correlations, 188 | y="cats", 189 | right="Correlation", 190 | height=0.6, 191 | line_color = 'black', 192 | fill_color=bokeh.palettes.Dark2_6[0] 193 | 194 | ) 195 | 196 | labels_df = melted_correlations 197 | labels_df['corr_str'] = labels_df['Correlation'].apply(lambda x: round(x, 2)).astype(str) 198 | labels_source = bokeh.models.ColumnDataSource(labels_df) 199 | 200 | labels = bokeh.models.LabelSet(x='Correlation', y='cats', text='corr_str',text_font_size = "10px", 201 | x_offset=12, y_offset=-5, source=labels_source, render_mode='canvas') 202 | 203 | p.ygrid.grid_line_color = None 204 | p.y_range.range_padding = 0.1 205 | p.add_layout(labels) 206 | p.legend.visible = False 207 | 208 | p.output_backend = "svg" 209 | return p 210 | 211 | if __name__ == '__main__': 212 | datasets = [cr9114_dict, cr6261_dict, (g6HC_dict, g6LC_dict) ] 213 | 214 | all_corr_plots = [] 215 | 216 | for d in datasets: 217 | if type(d) is tuple: 218 | g6Hc, g6Lc = d 219 | title = g6Hc['ab_name'] + ', ' + g6Hc['expt_type'] 220 | g6_combined = get_g6_corr(g6Hc, g6Lc) 221 | 222 | all_corr_plots.append(plot_hbar(title, g6_combined, g6Lc['palette'])) 223 | 224 | else: 225 | corr_df = get_corr(d['ab_name'], d['files'], d['dms_df'], d['ag_columns']) 226 | title = d['ab_name'] + ', ' + d['expt_type'] 227 | all_corr_plots.append(plot_hbar(title, corr_df, d['palette'])) 228 | 229 | 230 | all_corr_fname = f"output/ab_mutagenesis_expts/mpnn_benchmarks.html" 231 | bokeh.plotting.output_file(all_corr_fname) 232 | bokeh.io.show(bokeh.layouts.gridplot(all_corr_plots, ncols = len(all_corr_plots))) 233 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: struct-evo 2 | channels: 3 | - pytorch 4 | - pyg 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - abseil-cpp=20211102.0=hd4dd3e8_0 11 | - anyio=4.2.0=py39h06a4308_0 12 | - argon2-cffi=21.3.0=pyhd3eb1b0_0 13 | - argon2-cffi-bindings=21.2.0=py39h7f8727e_0 14 | - arrow-cpp=14.0.2=h374c478_1 15 | - asttokens=2.0.5=pyhd3eb1b0_0 16 | - attrs=23.1.0=py39h06a4308_0 17 | - aws-c-auth=0.6.19=h5eee18b_0 18 | - aws-c-cal=0.5.20=hdbd6064_0 19 | - aws-c-common=0.8.5=h5eee18b_0 20 | - aws-c-compression=0.2.16=h5eee18b_0 21 | - aws-c-event-stream=0.2.15=h6a678d5_0 22 | - aws-c-http=0.6.25=h5eee18b_0 23 | - aws-c-io=0.13.10=h5eee18b_0 24 | - aws-c-mqtt=0.7.13=h5eee18b_0 25 | - aws-c-s3=0.1.51=hdbd6064_0 26 | - aws-c-sdkutils=0.1.6=h5eee18b_0 27 | - aws-checksums=0.1.13=h5eee18b_0 28 | - aws-crt-cpp=0.18.16=h6a678d5_0 29 | - aws-sdk-cpp=1.10.55=h721c034_0 30 | - backcall=0.2.0=pyhd3eb1b0_0 31 | - beautifulsoup4=4.12.3=py39h06a4308_0 32 | - blas=1.0=mkl 33 | - bleach=4.1.0=pyhd3eb1b0_0 34 | - bokeh=2.4.3=py39h06a4308_0 35 | - boost-cpp=1.82.0=hdb19cb5_2 36 | - bottleneck=1.3.7=py39ha9d4c09_0 37 | - brotli=1.0.9=h5eee18b_8 38 | - brotli-bin=1.0.9=h5eee18b_8 39 | - brotli-python=1.0.9=py39h5a03fae_7 40 | - bzip2=1.0.8=h5eee18b_6 41 | - c-ares=1.19.1=h5eee18b_0 42 | - ca-certificates=2024.3.11=h06a4308_0 43 | - certifi=2024.6.2=py39h06a4308_0 44 | - cffi=1.16.0=py39h5eee18b_1 45 | - charset-normalizer=3.3.2=pyhd8ed1ab_0 46 | - click=8.1.7=py39h06a4308_0 47 | - cloudpickle=2.2.1=py39h06a4308_0 48 | - colorama=0.4.6=pyhd8ed1ab_0 49 | - colorcet=3.0.0=py39h06a4308_0 50 | - comm=0.2.1=py39h06a4308_0 51 | - contourpy=1.2.0=py39hdb19cb5_0 52 | - cudatoolkit=11.3.1=h2bc3f7f_2 53 | - cycler=0.11.0=pyhd3eb1b0_0 54 | - cyrus-sasl=2.1.28=h52b45da_1 55 | - cytoolz=0.12.2=py39h5eee18b_0 56 | - dask=2023.11.0=py39h06a4308_0 57 | - dask-core=2023.11.0=py39h06a4308_0 58 | - datashader=0.14.1=py39h06a4308_1 59 | - datashape=0.5.4=py39h06a4308_1 60 | - dbus=1.13.18=hb2f20db_0 61 | - debugpy=1.6.7=py39h6a678d5_0 62 | - decorator=5.1.1=pyhd3eb1b0_0 63 | - defusedxml=0.7.1=pyhd3eb1b0_0 64 | - distributed=2023.11.0=py39h06a4308_0 65 | - entrypoints=0.4=py39h06a4308_0 66 | - et_xmlfile=1.1.0=py39h06a4308_0 67 | - exceptiongroup=1.2.0=py39h06a4308_0 68 | - executing=0.8.3=pyhd3eb1b0_0 69 | - expat=2.6.2=h6a678d5_0 70 | - fontconfig=2.14.1=h4c34cd2_2 71 | - fonttools=4.51.0=py39h5eee18b_0 72 | - freetype=2.12.1=h4a9f257_0 73 | - fsspec=2024.3.1=py39h06a4308_0 74 | - gflags=2.2.2=h6a678d5_1 75 | - glib=2.78.4=h6a678d5_0 76 | - glib-tools=2.78.4=h6a678d5_0 77 | - glog=0.5.0=h6a678d5_1 78 | - grpc-cpp=1.48.2=he1ff14a_1 79 | - gst-plugins-base=1.14.1=h6a678d5_1 80 | - gstreamer=1.14.1=h5eee18b_1 81 | - heapdict=1.0.1=pyhd3eb1b0_0 82 | - holoviews=1.15.0=py39h06a4308_1 83 | - hvplot=0.8.0=py39h06a4308_0 84 | - icu=73.1=h6a678d5_0 85 | - idna=3.7=pyhd8ed1ab_0 86 | - importlib-metadata=7.0.1=py39h06a4308_0 87 | - importlib_resources=6.1.1=py39h06a4308_1 88 | - intel-openmp=2023.1.0=hdb19cb5_46306 89 | - ipykernel=6.28.0=py39h06a4308_0 90 | - ipython=8.15.0=py39h06a4308_0 91 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 92 | - jedi=0.18.1=py39h06a4308_1 93 | - jinja2=2.11.3=pyhd8ed1ab_2 94 | - joblib=1.4.2=pyhd8ed1ab_0 95 | - jpeg=9e=h5eee18b_1 96 | - jsonschema=4.19.2=py39h06a4308_0 97 | - jsonschema-specifications=2023.7.1=py39h06a4308_0 98 | - jupyter_client=7.4.9=py39h06a4308_0 99 | - jupyter_core=5.7.2=py39h06a4308_0 100 | - jupyter_events=0.10.0=py39h06a4308_0 101 | - jupyter_server=2.10.0=py39h06a4308_0 102 | - jupyter_server_terminals=0.4.4=py39h06a4308_1 103 | - jupyterlab_pygments=0.2.2=py39h06a4308_0 104 | - kiwisolver=1.4.4=py39h6a678d5_0 105 | - krb5=1.20.1=h143b758_1 106 | - lcms2=2.12=h3be6417_0 107 | - ld_impl_linux-64=2.38=h1181459_1 108 | - lerc=3.0=h295c915_0 109 | - libboost=1.82.0=h109eef0_2 110 | - libbrotlicommon=1.0.9=h5eee18b_8 111 | - libbrotlidec=1.0.9=h5eee18b_8 112 | - libbrotlienc=1.0.9=h5eee18b_8 113 | - libclang=14.0.6=default_hc6dbbc7_1 114 | - libclang13=14.0.6=default_he11475f_1 115 | - libcups=2.4.2=h2d74bed_1 116 | - libcurl=8.7.1=h251f7ec_0 117 | - libdeflate=1.17=h5eee18b_1 118 | - libedit=3.1.20230828=h5eee18b_0 119 | - libev=4.33=h7f8727e_1 120 | - libevent=2.1.12=hdbd6064_1 121 | - libffi=3.4.4=h6a678d5_1 122 | - libgcc-ng=11.2.0=h1234567_1 123 | - libgfortran-ng=11.2.0=h00389a5_1 124 | - libgfortran5=11.2.0=h1234567_1 125 | - libglib=2.78.4=hdc74915_0 126 | - libgomp=11.2.0=h1234567_1 127 | - libiconv=1.16=h5eee18b_3 128 | - libllvm14=14.0.6=hdb19cb5_3 129 | - libnghttp2=1.57.0=h2d74bed_0 130 | - libpng=1.6.39=h5eee18b_0 131 | - libpq=12.17=hdbd6064_0 132 | - libprotobuf=3.20.3=he621ea3_0 133 | - libsodium=1.0.18=h7b6447c_0 134 | - libssh2=1.11.0=h251f7ec_0 135 | - libstdcxx-ng=11.2.0=h1234567_1 136 | - libthrift=0.15.0=h1795dd8_2 137 | - libtiff=4.5.1=h6a678d5_0 138 | - libuuid=1.41.5=h5eee18b_0 139 | - libuv=1.44.2=h5eee18b_0 140 | - libwebp-base=1.3.2=h5eee18b_0 141 | - libxcb=1.15=h7f8727e_0 142 | - libxkbcommon=1.0.1=h5eee18b_1 143 | - libxml2=2.10.4=hfdd30dd_2 144 | - llvmlite=0.42.0=py39h6a678d5_0 145 | - locket=1.0.0=py39h06a4308_0 146 | - lz4=4.3.2=py39h5eee18b_0 147 | - lz4-c=1.9.4=h6a678d5_1 148 | - markdown=3.4.1=py39h06a4308_0 149 | - markupsafe=1.1.1=py39h3811e60_3 150 | - matplotlib=3.8.0=py39h06a4308_0 151 | - matplotlib-base=3.8.0=py39h1128e8f_0 152 | - matplotlib-inline=0.1.6=py39h06a4308_0 153 | - mistune=0.8.4=py39h27cfd23_1000 154 | - mkl=2023.1.0=h213fc3f_46344 155 | - mkl-service=2.4.0=py39h5eee18b_1 156 | - mkl_fft=1.3.8=py39h5eee18b_0 157 | - mkl_random=1.2.4=py39hdb19cb5_0 158 | - msgpack-python=1.0.3=py39hd09550d_0 159 | - multipledispatch=0.6.0=py39h06a4308_0 160 | - mysql=5.7.24=h721c034_2 161 | - nbclassic=1.1.0=py39h06a4308_0 162 | - nbclient=0.5.13=py39h06a4308_0 163 | - nbconvert=6.4.4=py39h06a4308_0 164 | - nbformat=5.9.2=py39h06a4308_0 165 | - ncurses=6.4=h6a678d5_0 166 | - nest-asyncio=1.6.0=py39h06a4308_0 167 | - notebook=6.5.7=py39h06a4308_0 168 | - notebook-shim=0.2.3=py39h06a4308_0 169 | - numba=0.59.1=py39h6a678d5_0 170 | - numexpr=2.8.7=py39h85018f9_0 171 | - numpy=1.22.3=py39hf6e8229_2 172 | - numpy-base=1.22.3=py39h060ed82_2 173 | - openjpeg=2.4.0=h9ca470c_1 174 | - openpyxl=3.0.10=py39h5eee18b_0 175 | - openssl=3.0.14=h5eee18b_0 176 | - orc=1.7.4=hb3bc3d3_1 177 | - overrides=7.4.0=py39h06a4308_0 178 | - packaging=23.2=py39h06a4308_0 179 | - pandocfilters=1.5.0=pyhd3eb1b0_0 180 | - panel=0.13.1=py39h06a4308_0 181 | - param=1.12.0=pyhd3eb1b0_0 182 | - parso=0.8.3=pyhd3eb1b0_0 183 | - partd=1.4.1=py39h06a4308_0 184 | - pcre2=10.42=hebb0a14_1 185 | - pexpect=4.8.0=pyhd3eb1b0_3 186 | - pickleshare=0.7.5=pyhd3eb1b0_1003 187 | - pillow=10.3.0=py39h5eee18b_0 188 | - pip=24.0=py39h06a4308_0 189 | - platformdirs=3.10.0=py39h06a4308_0 190 | - ply=3.11=py39h06a4308_0 191 | - prometheus_client=0.14.1=py39h06a4308_0 192 | - prompt-toolkit=3.0.43=py39h06a4308_0 193 | - psutil=5.9.0=py39h5eee18b_0 194 | - ptyprocess=0.7.0=pyhd3eb1b0_2 195 | - pure_eval=0.2.2=pyhd3eb1b0_0 196 | - pyarrow=14.0.2=py39h1eedbd7_0 197 | - pybind11-abi=4=hd8ed1ab_3 198 | - pycparser=2.21=pyhd3eb1b0_0 199 | - pyct=0.5.0=py39h06a4308_0 200 | - pyg=2.1.0=py39_torch_1.11.0_cu113 201 | - pygments=2.15.1=py39h06a4308_1 202 | - pyparsing=3.0.9=py39h06a4308_0 203 | - pyqt=5.15.10=py39h6a678d5_0 204 | - pyqt5-sip=12.13.0=py39h5eee18b_0 205 | - pysocks=1.7.1=pyha2e5f31_6 206 | - python=3.9.19=h955ad1f_1 207 | - python-dateutil=2.9.0post0=py39h06a4308_2 208 | - python-fastjsonschema=2.16.2=py39h06a4308_0 209 | - python-json-logger=2.0.7=py39h06a4308_0 210 | - python-lmdb=1.4.1=py39h6a678d5_0 211 | - python_abi=3.9=2_cp39 212 | - pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0 213 | - pytorch-cluster=1.6.0=py39_torch_1.11.0_cu113 214 | - pytorch-mutex=1.0=cuda 215 | - pytorch-scatter=2.0.9=py39_torch_1.11.0_cu113 216 | - pytorch-sparse=0.6.15=py39_torch_1.11.0_cu113 217 | - pytz=2024.1=py39h06a4308_0 218 | - pyviz_comms=3.0.2=py39h06a4308_0 219 | - pyyaml=6.0.1=py39h5eee18b_0 220 | - pyzmq=24.0.1=py39h5eee18b_0 221 | - qt-main=5.15.2=h53bd1ea_10 222 | - re2=2022.04.01=h295c915_0 223 | - readline=8.2=h5eee18b_0 224 | - referencing=0.30.2=py39h06a4308_0 225 | - requests=2.32.3=pyhd8ed1ab_0 226 | - rfc3339-validator=0.1.4=py39h06a4308_0 227 | - rfc3986-validator=0.1.1=py39h06a4308_0 228 | - rpds-py=0.10.6=py39hb02cf49_0 229 | - s2n=1.3.27=hdbd6064_0 230 | - scikit-learn=1.4.2=py39h1128e8f_1 231 | - scipy=1.11.4=py39h5f9d8c6_0 232 | - send2trash=1.8.2=py39h06a4308_0 233 | - setuptools=69.5.1=py39h06a4308_0 234 | - sip=6.7.12=py39h6a678d5_0 235 | - six=1.16.0=pyhd3eb1b0_1 236 | - snappy=1.1.10=h6a678d5_1 237 | - sniffio=1.3.0=py39h06a4308_0 238 | - sortedcontainers=2.4.0=pyhd3eb1b0_0 239 | - soupsieve=2.5=py39h06a4308_0 240 | - sqlite=3.45.3=h5eee18b_0 241 | - stack_data=0.2.0=pyhd3eb1b0_0 242 | - tbb=2021.8.0=hdb19cb5_0 243 | - tblib=1.7.0=pyhd3eb1b0_0 244 | - terminado=0.17.1=py39h06a4308_0 245 | - testpath=0.6.0=py39h06a4308_0 246 | - threadpoolctl=3.5.0=pyhc1e730c_0 247 | - tk=8.6.14=h39e8969_0 248 | - tomli=2.0.1=py39h06a4308_0 249 | - toolz=0.12.0=py39h06a4308_0 250 | - tornado=6.4.1=py39h5eee18b_0 251 | - tqdm=4.66.4=pyhd8ed1ab_0 252 | - traitlets=5.14.3=py39h06a4308_0 253 | - typing-extensions=4.11.0=py39h06a4308_0 254 | - typing_extensions=4.11.0=py39h06a4308_0 255 | - unicodedata2=15.1.0=py39h5eee18b_0 256 | - urllib3=2.2.2=pyhd8ed1ab_0 257 | - utf8proc=2.6.1=h5eee18b_1 258 | - wcwidth=0.2.5=pyhd3eb1b0_0 259 | - webencodings=0.5.1=py39h06a4308_1 260 | - websocket-client=1.8.0=py39h06a4308_0 261 | - wheel=0.43.0=py39h06a4308_0 262 | - xarray=2023.6.0=py39h06a4308_0 263 | - xz=5.4.6=h5eee18b_1 264 | - yaml=0.2.5=h7b6447c_0 265 | - zeromq=4.3.5=h6a678d5_0 266 | - zict=3.0.0=py39h06a4308_0 267 | - zipp=3.17.0=py39h06a4308_0 268 | - zlib=1.2.13=h5eee18b_1 269 | - zstd=1.5.5=hc292b87_2 270 | - pip: 271 | - ablang==0.3.1 272 | - biopython==1.78 273 | - biotite==0.34.1 274 | - git+https://github.com/facebookresearch/esm.git 275 | - iqplot==0.3.3 276 | - msgpack==1.0.8 277 | - natsort==8.4.0 278 | - networkx==3.2.1 279 | - pandas==2.0.3 280 | - tzdata==2024.1 281 | prefix: /data/home/vrs/miniconda3/envs/struct-evo 282 | -------------------------------------------------------------------------------- /bin/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import biotite.structure 4 | from biotite.structure.io import pdbx, pdb 5 | from biotite.structure.residues import get_residues 6 | from biotite.structure import filter_backbone 7 | from biotite.structure import get_chains 8 | from biotite.sequence import ProteinSequence 9 | import numpy as np 10 | from scipy.spatial import transform 11 | from scipy.stats import special_ortho_group 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.utils.data as data 16 | from typing import Sequence, Tuple, List 17 | from esm.data import BatchConverter 18 | 19 | 20 | def load_structure(fpath, chain=None): 21 | """ 22 | Args: 23 | fpath: filepath to either pdb or cif file 24 | chain: the chain id or list of chain ids to load 25 | Returns: 26 | biotite.structure.AtomArray 27 | """ 28 | if fpath.endswith('cif'): 29 | with open(fpath) as fin: 30 | pdbxf = pdbx.PDBxFile.read(fin) 31 | structure = pdbx.get_structure(pdbxf, model=1) 32 | elif fpath.endswith('pdb'): 33 | with open(fpath) as fin: 34 | pdbf = pdb.PDBFile.read(fin) 35 | structure = pdb.get_structure(pdbf, model=1) 36 | bbmask = filter_backbone(structure) 37 | structure = structure[bbmask] 38 | all_chains = get_chains(structure) 39 | if len(all_chains) == 0: 40 | raise ValueError('No chains found in the input file.') 41 | if chain is None: 42 | chain_ids = all_chains 43 | elif isinstance(chain, list): 44 | chain_ids = chain 45 | else: 46 | chain_ids = [chain] 47 | for chain in chain_ids: 48 | if chain not in all_chains: 49 | raise ValueError(f'Chain {chain} not found in input file') 50 | chain_filter = [a.chain_id in chain_ids for a in structure] 51 | structure = structure[chain_filter] 52 | return structure 53 | 54 | 55 | def extract_coords_from_structure(structure: biotite.structure.AtomArray): 56 | """ 57 | Args: 58 | structure: An instance of biotite AtomArray 59 | Returns: 60 | Tuple (coords, seq) 61 | - coords is an L x 3 x 3 array for N, CA, C coordinates 62 | - seq is the extracted sequence 63 | """ 64 | coords = get_atom_coords_residuewise(["N", "CA", "C"], structure) 65 | residue_identities = get_residues(structure)[1] 66 | seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities]) 67 | return coords, seq 68 | 69 | 70 | def load_coords(fpath, chain): 71 | """ 72 | Args: 73 | fpath: filepath to either pdb or cif file 74 | chain: the chain id 75 | Returns: 76 | Tuple (coords, seq) 77 | - coords is an L x 3 x 3 array for N, CA, C coordinates 78 | - seq is the extracted sequence 79 | """ 80 | structure = load_structure(fpath, chain) 81 | return extract_coords_from_structure(structure) 82 | 83 | 84 | def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray): 85 | """ 86 | Example for atoms argument: ["N", "CA", "C"] 87 | """ 88 | def filterfn(s, axis=None): 89 | filters = np.stack([s.atom_name == name for name in atoms], axis=1) 90 | sum = filters.sum(0) 91 | if not np.all(sum <= np.ones(filters.shape[1])): 92 | raise RuntimeError("structure has multiple atoms with same name") 93 | index = filters.argmax(0) 94 | coords = s[index].coord 95 | coords[sum == 0] = float("nan") 96 | return coords 97 | 98 | return biotite.structure.apply_residue_wise(struct, struct, filterfn) 99 | 100 | #* now consistent with updated gpu 101 | def get_sequence_loss(model, alphabet, coords, seq): 102 | device = next(model.parameters()).device 103 | batch_converter = CoordBatchConverter(alphabet) 104 | batch = [(coords, None, seq)] 105 | coords, confidence, strs, tokens, padding_mask = batch_converter( 106 | batch, device=device) 107 | 108 | prev_output_tokens = tokens[:, :-1].to(device) 109 | target = tokens[:, 1:] 110 | target_padding_mask = (target == alphabet.padding_idx) 111 | logits, _ = model.forward(coords, padding_mask, confidence, prev_output_tokens) 112 | loss = F.cross_entropy(logits, target, reduction='none') 113 | loss = loss[0].cpu().detach().numpy() 114 | target_padding_mask = target_padding_mask[0].cpu().numpy() 115 | return loss, target_padding_mask 116 | 117 | 118 | def score_sequence(model, alphabet, coords, seq): 119 | loss, target_padding_mask = get_sequence_loss(model, alphabet, coords, seq) 120 | ll_fullseq = -np.sum(loss * ~target_padding_mask) / np.sum(~target_padding_mask) 121 | # Also calculate average when excluding masked portions 122 | coord_mask = np.all(np.isfinite(coords), axis=(-1, -2)) 123 | ll_withcoord = -np.sum(loss * coord_mask) / np.sum(coord_mask) 124 | return ll_fullseq, ll_withcoord 125 | 126 | 127 | def get_encoder_output(model, alphabet, coords): 128 | device = next(model.parameters()).device 129 | batch_converter = CoordBatchConverter(alphabet) 130 | batch = [(coords, None, None)] 131 | coords, confidence, strs, tokens, padding_mask = batch_converter( 132 | batch, device=device) 133 | encoder_out = model.encoder.forward(coords, padding_mask, confidence, 134 | return_all_hiddens=False) 135 | # remove beginning and end (bos and eos tokens) 136 | return encoder_out['encoder_out'][0][1:-1, 0] 137 | 138 | 139 | def rotate(v, R): 140 | """ 141 | Rotates a vector by a rotation matrix. 142 | 143 | Args: 144 | v: 3D vector, tensor of shape (length x batch_size x channels x 3) 145 | R: rotation matrix, tensor of shape (length x batch_size x 3 x 3) 146 | 147 | Returns: 148 | Rotated version of v by rotation matrix R. 149 | """ 150 | R = R.unsqueeze(-3) 151 | v = v.unsqueeze(-1) 152 | return torch.sum(v * R, dim=-2) 153 | 154 | 155 | def get_rotation_frames(coords): 156 | """ 157 | Returns a local rotation frame defined by N, CA, C positions. 158 | 159 | Args: 160 | coords: coordinates, tensor of shape (batch_size x length x 3 x 3) 161 | where the third dimension is in order of N, CA, C 162 | 163 | Returns: 164 | Local relative rotation frames in shape (batch_size x length x 3 x 3) 165 | """ 166 | v1 = coords[:, :, 2] - coords[:, :, 1] 167 | v2 = coords[:, :, 0] - coords[:, :, 1] 168 | e1 = normalize(v1, dim=-1) 169 | u2 = v2 - e1 * torch.sum(e1 * v2, dim=-1, keepdim=True) 170 | e2 = normalize(u2, dim=-1) 171 | e3 = torch.cross(e1, e2, dim=-1) 172 | R = torch.stack([e1, e2, e3], dim=-2) 173 | return R 174 | 175 | 176 | def nan_to_num(ts, val=0.0): 177 | """ 178 | Replaces nans in tensor with a fixed value. 179 | """ 180 | val = torch.tensor(val, dtype=ts.dtype, device=ts.device) 181 | return torch.where(~torch.isfinite(ts), val, ts) 182 | 183 | 184 | def rbf(values, v_min, v_max, n_bins=16): 185 | """ 186 | Returns RBF encodings in a new dimension at the end. 187 | """ 188 | rbf_centers = torch.linspace(v_min, v_max, n_bins, device=values.device) 189 | rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1]) 190 | rbf_std = (v_max - v_min) / n_bins 191 | v_expand = torch.unsqueeze(values, -1) 192 | z = (values.unsqueeze(-1) - rbf_centers) / rbf_std 193 | return torch.exp(-z ** 2) 194 | 195 | 196 | def norm(tensor, dim, eps=1e-8, keepdim=False): 197 | """ 198 | Returns L2 norm along a dimension. 199 | """ 200 | return torch.sqrt( 201 | torch.sum(torch.square(tensor), dim=dim, keepdim=keepdim) + eps) 202 | 203 | 204 | def normalize(tensor, dim=-1): 205 | """ 206 | Normalizes a tensor along a dimension after removing nans. 207 | """ 208 | return nan_to_num( 209 | torch.div(tensor, norm(tensor, dim=dim, keepdim=True)) 210 | ) 211 | 212 | #* 213 | class CoordBatchConverter(BatchConverter): 214 | def __call__(self, raw_batch: Sequence[Tuple[Sequence, str]], device=None): 215 | """ 216 | Args: 217 | raw_batch: List of tuples (coords, confidence, seq) 218 | In each tuple, 219 | coords: list of floats, shape L x 3 x 3 220 | confidence: list of floats, shape L; or scalar float; or None 221 | seq: string of length L 222 | Returns: 223 | coords: Tensor of shape batch_size x L x 3 x 3 224 | confidence: Tensor of shape batch_size x L 225 | strs: list of strings 226 | tokens: LongTensor of shape batch_size x L 227 | padding_mask: ByteTensor of shape batch_size x L 228 | """ 229 | self.alphabet.cls_idx = self.alphabet.get_idx("") 230 | batch = [] 231 | for coords, confidence, seq in raw_batch: 232 | if confidence is None: 233 | confidence = 1. 234 | if isinstance(confidence, float) or isinstance(confidence, int): 235 | confidence = [float(confidence)] * len(coords) 236 | if seq is None: 237 | seq = 'X' * len(coords) 238 | batch.append(((coords, confidence), seq)) 239 | 240 | coords_and_confidence, strs, tokens = super().__call__(batch) 241 | 242 | # pad beginning and end of each protein due to legacy reasons 243 | coords = [ 244 | F.pad(torch.tensor(cd), (0, 0, 0, 0, 1, 1), value=np.nan) #<---pad set to nan from np.inf 245 | for cd, _ in coords_and_confidence 246 | ] 247 | confidence = [ 248 | F.pad(torch.tensor(cf), (1, 1), value=-1.) 249 | for _, cf in coords_and_confidence 250 | ] 251 | coords = self.collate_dense_tensors(coords, pad_v=np.nan) 252 | confidence = self.collate_dense_tensors(confidence, pad_v=-1.) 253 | if device is not None: 254 | coords = coords.to(device) 255 | confidence = confidence.to(device) 256 | tokens = tokens.to(device) 257 | padding_mask = torch.isnan(coords[:,:,0,0]) 258 | coord_mask = torch.isfinite(coords.sum(-2).sum(-1)) 259 | confidence = confidence * coord_mask + (-1.) * padding_mask 260 | return coords, confidence, strs, tokens, padding_mask 261 | 262 | def from_lists(self, coords_list, confidence_list=None, seq_list=None, device=None): 263 | """ 264 | Args: 265 | coords_list: list of length batch_size, each item is a list of 266 | floats in shape L x 3 x 3 to describe a backbone 267 | confidence_list: one of 268 | - None, default to highest confidence 269 | - list of length batch_size, each item is a scalar 270 | - list of length batch_size, each item is a list of floats of 271 | length L to describe the confidence scores for the backbone 272 | with values between 0. and 1. 273 | seq_list: either None or a list of strings 274 | Returns: 275 | coords: Tensor of shape batch_size x L x 3 x 3 276 | confidence: Tensor of shape batch_size x L 277 | strs: list of strings 278 | tokens: LongTensor of shape batch_size x L 279 | padding_mask: ByteTensor of shape batch_size x L 280 | """ 281 | batch_size = len(coords_list) 282 | if confidence_list is None: 283 | confidence_list = [None] * batch_size 284 | if seq_list is None: 285 | seq_list = [None] * batch_size 286 | raw_batch = zip(coords_list, confidence_list, seq_list) 287 | return self.__call__(raw_batch, device) 288 | 289 | @staticmethod 290 | def collate_dense_tensors(samples, pad_v): 291 | """ 292 | Takes a list of tensors with the following dimensions: 293 | [(d_11, ..., d_1K), 294 | (d_21, ..., d_2K), 295 | ..., 296 | (d_N1, ..., d_NK)] 297 | and stack + pads them into a single tensor of: 298 | (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK}) 299 | """ 300 | if len(samples) == 0: 301 | return torch.Tensor() 302 | if len(set(x.dim() for x in samples)) != 1: 303 | raise RuntimeError( 304 | f"Samples has varying dimensions: {[x.dim() for x in samples]}" 305 | ) 306 | (device,) = tuple(set(x.device for x in samples)) # assumes all on same device 307 | max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])] 308 | result = torch.empty( 309 | len(samples), *max_shape, dtype=samples[0].dtype, device=device 310 | ) 311 | result.fill_(pad_v) 312 | for i in range(len(samples)): 313 | result_i = result[i] 314 | t = samples[i] 315 | result_i[tuple(slice(0, k) for k in t.shape)] = t 316 | return result -------------------------------------------------------------------------------- /bin/dms_enrichment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from biotite.sequence.io.fasta import FastaFile, get_sequences 4 | from Bio import pairwise2 5 | import pandas as pd 6 | import shutil 7 | import os 8 | import numpy as np 9 | import pandas as pd 10 | import bokeh.io 11 | import bokeh.plotting 12 | import bokeh.palettes 13 | from bokeh.transform import factor_cmap 14 | 15 | from bokeh.io import export_svg 16 | import iqplot 17 | from colorcet import glasbey_category10 18 | 19 | import subprocess 20 | 21 | bla_dict = { 22 | 'protein':'bla', 23 | 'pdbfile':'bla_1m40_a.pdb', 24 | 'chain':'A', 25 | 'dmsfile': 'dms_bla.csv', 26 | 'fitness_col': 'DMS_amp_2500_(b)', 27 | 'threshold' : 0.01 28 | } 29 | 30 | calm1_dict = { 31 | 'protein':'CALM1', 32 | 'pdbfile':'calm1_5v03_r.pdb', 33 | 'chain':'R', 34 | 'dmsfile': 'dms_calm1.csv', 35 | 'fitness_col':'DMS', 36 | 'threshold': 1 37 | } 38 | 39 | haeiiim_dict = { 40 | 'protein':'haeIIIM', 41 | 'pdbfile':'haeiiim_3ubt_b.pdb', 42 | 'chain':'B', 43 | 'dmsfile': 'dms_haeiiim.csv', 44 | 'fitness_col':'DMS_G3', 45 | 'threshold' : 0.01 46 | } 47 | 48 | gal4_dict = { 49 | 'protein':'GAL4', 50 | 'pdbfile':'gal4_3coq_b.pdb', 51 | 'chain':'B', 52 | 'dmsfile': 'dms_gal4.csv', 53 | 'fitness_col':'DMS_nonsel_24', 54 | 'threshold': 1 55 | } 56 | 57 | hras_dict = { 58 | 'protein':'HRAS', 59 | 'pdbfile':'hras_2ce2_x.pdb', 60 | 'chain':'X', 61 | 'dmsfile': 'dms_hras.csv', 62 | 'fitness_col':'DMS_unregulated', 63 | 'threshold': 1 64 | } 65 | 66 | mapk1_dict = { 67 | 'protein':'MAPK1', 68 | 'pdbfile':'mapk1_4zzn_a.pdb', 69 | 'chain':'A', 70 | 'dmsfile': 'dms_mapk1.csv', 71 | 'fitness_col':'DMS_VRT', 72 | 'threshold': 1 73 | } 74 | 75 | tpk1_dict = { 76 | 'protein':'TPK1', 77 | 'pdbfile':'tpk1_3s4y_a.pdb', 78 | 'chain':'A', 79 | 'dmsfile': 'dms_tpk1.csv', 80 | 'fitness_col':'DMS', 81 | 'threshold': 1 82 | } 83 | 84 | tpmt_dict = { 85 | 'protein':'TPMT', 86 | 'pdbfile':'tpmt_2bzg_a.pdb', 87 | 'chain':'A', 88 | 'dmsfile': 'dms_tpmt.csv', 89 | 'fitness_col':'DMS', 90 | 'threshold': 1 91 | } 92 | 93 | ube2i_dict = { 94 | 'protein':'UBE2I', 95 | 'pdbfile':'ube2i_5f6e_a.pdb', 96 | 'chain':'A', 97 | 'dmsfile': 'dms_tpmt.csv', 98 | 'fitness_col':'DMS', 99 | 'threshold': 1 100 | } 101 | 102 | ubi4_dict = { 103 | 'protein':'UBI4', 104 | 'pdbfile':'ubi4_4q5e_b.pdb', 105 | 'chain':'B', 106 | 'dmsfile': 'dms_ubi4.csv', 107 | 'fitness_col':'DMS_limiting_(b)', 108 | 'threshold': 1 109 | } 110 | 111 | studies = [ 112 | bla_dict, 113 | calm1_dict, 114 | gal4_dict, 115 | haeiiim_dict, 116 | hras_dict, 117 | tpmt_dict, 118 | tpk1_dict, 119 | mapk1_dict, 120 | ube2i_dict, 121 | ubi4_dict, 122 | ] 123 | 124 | def score_if(prot_dict): 125 | prefix = 'data/dms/' 126 | fasta_outpath = f"dms_libs/{prot_dict['pdbfile'][:-4]}_dmsLib.fasta" 127 | score_if_cmd = [ 128 | 'python', 129 | 'bin/score_log_likelihoods.py', 130 | prefix + f"structures/{prot_dict['pdbfile']}", 131 | '--seqpath', prefix + fasta_outpath, 132 | '--outpath', f"output/dms/if_results/dms_{prot_dict['protein']}_if.csv" 133 | ,'--chain', prot_dict['chain'] 134 | ] 135 | subprocess.run(score_if_cmd, check=True) 136 | 137 | def get_stats(prot_dict, n, top_per): 138 | 139 | top_per = top_per/100 140 | prefix = 'output/dms/' 141 | esm1v_outpath = 'dms_esm1vScored/' + f"dms_{prot_dict['protein']}" + '_maskMarginals.csv' 142 | if_results_df = pd.read_csv( prefix + f"if_results/dms_{prot_dict['protein']}_if.csv") 143 | dms_df = pd.read_csv(prefix + esm1v_outpath) 144 | 145 | unprofiled_df = dms_df[dms_df[prot_dict['fitness_col']].isna()] 146 | unprofiled_variants = unprofiled_df['variant'].to_list() 147 | 148 | dms_df = dms_df.dropna(subset = [prot_dict['fitness_col']]) 149 | pop_n = len(dms_df) 150 | if_results_df = if_results_df[~if_results_df['seqid'].isin(unprofiled_variants)] 151 | 152 | assert (len(if_results_df) == len(dms_df)) 153 | 154 | threshold = dms_df[prot_dict['fitness_col']].quantile(1-top_per) 155 | highFit_subset_variants = dms_df[dms_df[prot_dict['fitness_col']] >= threshold]['variant'] 156 | n_highFit = len(highFit_subset_variants) 157 | 158 | 159 | dms_df['esm1v'] = dms_df.loc[:, dms_df.columns.str.startswith('esm1v')].mean(axis=1) 160 | top_esm1v_df = dms_df.sort_values(by = 'esm1v', ascending = False)[:n] 161 | top_esm1v_variants = top_esm1v_df['variant'].to_list() 162 | 163 | top_if_df = if_results_df.sort_values(by = 'log_likelihood', ascending = False)[:n] 164 | top_if_variants = top_if_df['seqid'].to_list() 165 | 166 | esm1v_hits_df = top_esm1v_df[top_esm1v_df['variant'].isin(highFit_subset_variants)] 167 | if_hits_df = top_if_df[top_if_df['seqid'].isin(highFit_subset_variants)] 168 | 169 | n_esm1v_hits = len(esm1v_hits_df) 170 | n_if_hits = len(if_hits_df) 171 | 172 | if_hit_enrich = (n_if_hits / n) / top_per 173 | esm1v_hit_enrich = (n_esm1v_hits / n) / top_per 174 | 175 | if_hit_rate = (n_if_hits / n) 176 | esm1v_hit_rate = (n_esm1v_hits / n) 177 | 178 | return (prot_dict['protein'], pop_n, n_highFit, n_if_hits, n_esm1v_hits, if_hit_enrich, esm1v_hit_enrich, if_hit_rate, esm1v_hit_rate ) 179 | 180 | def plot_bars(melted_df): 181 | p = bokeh.plotting.figure( 182 | height=350, 183 | width=1100, 184 | y_axis_label="High Fitness \n Prediction Precision", 185 | x_axis_label = "Functional Percentile Threshold \n for High Fitness Classification", 186 | x_range=bokeh.models.FactorRange(*factors, group_padding = 1.2), 187 | tools="save", 188 | title='' 189 | ) 190 | p.output_backend = "svg" 191 | 192 | p.vbar( 193 | source=melted_df, 194 | x="cats", 195 | top="Hit Rate", 196 | width = 1, 197 | line_color='black', 198 | alpha = 'alpha', 199 | legend_field='legend_label', 200 | fill_color=bokeh.transform.factor_cmap( 201 | 'Method', 202 | palette=['#999999', '#43a2ca'], 203 | factors=list(melted_df['Method'].unique()), 204 | start=1, 205 | end=2 206 | ) 207 | ) 208 | 209 | p.xgrid.grid_line_color = None 210 | p.ygrid.grid_line_color = None 211 | p.x_range.range_padding = 0.03 212 | p.legend.location = "top_right" 213 | p.legend.spacing = 20 214 | p.legend.label_text_font_size = "11pt" 215 | p.legend.label_text_color = "black" 216 | p.legend.orientation = "horizontal" 217 | p.xaxis.major_label_orientation = 1.2 218 | p.xaxis.separator_line_alpha = 0 219 | p.xaxis.group_text_font_size = '11pt' 220 | p.xaxis.subgroup_text_font_size = '0pt' 221 | p.xaxis.axis_label_text_font_size = '11pt' 222 | p.xaxis.major_label_text_font_size = '10pt' 223 | p.xaxis.axis_label_text_font_style = 'normal' 224 | p.xaxis.major_label_text_color = 'black' 225 | p.xaxis.axis_label_text_color = 'black' 226 | 227 | p.yaxis.axis_label_text_font_size = '11pt' 228 | p.yaxis.axis_label_text_font_style = 'normal' 229 | p.yaxis.axis_label_text_color = 'black' 230 | p.yaxis.major_label_text_font_size = '10pt' 231 | p.xaxis.group_text_color = 'dimgrey' 232 | 233 | return p 234 | 235 | def plot_ecdf_hits(prot_dict, n, top_per_lst): 236 | 237 | #use least stringent threshold to plot all hits 238 | top_per = max(top_per_lst)/100 239 | prefix = 'output/dms/' 240 | 241 | esm1v_outpath = 'dms_esm1vScored/' + f"dms_{prot_dict['protein']}" + '_maskMarginals.csv' 242 | if_results_df = pd.read_csv( prefix + f"if_results/dms_{prot_dict['protein']}_if.csv") 243 | dms_df = pd.read_csv(prefix + esm1v_outpath) 244 | 245 | 246 | unprofiled_df = dms_df[dms_df[prot_dict['fitness_col']].isna()] 247 | unprofiled_variants = unprofiled_df['variant'].to_list() 248 | 249 | dms_df = dms_df.dropna(subset = [prot_dict['fitness_col']]) 250 | dms_df['percentile'] = dms_df[prot_dict['fitness_col']].rank(pct=True, method='average') 251 | dms_df['zscore'] = (dms_df[prot_dict['fitness_col']] - dms_df[prot_dict['fitness_col']].mean())/dms_df[prot_dict['fitness_col']].std() 252 | if_results_df = if_results_df[~if_results_df['seqid'].isin(unprofiled_variants)] 253 | 254 | assert (len(if_results_df) == len(dms_df)) 255 | 256 | threshold = dms_df[prot_dict['fitness_col']].quantile(1-top_per) 257 | highFit_subset_variants = dms_df[dms_df[prot_dict['fitness_col']] >= threshold]['variant'] 258 | n_highFit = len(highFit_subset_variants) 259 | 260 | dms_df['esm1v'] = dms_df.loc[:, dms_df.columns.str.startswith('esm1v')].mean(axis=1) 261 | top_esm1v_df = dms_df.sort_values(by = 'esm1v', ascending = False)[:n] 262 | top_esm1v_variants = top_esm1v_df['variant'].to_list() 263 | 264 | top_if_df = if_results_df.sort_values(by = 'log_likelihood', ascending = False)[:n] 265 | top_if_variants = top_if_df['seqid'].to_list() 266 | 267 | esm1v_hits_df = top_esm1v_df[top_esm1v_df['variant'].isin(highFit_subset_variants)] 268 | if_hits_df = top_if_df[top_if_df['seqid'].isin(highFit_subset_variants)] 269 | if_hits_df = pd.merge(if_hits_df, dms_df[['variant', 'zscore', 'percentile', prot_dict['fitness_col']]], left_on='seqid', right_on='variant', how='left') 270 | 271 | p_1v = iqplot.ecdf( 272 | data=dms_df, 273 | q=prot_dict['fitness_col'], 274 | style="staircase", 275 | palette = 'lightgrey', 276 | line_kwargs= {'line_width':6}, 277 | x_axis_label = prot_dict['protein']+' Fitness Measure', 278 | #title = prot_dict['protein'] 279 | ) 280 | p_if = iqplot.ecdf( 281 | data=dms_df, 282 | q=prot_dict['fitness_col'], 283 | style="staircase", 284 | palette = 'lightgrey', 285 | line_kwargs= {'line_width':6}, 286 | x_axis_label = None, #prot_dict['protein']+'Fitness Measure', 287 | title = prot_dict['protein'] 288 | ) 289 | colors = ['#1f78b4', '#33a02c', '#ff7f00',] 290 | for i, top_per in enumerate(top_per_lst): 291 | top_per = top_per/100 292 | hline = bokeh.models.Span(location=1-top_per, dimension='width', line_color=colors[i], line_dash='dashed', line_width=4, ) 293 | p_1v.add_layout(hline) 294 | p_if.add_layout(hline) 295 | 296 | 297 | p_1v.circle( 298 | source = esm1v_hits_df, 299 | x = prot_dict['fitness_col'], 300 | y = 'percentile', 301 | color = 'dimgrey', 302 | size = 11, 303 | line_color = 'lightgrey', 304 | alpha = 0.8, 305 | ) 306 | 307 | p_if.circle( 308 | source = if_hits_df, 309 | x = prot_dict['fitness_col'], 310 | y = 'percentile', 311 | color = '#43a2ca', 312 | size = 11, 313 | line_color = 'white', 314 | alpha = 0.8, 315 | ) 316 | 317 | p_if.title.text_font_style = 'normal' 318 | p_if.title.text_color = 'black' 319 | p_if.title.align = 'center' 320 | p_if.title.text_font_size = '16pt' 321 | p_1v.xaxis.axis_label_text_font_size = '14pt' 322 | 323 | 324 | for p in [p_if, p_1v]: 325 | p.xaxis.axis_label_text_font_style = 'normal' 326 | p.yaxis.axis_label_text_font_style = 'normal' 327 | p.xaxis.axis_label_text_font_size = '14pt' 328 | p.yaxis.axis_label_text_font_size = '14pt' 329 | p.xaxis.axis_label_text_color = 'black' 330 | p.yaxis.axis_label_text_color = 'black' 331 | p.xaxis.major_label_text_font_size = '14pt' 332 | p.yaxis.major_label_text_font_size = '14pt' 333 | p.xaxis.major_label_text_color = 'black' 334 | p.yaxis.major_label_text_color = 'black' 335 | p.output_backend = "svg" 336 | 337 | return [p_if, p_1v] 338 | 339 | if __name__ == '__main__': 340 | for s in studies: 341 | score_if(s) 342 | shutil.copytree("data/dms/dms_esm1vScored", "output/dms/dms_esm1vScored", dirs_exist_ok=True) 343 | 344 | name, n_pop, n_popHits, n_if_hits, n_esm1v_hits, if_hit_enrich, esm1v_hit_enrich, if_hit_rate, esm1v_hit_rate, p_threshold = ([] for _ in range(10)) 345 | data_lists = [name, n_pop, n_popHits, n_if_hits, n_esm1v_hits, if_hit_enrich, esm1v_hit_enrich, if_hit_rate, esm1v_hit_rate, p_threshold] 346 | for s in studies: 347 | for p in [5, 10, 20 ]: 348 | outputs = get_stats(s, 10, p) 349 | for i in range(len(outputs)): 350 | data_lists[i].append(outputs[i]) 351 | data_lists[len(data_lists)-1].append(p) 352 | col_names = ['Protein', 'Total Library Size', 'Library Hits (Variants with Fitness >95th)', 'Inverse Folding Hits (in Top 10)', 'ESM1v Hits (in Top 10)', 'Inverse Folding Hit Enrichment', 'ESM1v Hit Enrichment','Inverse Folding Hit Rate', 'ESM1v Hit Rate','Percentile Threshold' ] 353 | results_df = pd.DataFrame({col_names[i]: data_lists[i] for i in range(len(data_lists))}) 354 | # Melt the dataframe 355 | melted_df = results_df.melt(id_vars=['Protein', 'Percentile Threshold'], value_vars=[ 'ESM1v Hit Rate', 'Inverse Folding Hit Rate',], 356 | var_name='Method', value_name='Hit Rate') 357 | melted_df['Method'] = melted_df['Method'].replace({ 358 | 'Inverse Folding Hit Rate': 'Inverse Folding', 359 | 'ESM1v Hit Rate': 'Language Model'}) 360 | melted_df['alpha'] = (1/melted_df['Percentile Threshold'].to_numpy())*3 + 0.4 361 | melted_df['Percentile Threshold'] = melted_df['Percentile Threshold'].astype(str) 362 | melted_df['legend_label'] = ['Structure-Informed Language Model' if x=='Inverse Folding' else 'Language Model' for x in melted_df['Method']] 363 | 364 | melted_df['cats'] = melted_df.apply(lambda x: (x["Protein"], x["Method"], x['Percentile Threshold'],), axis = 1) 365 | factors = list(melted_df.cats) 366 | 367 | p = plot_bars(melted_df) 368 | comparePrecision_fname = 'output/dms/precision_comparison.html' 369 | bokeh.plotting.output_file(comparePrecision_fname) 370 | bokeh.io.show(p) 371 | 372 | ecdf_plots = [] 373 | for s in studies: 374 | ecdf_plots.extend(plot_ecdf_hits(s, 10, [20,10,5])) 375 | mid = int(len(ecdf_plots)/2) 376 | ecdf_fname = 'output/dms/fitness_ecdfs.html' 377 | bokeh.plotting.output_file(ecdf_fname) 378 | bokeh.io.show(bokeh.layouts.gridplot([ecdf_plots[:mid:2], 379 | ecdf_plots[1:mid:2], 380 | ecdf_plots[mid::2], 381 | ecdf_plots[mid+1::2]]) 382 | ) 383 | -------------------------------------------------------------------------------- /bin/plot_esm1v_benchmarks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import bokeh.io 4 | import bokeh.plotting 5 | import bokeh.palettes 6 | from bokeh.transform import factor_cmap 7 | import datashader 8 | import holoviews as hv 9 | import holoviews.operation.datashader 10 | hv.extension("bokeh") 11 | 12 | import warnings 13 | # Suppress FutureWarning messages 14 | warnings.simplefilter(action='ignore',) 15 | 16 | cr9114_dict = { 17 | 'ab_name' : 'CR9114', 18 | 'files' : { 19 | 'ESM-1v Ab-Ag': 'output/ab_mutagenesis_expts/cr9114/cr9114_esm1vAbAg_exp_data_maskMargLabeled.csv', 20 | 'ESM-1v Ab only': 'output/ab_mutagenesis_expts/cr9114/cr9114_esm1vbothchains_exp_data_maskMargLabeled.csv', 21 | 'ESM-1v Ab VH only': 'output/ab_mutagenesis_expts/cr9114/cr9114_exp_data_maskMargLabeled.csv', 22 | }, 23 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/cr9114/cr9114_exp_data.csv', 24 | dtype = {'genotype': str}, 25 | ).rename(columns={'h1_mean': 'H1', 'h3_mean' : 'H3'}), 26 | 'ag_columns': ['H1', 'H3'], 27 | 'expt_type': 'Combinatorial Mutagenesis for Affinity', 28 | 'palette': bokeh.palettes.Spectral6 29 | } 30 | 31 | cr6261_dict = { 32 | 'ab_name' : 'CR6261', 33 | 'files' : { 34 | 'ESM-1v Ab-Ag': 'output/ab_mutagenesis_expts/cr6261/cr6261_esm1vAbAg_exp_data_maskMargLabeled.csv', 35 | 'ESM-1v Ab only': 'output/ab_mutagenesis_expts/cr6261/cr6261_esm1vbothchains_exp_data_maskMargLabeled.csv', 36 | 'ESM-1v Ab VH only': 'output/ab_mutagenesis_expts/cr6261/cr6261_exp_data_maskMargLabeled.csv', 37 | }, 38 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/cr6261/cr6261_exp_data.csv', 39 | dtype={'genotype': str}, 40 | ).rename(columns={'h1_mean': 'H1', 'h9_mean' : 'H9'}), 41 | 'ag_columns': ['H1', 'H9'], 42 | 'expt_type': 'Combinatorial Mutagenesis for Affinity', 43 | 'palette': bokeh.palettes.Dark2_6, 44 | } 45 | 46 | g6LC_dict = { 47 | 'ab_name' : 'g6', 48 | 'files' : { 49 | 'LM Ab-Ag': 'output/ab_mutagenesis_expts/g6/g6Lc_esm1vAbAg_exp_data_maskMargLabeled.csv', 50 | 'LM Ab only': 'output/ab_mutagenesis_expts/g6/g6Lc_esm1vbothchains_exp_data_maskMargLabeled.csv', 51 | 'LM Ab VH/VL only': 'output/ab_mutagenesis_expts/g6/g6Lc_exp_data_maskMargLabeled.csv', 52 | }, 53 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/g6/g6_lc_exp_data.csv'), 54 | 'ag_columns': ['norm_binding'], 55 | 'expt_type': 'Deep Mutational Scan for Binding', 56 | 'palette': bokeh.palettes.Pastel1_6, 57 | 'chain': 'VL' 58 | } 59 | 60 | g6HC_dict = { 61 | 'ab_name' : 'g6', 62 | 'files' : { 63 | 'LM Ab-Ag': 'output/ab_mutagenesis_expts/g6/g6Hc_esm1vAbAg_exp_data_maskMargLabeled.csv', 64 | 'LM Ab only': 'output/ab_mutagenesis_expts/g6/g6Hc_esm1vbothchains_exp_data_maskMargLabeled.csv', 65 | 'LM Ab VH/VL only': 'output/ab_mutagenesis_expts/g6/g6Hc_exp_data_maskMargLabeled.csv', 66 | }, 67 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/g6/g6_hc_exp_data.csv'), 68 | 'ag_columns': ['norm_binding'], 69 | 'expt_type': 'Deep Mutational Scan for Binding', 70 | 'palette': bokeh.palettes.Pastel1_4, 71 | 'chain': 'VH' 72 | } 73 | 74 | def apply_mask_and_average(scores, genotype): 75 | #scores here should already be a in log space 76 | 77 | if '1' in genotype: 78 | masked_list = [] 79 | for s, mask_char in zip(scores, genotype): 80 | if mask_char == '1': 81 | masked_list.append(s) 82 | multi_avg = np.mean(masked_list) 83 | else: 84 | #if genotype is all zeros = wt 85 | multi_avg = 0 86 | 87 | return multi_avg 88 | 89 | 90 | def transform_single_to_multi( dms_df, singleMuts_df, condition, sort_col, ascending ): 91 | multi_scores = [] 92 | #use input parameter as method bc key is kwarg in pd.sort_values 93 | #method should be either esm1v or abysis or AbLang 94 | 95 | #for cr sorts the single mutation dataframe in residue order, ie binary '10000' before '00010'. ascending = False 96 | #for g6 sorts the single mutation dataframe in position order, ie residue index. ascending = True 97 | 98 | singleMuts_sorted = singleMuts_df.sort_values(sort_col, key=lambda x: x.astype(int), ascending=ascending) 99 | singles_scores = singleMuts_sorted[condition].to_list() 100 | 101 | for g in dms_df['genotype']: 102 | multi_scores.append(apply_mask_and_average(singles_scores, g )) 103 | 104 | return multi_scores 105 | 106 | #get melted correlations matrix of structure input x target antigen, for a given antibody 107 | def get_corr(ab_name, files, dms_df, ag_columns, dropLOQ = False ): 108 | 109 | #retreive correlations from abysis and esm1v scores 110 | for key, filepath in files.items(): 111 | singleMuts_df = pd.read_csv(filepath) 112 | # Get the columns for all esm1v models 113 | esm_columns = [col for col in singleMuts_df.columns if col.startswith('esm')] 114 | # Calculate the average for each esm column 115 | esm_avg_values = singleMuts_df[esm_columns].mean(axis=1) 116 | singleMuts_df[key] = esm_avg_values 117 | if ab_name == 'g6': 118 | #g6 data set is only single mutations 119 | dms_df[key] = esm_avg_values 120 | else: 121 | #cr datasets are multiple mutations 122 | dms_df[key] = transform_single_to_multi(dms_df, singleMuts_df, key, 'genotype', ascending = False) 123 | 124 | 125 | conditions = list(files.keys()) 126 | 127 | if not dropLOQ: 128 | correlations = dms_df[ag_columns + conditions].corr(method='spearman') 129 | correlations = correlations.drop(conditions, axis= 1) 130 | correlations = correlations.drop(ag_columns, axis= 0) 131 | 132 | # Melt the correlations DataFrame and rename the index column 133 | melted_correlations = ( 134 | correlations 135 | .reset_index() 136 | .melt(id_vars='index', var_name='Target', value_name='Correlation') 137 | .rename(columns={'index': 'Input'}) 138 | ) 139 | melted_correlations['Method'] = ['LM' for i in melted_correlations['Input']] 140 | 141 | return melted_correlations 142 | #if we wish to drop the samples on the lower limit of quantitation, drop samples from each ag group individually and compute correlation with IF log_like 143 | else: 144 | all_ag_melted = pd.DataFrame({}) 145 | for ag in ag_columns: 146 | filt_df = dms_df[dms_df[ag] > min(dms_df[ag])] 147 | 148 | correlations = filt_df[[ag] + conditions].corr(method='spearman') 149 | correlations = correlations.drop(conditions, axis= 1) 150 | correlations = correlations.drop(ag, axis= 0) 151 | 152 | # Melt the correlations DataFrame and rename the index column 153 | melted_correlations = ( 154 | correlations 155 | .reset_index() 156 | .melt(id_vars='index', var_name='Target', value_name='Correlation') 157 | ).rename(columns={'index': 'Input'}) 158 | 159 | all_ag_melted = pd.concat([all_ag_melted, melted_correlations], ignore_index= True) 160 | all_ag_melted['Method'] = ['LM' for i in all_ag_melted['Input']] 161 | 162 | return all_ag_melted 163 | 164 | #get melted correlations matrix of structure input x target antigen, for a given antibody 165 | def get_g6_corr(g6Hc_dict, g6LC_dict, dropLOQ = False ): 166 | 167 | ab_name = 'g6' 168 | ag_columns = g6LC_dict['ag_columns'] 169 | vh_and_vl_dms_df = pd.DataFrame({}) 170 | 171 | for d in [g6HC_dict, g6LC_dict]: 172 | 173 | files = d['files'] 174 | dms_df = d['dms_df'] 175 | 176 | #retreive correlations from abysis, ablang, and esm1v scores 177 | for key, filepath in files.items(): 178 | singleMuts_df = pd.read_csv(filepath) 179 | # Get the columns for all esm1v models 180 | esm_columns = [col for col in singleMuts_df.columns if col.startswith('esm')] 181 | # Calculate the average for each esm column 182 | esm_avg_values = singleMuts_df[esm_columns].mean(axis=1) 183 | singleMuts_df[key] = esm_avg_values 184 | if ab_name == 'g6': 185 | #g6 data set is only single mutations 186 | dms_df[key] = esm_avg_values 187 | 188 | 189 | vh_and_vl_dms_df = pd.concat([vh_and_vl_dms_df, dms_df], ignore_index= True) 190 | 191 | conditions = list(files.keys()) 192 | 193 | # 194 | if not dropLOQ: 195 | correlations = vh_and_vl_dms_df[ag_columns + conditions].corr(method='spearman') 196 | correlations = correlations.drop(conditions, axis= 1) 197 | correlations = correlations.drop(ag_columns, axis= 0) 198 | 199 | # Melt the correlations DataFrame and rename the index column 200 | melted_correlations = ( 201 | correlations 202 | .reset_index() 203 | .melt(id_vars='index', var_name='Target', value_name='Correlation') 204 | .rename(columns={'index': 'Input'}) 205 | ) 206 | melted_correlations['Method'] = ['LM' for i in melted_correlations['Input']] 207 | melted_correlations['Target'] = ['VEGF-A'] * len(melted_correlations) 208 | 209 | return melted_correlations 210 | #if we wish to drop the samples on the lower limit of quantitation, drop samples from each ag group individually and compute correlation with IF log_like 211 | else: 212 | all_ag_melted = pd.DataFrame({}) 213 | for ag in ag_columns: 214 | filt_df = vh_and_vl_dms_df[vh_and_vl_dms_df[ag] > min(vh_and_vl_dms_df[ag])] 215 | 216 | correlations = filt_df[[ag] + conditions].corr(method='spearman') 217 | correlations = correlations.drop(conditions, axis= 1) 218 | correlations = correlations.drop(ag, axis= 0) 219 | 220 | # Melt the correlations DataFrame and rename the index column 221 | melted_correlations = ( 222 | correlations 223 | .reset_index() 224 | .melt(id_vars='index', var_name='Target', value_name='Correlation') 225 | ).rename(columns={'index': 'Input'}) 226 | 227 | all_ag_melted = pd.concat([all_ag_melted, melted_correlations], ignore_index= True) 228 | all_ag_melted['Method'] = ['LM' for i in melted_correlations['Input']] 229 | melted_correlations['Target'] = ['VEGF-A'] * len(melted_correlations) 230 | 231 | return all_ag_melted 232 | 233 | #plot correlation bars for cr antibodies: cr6261 and cr9114 234 | def plot_hbar(title, melted_correlations, palette = bokeh.palettes.Spectral6, compare = 'model'): 235 | 236 | if compare == 'model': 237 | melted_correlations.replace('Ab-Ag','InverseFolding', inplace = True) 238 | color_by = 'Input' 239 | elif compare == 'IF': 240 | color_by = 'Method' 241 | else: 242 | color_by = 'Target' 243 | 244 | 245 | melted_correlations['cats'] = melted_correlations.apply(lambda x: (x["Target"], x["Input"]), axis = 1) 246 | factors = list(melted_correlations.cats)[::-1] 247 | 248 | p = bokeh.plotting.figure( 249 | height=340, 250 | width=440, 251 | x_axis_label="Spearman Correlation", 252 | x_range=[0, 1], 253 | y_range=bokeh.models.FactorRange(*factors), 254 | tools="save", 255 | title = title 256 | ) 257 | 258 | 259 | p.hbar( 260 | source=melted_correlations, 261 | y="cats", 262 | right="Correlation", 263 | height=0.6, 264 | line_color = 'black', 265 | legend_field = color_by, 266 | # use the palette to colormap based on the the x[1:2] values 267 | # fill_color=factor_cmap( 268 | # color_by, 269 | # palette= palette, 270 | # factors = list(melted_correlations[color_by].unique()), 271 | # start=1, 272 | # end=3) 273 | fill_color=bokeh.palettes.Dark2_6[1] 274 | 275 | ) 276 | 277 | #labels_df = melted_correlations[melted_correlations['Input'].isin(['Ab-Ag','ESM-1v', 'abYsis', 'InverseFold'])] 278 | labels_df = melted_correlations 279 | labels_df['corr_str'] = labels_df['Correlation'].apply(lambda x: round(x, 2)).astype(str) 280 | labels_source = bokeh.models.ColumnDataSource(labels_df) 281 | 282 | labels = bokeh.models.LabelSet(x='Correlation', y='cats', text='corr_str',text_font_size = "10px", 283 | x_offset=12, y_offset=-5, source=labels_source, render_mode='canvas') 284 | 285 | p.ygrid.grid_line_color = None 286 | p.y_range.range_padding = 0.1 287 | p.add_layout(labels) 288 | p.legend.visible = False 289 | 290 | p.output_backend = "svg" 291 | return p 292 | 293 | if __name__ == '__main__': 294 | datasets = [cr9114_dict, cr6261_dict, (g6HC_dict, g6LC_dict) ] 295 | 296 | #compute with and without points on lower limit of quantitation ignored 297 | for dropLOQ in [False]: 298 | 299 | all_corr_plots = [] 300 | 301 | for d in datasets: 302 | if type(d) is tuple: 303 | g6Hc, g6Lc = d 304 | title = g6Hc['ab_name'] + ', ' + g6Hc['expt_type'] 305 | g6_combined = get_g6_corr(g6Hc, g6Lc, dropLOQ= dropLOQ) 306 | 307 | all_corr_plots.append(plot_hbar(title, g6_combined, g6Lc['palette'], 'IF')) 308 | 309 | else: 310 | corr_df = get_corr(d['ab_name'], d['files'], d['dms_df'], d['ag_columns'], dropLOQ= dropLOQ) 311 | title = d['ab_name'] + ', ' + d['expt_type'] 312 | all_corr_plots.append(plot_hbar(title, corr_df, d['palette'], 'Target')) 313 | 314 | 315 | all_corr_fname = f"output/ab_mutagenesis_expts/esm1v_benchmarks{'_dropLOQ' if dropLOQ else ''}.html" 316 | bokeh.plotting.output_file(all_corr_fname) 317 | bokeh.io.show(bokeh.layouts.gridplot(all_corr_plots, ncols = len(all_corr_plots))) 318 | -------------------------------------------------------------------------------- /bin/plot_ab-binding_benchmarks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import bokeh.io 4 | import bokeh.plotting 5 | import bokeh.palettes 6 | from bokeh.transform import factor_cmap 7 | import datashader 8 | import holoviews as hv 9 | import holoviews.operation.datashader 10 | hv.extension("bokeh") 11 | 12 | import warnings 13 | # Suppress FutureWarning messages 14 | warnings.simplefilter(action='ignore',) 15 | 16 | # Your Pandas code here 17 | 18 | if_conds = ['ag+ab', 'ab only', 'ab'] 19 | 20 | cr9114_dict = { 21 | 'ab_name' : 'CR9114', 22 | 'files' : { 23 | 'Ab-Ag': 'output/ab_mutagenesis_expts/cr9114/4fqi_ablh_scores.csv', 24 | 'Ab only': 'output/ab_mutagenesis_expts/cr9114/4fqi_lh_scores.csv', 25 | 'Ab VH only': 'output/ab_mutagenesis_expts/cr9114/4fqi_h_scores.csv', 26 | 'ESM-1v': 'output/ab_mutagenesis_expts/cr9114/cr9114_exp_data_maskMargLabeled.csv', 27 | 'AbLang': 'output/ab_mutagenesis_expts/cr9114/cr9114_hc_ablangScores.csv', 28 | 'abYsis': 'output/ab_mutagenesis_expts/cr9114/abysis_counts_cr9114_vh.txt', 29 | }, 30 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/cr9114/cr9114_exp_data.csv', 31 | dtype = {'genotype': str}, 32 | ).rename(columns={'h1_mean': 'H1', 'h3_mean' : 'H3'}), 33 | 'ag_columns': ['H1', 'H3'], 34 | 'expt_type': 'Combinatorial Mutagenesis for Affinity', 35 | 'palette': bokeh.palettes.Spectral6 36 | } 37 | 38 | cr6261_dict = { 39 | 'ab_name' : 'CR6261', 40 | 'files' : { 41 | 'Ab-Ag': 'output/ab_mutagenesis_expts/cr6261/3gbn_ablh_scores.csv', 42 | 'Ab only': 'output/ab_mutagenesis_expts/cr6261/3gbn_lh_scores.csv', 43 | 'Ab VH only': 'output/ab_mutagenesis_expts/cr6261/3gbn_h_scores.csv', 44 | 'ESM-1v': 'output/ab_mutagenesis_expts/cr6261/cr6261_exp_data_maskMargLabeled.csv', 45 | 'AbLang': 'output/ab_mutagenesis_expts/cr6261/cr6261_hc_ablangScores.csv', 46 | 'abYsis': 'output/ab_mutagenesis_expts/cr6261/abysis_counts_cr6261_vh.txt', 47 | }, 48 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/cr6261/cr6261_exp_data.csv', 49 | dtype={'genotype': str}, 50 | ).rename(columns={'h1_mean': 'H1', 'h9_mean' : 'H9'}), 51 | 'ag_columns': ['H1', 'H9'], 52 | 'expt_type': 'Combinatorial Mutagenesis for Affinity', 53 | 'palette': bokeh.palettes.Dark2_6, 54 | } 55 | 56 | 57 | g6LC_dict = { 58 | 'ab_name' : 'g6', 59 | 'files' : { 60 | 'Ab-Ag': 'output/ab_mutagenesis_expts/g6/2fjg_vlh_lc_scores.csv', 61 | 'Ab only': 'output/ab_mutagenesis_expts/g6/2fjg_lh_lc_scores.csv', 62 | 'Ab VH/VL only': 'output/ab_mutagenesis_expts/g6/2fjg_l_lc_scores.csv', 63 | 'ESM-1v': 'output/ab_mutagenesis_expts/g6/g6Lc_exp_data_maskMargLabeled.csv', 64 | 'AbLang': 'output/ab_mutagenesis_expts/g6/g6_lc_ablangScores.csv', 65 | 'abYsis': 'output/ab_mutagenesis_expts/g6/abysis_counts_g6_vl.txt', 66 | }, 67 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/g6/g6_lc_exp_data.csv'), 68 | 'ag_columns': ['norm_binding'], 69 | 'expt_type': 'Deep Mutational Scan for Binding', 70 | 'palette': bokeh.palettes.Pastel1_6, 71 | 'chain': 'VL' 72 | } 73 | 74 | g6HC_dict = { 75 | 'ab_name' : 'g6', 76 | 'files' : { 77 | 'Ab-Ag': 'output/ab_mutagenesis_expts/g6/2fjg_vlh_hc_scores.csv', 78 | 'Ab only': 'output/ab_mutagenesis_expts/g6/2fjg_lh_hc_scores.csv', 79 | 'Ab VH/VL only': 'output/ab_mutagenesis_expts/g6/2fjg_h_hc_scores.csv', 80 | 'ESM-1v': 'output/ab_mutagenesis_expts/g6/g6Hc_exp_data_maskMargLabeled.csv', 81 | 'AbLang': 'output/ab_mutagenesis_expts/g6/g6_hc_ablangScores.csv', 82 | 'abYsis': 'output/ab_mutagenesis_expts/g6/abysis_counts_g6_vh.txt', 83 | }, 84 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/g6/g6_hc_exp_data.csv'), 85 | 'ag_columns': ['norm_binding'], 86 | 'expt_type': 'Deep Mutational Scan for Binding', 87 | 'palette': bokeh.palettes.Pastel1_4, 88 | 'chain': 'VH' 89 | } 90 | def apply_mask_and_average(scores, genotype): 91 | #scores here should already be a in log space 92 | 93 | if '1' in genotype: 94 | masked_list = [] 95 | for s, mask_char in zip(scores, genotype): 96 | if mask_char == '1': 97 | masked_list.append(s) 98 | multi_avg = np.mean(masked_list) 99 | else: 100 | #if genotype is all zeros = wt 101 | multi_avg = 0 102 | 103 | return multi_avg 104 | 105 | 106 | def transform_single_to_multi( dms_df, singleMuts_df, condition, sort_col, ascending ): 107 | multi_scores = [] 108 | #use input parameter as method bc key is kwarg in pd.sort_values 109 | #method should be either esm1v or abysis or AbLang 110 | 111 | #for cr sorts the single mutation dataframe in residue order, ie binary '10000' before '00010'. ascending = False 112 | #for g6 sorts the single mutation dataframe in position order, ie residue index. ascending = True 113 | 114 | singleMuts_sorted = singleMuts_df.sort_values(sort_col, key=lambda x: x.astype(int), ascending=ascending) 115 | singles_scores = singleMuts_sorted[condition].to_list() 116 | 117 | for g in dms_df['genotype']: 118 | multi_scores.append(apply_mask_and_average(singles_scores, g )) 119 | 120 | return multi_scores 121 | 122 | def get_abysis_single_scores_df(singleMut_ids, abysis_df): 123 | single_scores = [] 124 | all_pos = [] 125 | 126 | for id in singleMut_ids: 127 | wt, pos, mt = id[0], int(id[1:-1]), id[-1] 128 | likelihood_ratio = abysis_df.loc[(abysis_df['pos'] == pos) & (abysis_df['wt'] == wt) & (abysis_df['mt'] == mt)]['likelihood_ratio'].to_list()[0] 129 | log_like_ratio = np.log10(likelihood_ratio) 130 | single_scores.append(log_like_ratio) 131 | all_pos.append(pos) 132 | 133 | singleMuts_df = pd.DataFrame({'pos': all_pos, 'abYsis': single_scores}) 134 | 135 | 136 | return singleMuts_df 137 | 138 | def get_ablang_single_scores_df(singleMut_ids, ablang_df): 139 | single_scores = [] 140 | all_pos = [] 141 | 142 | for id in singleMut_ids: 143 | wt, pos, mt = id[0], int(id[1:-1]), id[-1] 144 | log_likelihood_ratio = ablang_df.loc[(ablang_df['pos'] == pos) & (ablang_df['wt'] == wt) & (ablang_df['mt'] == mt)]['log_likelihood_ratio'].to_list()[0] 145 | single_scores.append(log_likelihood_ratio) 146 | all_pos.append(pos) 147 | 148 | singleMuts_df = pd.DataFrame({'pos': all_pos, 'AbLang': single_scores}) 149 | 150 | 151 | return singleMuts_df 152 | 153 | 154 | #get melted correlations matrix of structure input x target antigen, for a given antibody 155 | def get_corr(ab_name, files, dms_df, ag_columns, dropLOQ = False ): 156 | 157 | inv_fold_files = {key: value for key, value in files.items() if key not in ['ESM-1v', 'AbLang', 'abYsis']} 158 | other_files = {key: value for key, value in files.items() if key in ['ESM-1v', 'AbLang', 'abYsis']} 159 | 160 | #retreive correlations from inverse fold scores 161 | for key, filepath in inv_fold_files.items(): 162 | df = pd.read_csv(filepath) # Read the CSV file into a DataFrame 163 | column_label = key 164 | log_likelihood_column = df['log_likelihood'] # Extract the 'log_likelihood' column 165 | dms_df[column_label] = log_likelihood_column 166 | 167 | #retreive correlations from abysis and esm1v scores 168 | for key, filepath in other_files.items(): 169 | if key == 'ESM-1v': 170 | singleMuts_df = pd.read_csv(filepath) 171 | # Get the columns for all esm1v models 172 | esm_columns = [col for col in singleMuts_df.columns if col.startswith('esm')] 173 | # Calculate the average for each esm column 174 | esm_avg_values = singleMuts_df[esm_columns].mean(axis=1) 175 | singleMuts_df[key] = esm_avg_values 176 | if ab_name == 'g6': 177 | #g6 data set is only single mutations 178 | dms_df[key] = esm_avg_values 179 | else: 180 | #cr datasets are multiple mutations 181 | dms_df[key] = transform_single_to_multi(dms_df, singleMuts_df, key, 'genotype', ascending = False) 182 | elif key == 'abYsis': 183 | abysis_df = pd.read_csv(filepath, sep = '\t', dtype = {'pos': int}) 184 | if ab_name == 'g6': 185 | dms_df[key] = get_abysis_single_scores_df(dms_df['mutant'], abysis_df)[key] 186 | else: 187 | #cr data set are multiple mutations. use same scoring strategy as esm1v paper 188 | singleMut_ids = pd.read_csv(files['ESM-1v'])['mutant'].to_list() 189 | singleMuts_df = get_abysis_single_scores_df(singleMut_ids, abysis_df) 190 | dms_df[key] = transform_single_to_multi(dms_df, singleMuts_df, key, 'pos', ascending = True) 191 | elif key == 'AbLang': 192 | ablang_df = pd.read_csv(filepath, dtype = {'pos': int}) 193 | if ab_name == 'g6': 194 | dms_df[key] = get_ablang_single_scores_df(dms_df['mutant'], ablang_df)[key] 195 | else: 196 | #cr data set are multiple mutations. use same scoring strategy as esm1v paper 197 | singleMut_ids = pd.read_csv(files['ESM-1v'])['mutant'].to_list() 198 | singleMuts_df = get_ablang_single_scores_df(singleMut_ids, ablang_df) 199 | dms_df[key] = transform_single_to_multi(dms_df, singleMuts_df, key, 'pos', ascending = True) 200 | 201 | conditions = list(files.keys()) 202 | 203 | # 204 | if not dropLOQ: 205 | correlations = dms_df[ag_columns + conditions].corr(method='spearman') 206 | correlations = correlations.drop(conditions, axis= 1) 207 | correlations = correlations.drop(ag_columns, axis= 0) 208 | 209 | # Melt the correlations DataFrame and rename the index column 210 | melted_correlations = ( 211 | correlations 212 | .reset_index() 213 | .melt(id_vars='index', var_name='Target', value_name='Correlation') 214 | .rename(columns={'index': 'Input'}) 215 | ) 216 | melted_correlations['Method'] = ['LM' if (('ESM' in i) or ('AbLang' in i)) 217 | else 'IF' if 'Ab' in i 218 | else 'MSA' 219 | for i in melted_correlations['Input'] 220 | ] 221 | 222 | return melted_correlations 223 | #if we wish to drop the samples on the lower limit of quantitation, drop samples from each ag group individually and compute correlation with IF log_like 224 | else: 225 | all_ag_melted = pd.DataFrame({}) 226 | for ag in ag_columns: 227 | filt_df = dms_df[dms_df[ag] > min(dms_df[ag])] 228 | 229 | correlations = filt_df[[ag] + conditions].corr(method='spearman') 230 | correlations = correlations.drop(conditions, axis= 1) 231 | correlations = correlations.drop(ag, axis= 0) 232 | 233 | # Melt the correlations DataFrame and rename the index column 234 | melted_correlations = ( 235 | correlations 236 | .reset_index() 237 | .melt(id_vars='index', var_name='Target', value_name='Correlation') 238 | ).rename(columns={'index': 'Input'}) 239 | 240 | all_ag_melted = pd.concat([all_ag_melted, melted_correlations], ignore_index= True) 241 | all_ag_melted['Method'] = ['LM' if (('ESM' in i) or ('AbLang' in i)) 242 | else 'IF' if 'Ab' in i 243 | else 'MSA' 244 | for i in all_ag_melted['Input'] 245 | ] 246 | 247 | return all_ag_melted 248 | 249 | 250 | #get melted correlations matrix of structure input x target antigen, for a given antibody 251 | def get_g6_corr(g6Hc_dict, g6LC_dict, dropLOQ = False ): 252 | 253 | ab_name = 'g6' 254 | ag_columns = g6LC_dict['ag_columns'] 255 | vh_and_vl_dms_df = pd.DataFrame({}) 256 | 257 | for d in [g6HC_dict, g6LC_dict]: 258 | 259 | files = d['files'] 260 | dms_df = d['dms_df'] 261 | 262 | inv_fold_files = {key: value for key, value in files.items() if key not in ['ESM-1v', 'AbLang', 'abYsis']} 263 | other_files = {key: value for key, value in files.items() if key in ['ESM-1v', 'AbLang', 'abYsis']} 264 | 265 | #retreive correlations from inverse fold scores 266 | for key, filepath in inv_fold_files.items(): 267 | df = pd.read_csv(filepath) # Read the CSV file into a DataFrame 268 | column_label = key 269 | log_likelihood_column = df['log_likelihood'] # Extract the 'log_likelihood' column 270 | dms_df[column_label] = log_likelihood_column 271 | 272 | #retreive correlations from abysis, ablang, and esm1v scores 273 | for key, filepath in other_files.items(): 274 | if key == 'ESM-1v': 275 | singleMuts_df = pd.read_csv(filepath) 276 | # Get the columns for all esm1v models 277 | esm_columns = [col for col in singleMuts_df.columns if col.startswith('esm')] 278 | # Calculate the average for each esm column 279 | esm_avg_values = singleMuts_df[esm_columns].mean(axis=1) 280 | singleMuts_df[key] = esm_avg_values 281 | if ab_name == 'g6': 282 | #g6 data set is only single mutations 283 | dms_df[key] = esm_avg_values 284 | elif key == 'abYsis': 285 | abysis_df = pd.read_csv(filepath, sep = '\t', dtype = {'pos': int}) 286 | if ab_name == 'g6': 287 | dms_df[key] = get_abysis_single_scores_df(dms_df['mutant'], abysis_df)[key] 288 | elif key == 'AbLang': 289 | ablang_df = pd.read_csv(filepath, dtype = {'pos': int}) 290 | if ab_name == 'g6': 291 | dms_df[key] = get_ablang_single_scores_df(dms_df['mutant'], ablang_df)[key] 292 | 293 | 294 | vh_and_vl_dms_df = pd.concat([vh_and_vl_dms_df, dms_df], ignore_index= True) 295 | 296 | conditions = list(files.keys()) 297 | 298 | # 299 | if not dropLOQ: 300 | correlations = vh_and_vl_dms_df[ag_columns + conditions].corr(method='spearman') 301 | correlations = correlations.drop(conditions, axis= 1) 302 | correlations = correlations.drop(ag_columns, axis= 0) 303 | 304 | # Melt the correlations DataFrame and rename the index column 305 | melted_correlations = ( 306 | correlations 307 | .reset_index() 308 | .melt(id_vars='index', var_name='Target', value_name='Correlation') 309 | .rename(columns={'index': 'Input'}) 310 | ) 311 | melted_correlations['Method'] = ['LM' if (('ESM' in i) or ('AbLang' in i)) 312 | else 'IF' if 'Ab' in i 313 | else 'MSA' 314 | for i in melted_correlations['Input'] 315 | ] 316 | melted_correlations['Target'] = ['VEGF-A'] * len(melted_correlations) 317 | 318 | return melted_correlations 319 | #if we wish to drop the samples on the lower limit of quantitation, drop samples from each ag group individually and compute correlation with IF log_like 320 | else: 321 | all_ag_melted = pd.DataFrame({}) 322 | for ag in ag_columns: 323 | filt_df = vh_and_vl_dms_df[vh_and_vl_dms_df[ag] > min(vh_and_vl_dms_df[ag])] 324 | 325 | correlations = filt_df[[ag] + conditions].corr(method='spearman') 326 | correlations = correlations.drop(conditions, axis= 1) 327 | correlations = correlations.drop(ag, axis= 0) 328 | 329 | # Melt the correlations DataFrame and rename the index column 330 | melted_correlations = ( 331 | correlations 332 | .reset_index() 333 | .melt(id_vars='index', var_name='Target', value_name='Correlation') 334 | ).rename(columns={'index': 'Input'}) 335 | 336 | all_ag_melted = pd.concat([all_ag_melted, melted_correlations], ignore_index= True) 337 | all_ag_melted['Method'] = ['LM' if (('ESM' in i) or ('AbLang' in i)) 338 | else 'IF' if 'Ab' in i 339 | else 'MSA' 340 | for i in all_ag_melted['Input'] 341 | ] 342 | melted_correlations['Target'] = ['VEGF-A'] * len(melted_correlations) 343 | 344 | return all_ag_melted 345 | 346 | 347 | #plot correlation bars for cr antibodies: cr6261 and cr9114 348 | def plot_hbar(title, melted_correlations, palette = bokeh.palettes.Spectral6, compare = 'model'): 349 | 350 | if compare == 'model': 351 | melted_correlations.replace('Ab-Ag','InverseFolding', inplace = True) 352 | color_by = 'Input' 353 | elif compare == 'IF': 354 | color_by = 'Method' 355 | else: 356 | color_by = 'Target' 357 | 358 | 359 | melted_correlations['cats'] = melted_correlations.apply(lambda x: (x["Target"], x["Input"]), axis = 1) 360 | factors = list(melted_correlations.cats)[::-1] 361 | 362 | if 'CR6261' in title: 363 | x_min = -0.15 364 | else: 365 | x_min = 0 366 | 367 | p = bokeh.plotting.figure( 368 | height=340, 369 | width=440, 370 | x_axis_label="Spearman Correlation", 371 | x_range=[x_min, 1], 372 | y_range=bokeh.models.FactorRange(*factors), 373 | tools="save", 374 | title = title 375 | ) 376 | 377 | 378 | p.hbar( 379 | source=melted_correlations, 380 | y="cats", 381 | right="Correlation", 382 | height=0.6, 383 | line_color = 'black', 384 | legend_field = color_by, 385 | # use the palette to colormap based on the the x[1:2] values 386 | fill_color=factor_cmap( 387 | color_by, 388 | palette= palette, 389 | factors = list(melted_correlations[color_by].unique()), 390 | start=1, 391 | end=3) 392 | ) 393 | 394 | #labels_df = melted_correlations[melted_correlations['Input'].isin(['Ab-Ag','ESM-1v', 'abYsis', 'InverseFold'])] 395 | labels_df = melted_correlations 396 | labels_df['corr_str'] = labels_df['Correlation'].apply(lambda x: round(x, 2)).astype(str) 397 | labels_source = bokeh.models.ColumnDataSource(labels_df) 398 | 399 | labels = bokeh.models.LabelSet(x='Correlation', y='cats', text='corr_str',text_font_size = "10px", 400 | x_offset=12, y_offset=-5, source=labels_source, render_mode='canvas') 401 | 402 | p.ygrid.grid_line_color = None 403 | p.y_range.range_padding = 0.1 404 | p.add_layout(labels) 405 | 406 | if compare == 'IF': 407 | p.legend.orientation = "vertical" 408 | p.legend.location = "bottom_right" 409 | p.legend.label_text_font_size = "8pt" 410 | else: 411 | p.legend.visible = False 412 | 413 | p.output_backend = "svg" 414 | return p 415 | 416 | def plot_g6_scatter(g6Hc_df, g6Lc_df, ag_columns, dropLOQ): 417 | ag = ag_columns[0] 418 | if dropLOQ: 419 | g6Lc_df = g6Lc_df[g6Lc_df[ag] > np.min(g6Lc_df[ag])] 420 | g6Hc_df = g6Hc_df[g6Hc_df[ag] > np.min(g6Hc_df[ag])] 421 | 422 | 423 | hv.extension("bokeh") 424 | 425 | plots = [] 426 | for (plot_df,title) in [(g6Hc_df, 'g6.31 VH'),(g6Lc_df, 'g6.31 VL')]: 427 | 428 | # Generate HoloViews Points Element 429 | points = hv.Points( 430 | data=plot_df, 431 | kdims=[ag, 'Ab-Ag'], 432 | ) 433 | 434 | # Datashade with spreading of points 435 | p = hv.operation.datashader.dynspread( 436 | hv.operation.datashader.rasterize( 437 | points 438 | ).opts( 439 | cmap='Magma', 440 | cnorm='linear', 441 | ) 442 | ).opts( 443 | frame_width=350, 444 | frame_height=300, 445 | padding=0.05, 446 | show_grid=False, 447 | colorbar = True 448 | ) 449 | 450 | p = hv.render(p) 451 | p.output_backend = "svg" 452 | 453 | plots.append(p) 454 | 455 | return plots 456 | 457 | #plot scatter plot colored by n_muts for CR antibodies 458 | def single_cr_scatterMut(plot_df, ag_columns, x, y, title, ): 459 | x_axis_label = x + '-logKd' if x in ag_columns else 'log likelihood' 460 | y_axis_label = y + '-logKd' if y in ag_columns else 'log likelihood' 461 | 462 | 463 | #set up figure 464 | p = bokeh.plotting.figure( 465 | width=500, 466 | height=400, 467 | x_axis_label= x_axis_label, 468 | y_axis_label= y_axis_label, 469 | title = title 470 | ) 471 | 472 | #Set up color mapper to color by number of mutations from germline 473 | mapper = bokeh.transform.linear_cmap(field_name='som_mut', palette=bokeh.palettes.inferno(16) ,low=min(plot_df.som_mut) ,high=max(plot_df.som_mut)) 474 | p.circle( 475 | source=plot_df, 476 | x= x, 477 | y= y, 478 | alpha = 0.2, 479 | line_color=mapper, 480 | color=mapper, 481 | size = 4 482 | ) 483 | 484 | # 485 | p.diamond( 486 | #germline seqid is all 0's 487 | source = plot_df[plot_df['genotype'].str.contains('1') == False], 488 | x= x, 489 | y= y, 490 | color = 'dodgerblue', 491 | size = 12, 492 | legend_label = 'germline' 493 | ) 494 | p.star( 495 | #mature seqid is all 0's 496 | source = plot_df[plot_df['genotype'].str.contains('0') == False], 497 | x= x, 498 | y= y, 499 | color = 'green', 500 | size = 12, 501 | legend_label = 'mature' 502 | ) 503 | p.legend.location = "top_left" 504 | p.output_backend = "svg" 505 | 506 | return p, mapper 507 | 508 | #plot single scatter plot for CR antibodies - with dynamic spreading and datashading 509 | def single_cr_scatter(plot_df, ag_columns, x, y, title, ): 510 | x_axis_label = x + '-logKd' if x in ag_columns else 'log likelihood' 511 | y_axis_label = y + '-logKd' if y in ag_columns else 'log likelihood' 512 | 513 | hv.extension("bokeh") 514 | 515 | # Generate HoloViews Points Element 516 | points = hv.Points( 517 | data=plot_df, 518 | kdims=[x, y], 519 | ) 520 | 521 | # Datashade with spreading of points 522 | p = hv.operation.datashader.dynspread( 523 | hv.operation.datashader.rasterize( 524 | points 525 | ).opts( 526 | cmap='Magma', 527 | cnorm='linear', 528 | ) 529 | ).opts( 530 | frame_width=350, 531 | frame_height=300, 532 | padding=0.05, 533 | show_grid=False, 534 | colorbar = True 535 | ) 536 | 537 | p = hv.render(p) 538 | 539 | p.grid.visible = False 540 | # 541 | p.diamond( 542 | #germline seqid is all 0's 543 | source = plot_df[plot_df['genotype'].str.contains('1') == False], 544 | x= x, 545 | y= y, 546 | color = 'dodgerblue', 547 | size = 12, 548 | #legend_label = 'germline' 549 | ) 550 | p.star( 551 | #mature seqid is all 0's 552 | source = plot_df[plot_df['genotype'].str.contains('0') == False], 553 | x= x, 554 | y= y, 555 | color = 'green', 556 | size = 12, 557 | #legend_label = 'mature' 558 | ) 559 | #p.legend.location = "top_left" 560 | p.output_backend = "svg" 561 | 562 | return p 563 | 564 | def plot_CR_scatter_colorMuts(title, files, dms_df, ag_columns, dropLOQ = False): 565 | 566 | agAb_files = {key: value for key, value in files.items() if key in ['Ab-Ag',]} 567 | 568 | for key, filepath in agAb_files.items(): 569 | df = pd.read_csv(filepath) # Read the CSV file into a DataFrame 570 | column_label = key 571 | log_likelihood_column = df['log_likelihood'] # Extract the 'log_likelihood' column 572 | dms_df[column_label] = log_likelihood_column 573 | 574 | plots = [] 575 | #predictions vs experiment 576 | for ag in ag_columns: 577 | 578 | if dropLOQ: 579 | plot_df = dms_df[dms_df[ag] > min(dms_df[ag])] 580 | #add back wt which is on LOQ for visualization 581 | plot_df = pd.concat([plot_df, dms_df[dms_df['som_mut'] == 0]], ignore_index = True, axis = 0) 582 | else: 583 | plot_df = dms_df 584 | 585 | p, _ = single_cr_scatterMut(plot_df, ag_columns, x = ag, y = 'Ab-Ag', title = f'{title} against {ag}') 586 | plots.append(p) 587 | 588 | #experimental vs experimental 589 | plot_df = dms_df 590 | if dropLOQ: 591 | for ag in ag_columns: 592 | plot_df = plot_df[plot_df[ag] > min(plot_df[ag])] 593 | #add back wt which is on LOQ for visualization 594 | plot_df = pd.concat([plot_df, dms_df[dms_df['som_mut'] == 0]], ignore_index = True, axis = 0) 595 | 596 | p_exp, mapper = single_cr_scatterMut(plot_df, ag_columns, x= ag_columns[0], y = ag_columns[1], title = 'Experimental Cross-Reactive Binding Landscape') 597 | p_exp.legend.location = "top_left" 598 | color_bar = bokeh.models.ColorBar(color_mapper=mapper['transform'], width=8, title = 'Amino Acid Mutations') 599 | #p_exp.add_layout(color_bar, 'right') 600 | 601 | plots.append(p_exp) 602 | p.output_backend = "svg" 603 | 604 | return plots 605 | 606 | def plot_CR_scatterplots(title, files, dms_df, ag_columns, dropLOQ = False): 607 | 608 | agAb_files = {key: value for key, value in files.items() if key in ['Ab-Ag',]} 609 | 610 | for key, filepath in agAb_files.items(): 611 | df = pd.read_csv(filepath) # Read the CSV file into a DataFrame 612 | column_label = key 613 | log_likelihood_column = df['log_likelihood'] # Extract the 'log_likelihood' column 614 | dms_df[column_label] = log_likelihood_column 615 | 616 | plots = [] 617 | #predictions vs experiment 618 | for ag in ag_columns: 619 | 620 | if dropLOQ: 621 | plot_df = dms_df[dms_df[ag] > min(dms_df[ag])] 622 | #add back wt which is on LOQ for visualization 623 | plot_df = pd.concat([plot_df, dms_df[dms_df['som_mut'] == 0]], ignore_index = True, axis = 0) 624 | else: 625 | plot_df = dms_df 626 | 627 | p = single_cr_scatter(plot_df, ag_columns, x = ag, y = 'Ab-Ag', title = f'{title} against {ag}') 628 | plots.append(p) 629 | 630 | #experimental vs experimental 631 | plot_df = dms_df 632 | if dropLOQ: 633 | for ag in ag_columns: 634 | plot_df = plot_df[plot_df[ag] > min(plot_df[ag])] 635 | #add back wt which is on LOQ for visualization 636 | plot_df = pd.concat([plot_df, dms_df[dms_df['som_mut'] == 0]], ignore_index = True, axis = 0) 637 | 638 | p_exp = single_cr_scatter(plot_df, ag_columns, x= ag_columns[0], y = ag_columns[1], title = 'Experimental Cross-Reactive Binding Landscape') 639 | 640 | plots.append(p_exp) 641 | p_exp.output_backend = "svg" 642 | 643 | return plots 644 | 645 | #plot correlation bars for cr antibodies: cr6261 and cr9114 646 | def plot_LOQ_comparison(ab_name, unfilt_corrs, filt_corrs, palette = bokeh.palettes.Spectral6,): 647 | unfilt_corrs['LOQ'] = ['Included'] * len(unfilt_corrs) 648 | filt_corrs['LOQ'] = ['Excluded'] * len(filt_corrs) 649 | combined_melt = pd.concat([unfilt_corrs,filt_corrs], ignore_index= True) 650 | 651 | combined_melt['cats'] = combined_melt.apply(lambda x: (x["Target"], x["Input"], x["LOQ"]), axis = 1) 652 | factors = list(combined_melt.cats)[::-1] 653 | 654 | p = bokeh.plotting.figure( 655 | height=850, 656 | width=400, 657 | x_axis_label="Spearman Correlation", 658 | x_range=[-.25, 1], 659 | y_range=bokeh.models.FactorRange(*factors), 660 | tools="save", 661 | title = f'{ab_name}, Impact of Ignoring Lower LOQ ' 662 | ) 663 | 664 | p.hbar( 665 | source=combined_melt, 666 | y="cats", 667 | right="Correlation", 668 | height=0.7, 669 | line_color = 'black', 670 | legend_field = 'LOQ', 671 | # use the palette to colormap based on the the x[1:2] values 672 | fill_color=factor_cmap( 673 | 'LOQ', 674 | palette= palette, 675 | factors = list(combined_melt['LOQ'].unique()), 676 | start=1, 677 | end=2) 678 | ) 679 | 680 | p.ygrid.grid_line_color = None 681 | p.y_range.range_padding = 0.1 682 | p.legend.location = "bottom_right" 683 | p.legend.orientation = "vertical" 684 | p.yaxis.axis_label = "Model Input" 685 | p.output_backend = "svg" 686 | 687 | return p 688 | 689 | 690 | 691 | if __name__ == '__main__': 692 | 693 | datasets = [cr9114_dict, cr6261_dict, (g6HC_dict, g6LC_dict) ] 694 | results = [] 695 | 696 | #compute with and without points on lower limit of quantitation ignored 697 | for dropLOQ in [False, True]: 698 | 699 | g6_corrs = [] 700 | all_corr_plots = [] 701 | 702 | for d in datasets: 703 | if type(d) is tuple: 704 | g6Hc, g6Lc = d 705 | title = g6Hc['ab_name'] + ', ' + g6Hc['expt_type'] 706 | g6_combined = get_g6_corr(g6Hc, g6Lc, dropLOQ= dropLOQ) 707 | 708 | all_corr_plots.append(plot_hbar(title, g6_combined, g6Lc['palette'], 'IF')) 709 | 710 | #also plot scatter plots 711 | scatter_fname = f"output/ab_mutagenesis_expts/g6_scatterplot_colorMuts{'_dropLOQ' if dropLOQ else ''}.html" 712 | bokeh.plotting.output_file(scatter_fname) 713 | g6_scatterMuts_plots = plot_g6_scatter(g6Hc['dms_df'], g6Lc['dms_df'], g6Hc['ag_columns'], dropLOQ= dropLOQ) 714 | bokeh.io.show(bokeh.layouts.gridplot(g6_scatterMuts_plots, ncols = 2)) 715 | 716 | if (dropLOQ == False): 717 | results.append(g6_combined) 718 | 719 | else: 720 | corr_df = get_corr(d['ab_name'], d['files'], d['dms_df'], d['ag_columns'], dropLOQ= dropLOQ) 721 | title = d['ab_name'] + ', ' + d['expt_type'] 722 | 723 | all_corr_plots.append(plot_hbar(title, corr_df, d['palette'], 'IF')) 724 | 725 | #also plot scatter plots for CR9114, CR6261 726 | scatter_fname = f"output/ab_mutagenesis_expts/{d['ab_name']}_scatter_colorMuts{'_dropLOQ' if dropLOQ else ''}.html" 727 | bokeh.plotting.output_file(scatter_fname) 728 | cr_scatterMuts_plots = plot_CR_scatter_colorMuts(title, d['files'], d['dms_df'], d['ag_columns'], dropLOQ= dropLOQ) 729 | bokeh.io.show(bokeh.layouts.gridplot(cr_scatterMuts_plots, ncols = 3)) 730 | 731 | #also scatter plots with datashading/dynamic spreading for CR9114, CR6261 732 | scatter_fname = f"output/ab_mutagenesis_expts/{d['ab_name']}_scatter{'_dropLOQ' if dropLOQ else ''}.html" 733 | bokeh.plotting.output_file(scatter_fname) 734 | cr_scattter_plots = plot_CR_scatterplots(title, d['files'], d['dms_df'], d['ag_columns'], dropLOQ= dropLOQ) 735 | bokeh.io.show(bokeh.layouts.gridplot(cr_scattter_plots, ncols = 3)) 736 | 737 | if (dropLOQ == False): 738 | results.append(corr_df) 739 | 740 | all_corr_fname = f"output/ab_mutagenesis_expts/benchmarks{'_dropLOQ' if dropLOQ else ''}.html" 741 | bokeh.plotting.output_file(all_corr_fname) 742 | bokeh.io.show(bokeh.layouts.gridplot(all_corr_plots, ncols = len(all_corr_plots))) 743 | 744 | compareLOQ_fname = 'output/ab_mutagenesis_expts/LOQ_comparison.html' 745 | bokeh.plotting.output_file(compareLOQ_fname) 746 | compareLOQ_plots = [] 747 | for d in datasets: 748 | if (type(d) != tuple) and (d['ab_name'].startswith('CR')): 749 | corr_withLOQ = get_corr(d['ab_name'], d['files'], d['dms_df'], d['ag_columns'], ) 750 | corr_withoutLOQ = get_corr(d['ab_name'], d['files'], d['dms_df'], d['ag_columns'], dropLOQ= dropLOQ) 751 | 752 | compareLOQ_plots.append(plot_LOQ_comparison(d['ab_name'], corr_withLOQ, corr_withoutLOQ , d['palette'])) 753 | bokeh.io.show(bokeh.layouts.gridplot(compareLOQ_plots, ncols = len(compareLOQ_plots))) 754 | 755 | pd.concat(results, ignore_index= True).to_csv( 'output/ab_mutagenesis_expts/results.csv' , index = False) 756 | --------------------------------------------------------------------------------