├── .gitignore
├── structural-evolution-overview.png
├── LICENSE
├── bin
├── mpnn_ab_benchmarking.sh
├── if_ab_benchmarking.sh
├── generate_dms.py
├── dms_utils.py
├── parse_abysis.py
├── eval_ablang.py
├── recommend.py
├── score_log_likelihoods.py
├── esm1v_ab_benchmarking.sh
├── multichain_util.py
├── plot_mpnn_benchmarks.py
├── util.py
├── dms_enrichment.py
├── plot_esm1v_benchmarks.py
└── plot_ab-binding_benchmarks.py
├── README.md
└── environment.yml
/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 | output
3 | *.pyc
4 | *.png
5 | *.log
6 | .tar
--------------------------------------------------------------------------------
/structural-evolution-overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/varun-shanker/structural-evolution/HEAD/structural-evolution-overview.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 varun-shanker
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/bin/mpnn_ab_benchmarking.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | run_protein_mpnn() {
4 | python ../ProteinMPNN/protein_mpnn_run.py \
5 | --path_to_fasta "$1" \
6 | --score_only 1 \
7 | --save_score 1 \
8 | --pdb_path_chains "$2" \
9 | --out_folder "$3" \
10 | --pdb_path "$4"
11 | }
12 |
13 | # G6 Light Chain
14 | run_protein_mpnn \
15 | "data/ab_mutagenesis_expts/g6/g6_2fjg_lc_lib.fasta" \
16 | "L" \
17 | "output/ab_mutagenesis_expts/g6/proteinMpnnLC" \
18 | "data/ab_mutagenesis_expts/g6/2fjg_vlh_fvar.pdb"
19 |
20 | # G6 Heavy Chain
21 | run_protein_mpnn \
22 | "data/ab_mutagenesis_expts/g6/g6_2fjg_hc_lib.fasta" \
23 | "H" \
24 | "output/ab_mutagenesis_expts/g6/proteinMpnnHC" \
25 | "data/ab_mutagenesis_expts/g6/2fjg_vlh_fvar.pdb"
26 |
27 | # CR6261 Heavy Chain
28 | run_protein_mpnn \
29 | "data/ab_mutagenesis_expts/cr6261/cr6261_3gbn_hc_lib.fasta" \
30 | "H" \
31 | "output/ab_mutagenesis_expts/cr6261/mpnn" \
32 | "data/ab_mutagenesis_expts/cr6261/3gbn_ablh_fvar.pdb"
33 |
34 | # CR9114 Heavy Chain
35 | run_protein_mpnn \
36 | "data/ab_mutagenesis_expts/cr9114/cr9114_4fqi_hc_lib.fasta" \
37 | "H" \
38 | "output/ab_mutagenesis_expts/cr9114/mpnn" \
39 | "data/ab_mutagenesis_expts/cr9114/4fqi_ablh_fvar.pdb"
--------------------------------------------------------------------------------
/bin/if_ab_benchmarking.sh:
--------------------------------------------------------------------------------
1 | abs=("cr6261" "cr9114" "g6" "g6")
2 | ab_fastas=("cr6261_3gbn_hc_lib.fasta" "cr9114_4fqi_hc_lib.fasta" "g6_2fjg_hc_lib.fasta" "g6_2fjg_lc_lib.fasta")
3 | data_path="data/ab_mutagenesis_expts/"
4 | out_prefix="output/ab_mutagenesis_expts/"
5 |
6 | for ((i=0; i<${#abs[@]}; i++)); do
7 | ab="${abs[i]}"
8 | ab_fasta="${ab_fastas[i]}"
9 | ab_dir_path="${data_path}${ab}/"
10 | struc_list=("${ab_dir_path}"*.pdb)
11 | ab_out_dir="${out_prefix}${ab}/"
12 |
13 | # Set the default chain value
14 | chain="H"
15 | # Special handling for 'g6' antibody
16 | if [[ "$ab" == "g6" ]]; then
17 | [[ "$ab_fasta" == *"lc"* ]] && chain="L"
18 | fi
19 |
20 | # gather pdbs and filter the hc/lc only structure from being scored by library for the other chain
21 | if [[ "$ab" == "g6" && "$chain" == "L" ]]; then
22 | struc_list=($(echo "${struc_list[@]}" | tr ' ' '\n' | grep -v '_h_' | tr '\n' ' '))
23 | elif [[ "$ab" == "g6" && "$chain" == "H" ]]; then
24 | struc_list=($(echo "${struc_list[@]}" | tr ' ' '\n' | grep -v '_l_' | tr '\n' ' '))
25 | fi
26 |
27 | mkdir -p "$ab_out_dir"
28 |
29 | for struc in "${struc_list[@]}"; do
30 | out_file="${ab_out_dir}${struc##*/}"
31 | if [[ "$ab" == "g6" ]]; then
32 | chain_modeled="$([ "$chain" == "H" ] && echo "hc" || echo "lc")"
33 | out_file="${out_file%_fvar.pdb}_${chain_modeled}_scores.csv"
34 | else
35 | out_file="${out_file%_fvar.pdb}_scores.csv"
36 | fi
37 |
38 | if [[ ! -f "$out_file" ]]; then
39 | python bin/score_log_likelihoods.py "$struc" --chain "$chain" --seqpath "${ab_dir_path}${ab_fasta}" --outpath "$out_file"
40 | else
41 | echo "$out_file already exists. Skipping..."
42 | fi
43 | done
44 | done
--------------------------------------------------------------------------------
/bin/generate_dms.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from dms_utils import deep_mutational_scan
3 | from pathlib import Path
4 | import numpy as np
5 |
6 | import esm
7 | from util import load_structure, extract_coords_from_structure
8 | import biotite.structure
9 | from collections import defaultdict
10 |
11 |
12 | def get_native_seq(pdbfile, chain):
13 | structure = load_structure(pdbfile, chain)
14 | _ , native_seq = extract_coords_from_structure(structure)
15 | return native_seq
16 |
17 | def write_dms_lib(args):
18 | '''Writes a deep mutational scanning library, including the native/wildtype (wt) of the
19 | indicated target chain in the structure to an output Fasta file'''
20 |
21 | sequence = get_native_seq(args.pdbfile, args.chain)
22 | Path(args.outpath).parent.mkdir(parents=True, exist_ok=True)
23 | with open(args.dmspath, 'w') as f:
24 | f.write('>wt\n')
25 | f.write(sequence+'\n')
26 | for pos, wt, mt in deep_mutational_scan(sequence):
27 | assert(sequence[pos] == wt)
28 | mut_seq = sequence[:pos] + mt + sequence[(pos + 1):]
29 |
30 | f.write('>' + str(wt) + str(pos+1) + str(mt) + '\n')
31 | f.write(mut_seq + '\n')
32 |
33 |
34 | def main():
35 | parser = argparse.ArgumentParser(
36 | description='Create a DMS library based on target chain in the structure.'
37 | )
38 | parser.add_argument(
39 | 'pdbfile', type=str,
40 | help='input filepath, either .pdb or .cif',
41 | )
42 | parser.add_argument(
43 | '--dmspath', type=str,
44 | help='output filepath for dms library',
45 | )
46 | parser.add_argument(
47 | '--chain', type=str,
48 | help='chain id for the chain of interest', default='A',
49 | )
50 |
51 | args = parser.parse_args()
52 |
53 | if args.dmspath is None:
54 | args.dmspath = f'predictions/{args.pdbfile[:-4]}-{args.chain}_dms.fasta'
55 |
56 | write_dms_lib(args)
57 |
58 | if __name__ == '__main__':
59 |
60 | main()
61 |
--------------------------------------------------------------------------------
/bin/dms_utils.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | import datetime
3 | from dateutil.parser import parse as dparse
4 | import errno
5 | import numpy as np
6 | import os
7 | import random
8 | import sys
9 | import time
10 | import warnings
11 | from Bio import pairwise2
12 | from Bio import BiopythonWarning
13 | warnings.simplefilter('ignore', BiopythonWarning)
14 | from Bio import Seq, SeqIO
15 |
16 |
17 | np.random.seed(1)
18 | random.seed(1)
19 |
20 | AAs = [
21 | 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
22 | 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y',
23 | ]
24 |
25 | def tprint(string):
26 | string = str(string)
27 | sys.stdout.write(str(datetime.datetime.now()) + ' | ')
28 | sys.stdout.write(string + '\n')
29 | sys.stdout.flush()
30 |
31 | def mkdir_p(path):
32 | try:
33 | os.makedirs(path)
34 | except OSError as exc: # Python >2.5
35 | if exc.errno == errno.EEXIST and os.path.isdir(path):
36 | pass
37 | else:
38 | raise
39 |
40 | def deep_mutational_scan(sequence, exclude_noop=True):
41 | for pos, wt in enumerate(sequence):
42 | for mt in AAs:
43 | if exclude_noop and wt == mt:
44 | continue
45 | yield (pos, wt, mt)
46 |
47 |
48 | def make_mutations(seq, mutations):
49 | mut_seq = [ char for char in seq ]
50 | for mutation in mutations:
51 | wt, pos, mt = mutation[0], int(mutation[1:-1]) - 1, mutation[-1]
52 | assert(seq[pos] == wt)
53 | mut_seq[pos] = mt
54 | mut_seq = ''.join(mut_seq).replace('-', '')
55 | return mut_seq
56 |
57 | def find_mutations(seq1, seq2):
58 | alignment = pairwise2.align.globalms(
59 | seq1, seq2, 5, -4, -6, -.1, one_alignment_only=True,
60 | )[0]
61 |
62 | mutation_set = []
63 | pos1, pos2, pos_map = 0, 0, {}
64 | for wt, mt in zip(alignment[0], alignment[1]):
65 | if wt != '-':
66 | pos1 += 1
67 | if mt != '-':
68 | pos2 += 1
69 | if wt != mt and wt != '-' and mt != '-':
70 | mut_str = f'{wt}{pos1}{mt}'
71 | mutation_set.append(mut_str)
72 |
73 | return mutation_set
--------------------------------------------------------------------------------
/bin/parse_abysis.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 | import os
4 |
5 | AAs = set('ACDEFGHIKLMNPQRSTVWY')
6 |
7 | seqs_abs = {
8 | 'cr6261_vh': 'EVQLVESGAEVKKPGSSVKVSCKASGGPFRSYAISWVRQAPGQGPEWMGGIIPIFGTTKYAPKFQGRVTITADDFAGTVYMELSSLRSEDTAMYYCAKHMGYQVRETMDVWGKGTTVTVSS',
9 | 'cr9114_vh': 'QVQLVQSGAEVKKPGSSVKVSCKSSGGTSNNYAISWVRQAPGQGLDWMGGISPIFGSTAYAQKFQGRVTISADIFSNTAYMELNSLTSEDTAVYFCARHGNYYYYSGMDVWGQGTTVTVSS',
10 | 'g6_vh': 'EVQLVESGGGLVQPGGSLRLSCAASGFTISDYWIHWVRQAPGKGLEWVAGITPAGGYTYYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARFVFFLPYAMDYWGQGTLVTV',
11 | 'g6_vl': 'DIQMTQSPSSLSASVGDRVTITCRASQDVSTAVAWYQQKPGKAPKLLIYSASFLYSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYTTPPTFGQGTKVEIK',
12 | 'beb_vh': 'QITLKESGPTLVKPTQTLTLTCTFSGFSLSISGVGVGWLRQPPGKALEWLALIYWDDDKRYSPSLKSRLTISKDTSKNQVVLKMTNIDPVDTATYYCAHHSISTIFDHWGQGTLVTVSS',
13 | 'beb_vl' : 'QSALTQPASVSGSPGQSITISCTATSSDVGDYNYVSWYQQHPGKAPKLMIFEVSDRPSGISNRFSGSKSGNTASLTISGLQAEDEADYYCSSYTTSSAVFGGGTKLTVL',
14 | 'sa58_vh': 'QVQLAQSGSELRKPGASVKVSCDTSGHSFTSNAIHWVRQAPGQGLEWMGWINTDTGTPTYAQGFTGRFVFSLDTSARTAYLQISSLKADDTAVFYCARERDYSDYFFDYWGQGTLVTVSS',
15 | 'sa58_vl': 'EVVMTQSPASLSVSPGERATLSCRARASLGISTDLAWYQQRPGQAPRLLIYGASTRATGIPARFSGSGSGTEFTLTISSLQSEDSAVYYCQQYSNWPLTFGGGTKVEIK',
16 | }
17 |
18 |
19 | def parse_abysis(seq, seq_name):
20 | fname = f'data/abysis/abysis_counts_{seq_name}.json'
21 | with open(fname) as f:
22 | json_data = json.load(f)
23 |
24 | data = []
25 | for idx, freq_table in enumerate(json_data['frequencies']):
26 | wt = seq[idx]
27 | total = freq_table['total']
28 | for entry in freq_table['counts']:
29 | if entry['aa'] == seq[idx]:
30 | wt_frac = entry['c'] / total
31 | aa_to_freq, seen = {}, set()
32 | for entry in freq_table['counts']:
33 | mt = entry['aa']
34 | seen.add(mt)
35 | counts = entry['c']
36 | frac = entry['c'] / total
37 | ratio = frac / wt_frac
38 | data.append([ idx + 1, wt, mt, counts, frac, ratio ])
39 | for mt in AAs - seen:
40 | data.append([ idx + 1, wt, mt, 0, 0., 0. ])
41 |
42 | df = pd.DataFrame(data, columns=[
43 | 'pos',
44 | 'wt',
45 | 'mt',
46 | 'counts',
47 | 'fraction',
48 | 'likelihood_ratio',
49 | ])
50 |
51 | if any(prefix in seq_name for prefix in ['cr6261', 'cr9114', 'g6']):
52 | subdirectory = seq_name.split('_')[0]
53 | else:
54 | subdirectory = ''
55 |
56 | output_dir = os.path.join('output', 'ab_mutagenesis_expts', subdirectory)
57 | os.makedirs(output_dir, exist_ok=True)
58 |
59 | output_path = os.path.join(output_dir, f'abysis_counts_{seq_name}.txt')
60 | df.to_csv(output_path, sep='\t')
61 |
62 | if __name__ == '__main__':
63 | for seq_name in seqs_abs:
64 | seq = seqs_abs[seq_name]
65 | parse_abysis(seq, seq_name)
--------------------------------------------------------------------------------
/bin/eval_ablang.py:
--------------------------------------------------------------------------------
1 | import ablang
2 | import numpy as np
3 | import pandas as pd
4 | import scipy.special
5 |
6 | heavy_ablang = ablang.pretrained("heavy")
7 | heavy_ablang.freeze()
8 |
9 | light_ablang = ablang.pretrained("light")
10 | light_ablang.freeze()
11 |
12 | ab_dict = dict(
13 | cr6261= ('EVQLVESGAEVKKPGSSVKVSCKASGGPFRSYAISWVRQAPGQGPEWMGGIIPIFGTTKYAPKFQGRVTITADDFAGTVYMELSSLRSEDTAMYYCAKHMGYQVRETMDVWGKGTTVTVSS',
14 | 'QSVLTQPPSVSAAPGQKVTISCSGSSSNIGNDYVSWYQQLPGTAPKLLIYDNNKRPSGIPDRFSGSKSGTSATLGITGLQTGDEANYYCATWDRRPTAYVVFGGGTKLTVL'),
15 | cr9114= ('QVQLVQSGAEVKKPGSSVKVSCKSSGGTSNNYAISWVRQAPGQGLDWMGGISPIFGSTAYAQKFQGRVTISADIFSNTAYMELNSLTSEDTAVYFCARHGNYYYYSGMDVWGQGTTVTVSS',
16 | 'QSALTQPPAVSGTPGQRVTISCSGSDSNIGRRSVNWYQQFPGTAPKLLIYSNDQRPSVVPDRFSGSKSGTSASLAISGLQSEDEAEYYCAAWDDSLKGAVFGGGTQLTVL'),
17 | g6 = ('EVQLVESGGGLVQPGGSLRLSCAASGFTISDYWIHWVRQAPGKGLEWVAGITPAGGYTYYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARFVFFLPYAMDYWGQGTLVTV',
18 | 'DIQMTQSPSSLSASVGDRVTITCRASQDVSTAVAWYQQKPGKAPKLLIYSASFLYSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYTTPPTFGQGTKVEIK')
19 |
20 | )
21 |
22 | def eval_ablang(s, ab, chain):
23 | if chain == 'hc':
24 | fname = 'output/ab_mutagenesis_expts/'+ab+'/' + ab +'_hc_ablangScores.csv'
25 | log_likelihoods = heavy_ablang(s, mode = 'likelihood')[0][1:-1]
26 | alphabet = heavy_ablang.tokenizer.vocab_to_aa
27 | elif chain == 'lc':
28 | fname = 'output/ab_mutagenesis_expts/'+ab+'/' + ab +'_lc_ablangScores.csv'
29 | log_likelihoods = light_ablang(s, mode = 'likelihood')[0][1:-1]
30 | alphabet = light_ablang.tokenizer.vocab_to_aa
31 |
32 | assert (log_likelihoods.shape)[0] == len(s)
33 |
34 | filt_alphabet = {key: value for key, value in alphabet.items() if value.isalpha()}
35 | log_likelihood_ratio = []
36 | for i,res_log_likelihoods in enumerate(log_likelihoods):
37 | wt_res = s[i]
38 | wt_index = list(filt_alphabet.values()).index(wt_res)
39 | wt_log_likelihood = res_log_likelihoods[wt_index]
40 | log_likelihood_ratio.extend(res_log_likelihoods-wt_log_likelihood)
41 |
42 | res_order = [alphabet[key] for key in range(1, 21)] #extract order of residues in likelihood
43 | mt = res_order * len(s)
44 | wt = [char for char in s for i in range(len(res_order))]
45 | pos = [i+1 for i in range(len(s)) for j in range(len(res_order))]
46 | data = {
47 | 'pos': pos,
48 | 'wt': wt,
49 | 'mt': mt,
50 | 'log_likelihood' : log_likelihoods.flatten(),
51 | 'log_likelihood_ratio' : log_likelihood_ratio,
52 | }
53 | df = pd.DataFrame(data)
54 | df.to_csv(fname, index = False)
55 |
56 |
57 | def main():
58 |
59 | for ab in ab_dict:
60 | for s,chain in zip(ab_dict[ab], ('hc','lc')):
61 | eval_ablang(s, ab, chain)
62 |
63 |
64 | if __name__ == '__main__':
65 | main()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # structural-evolution
2 |
3 | This repository scripts for running the analysis described in the paper ["Unsupervised evolution of protein and antibody complexes with a structure-informed language model"](https://www.science.org/stoken/author-tokens/ST-1968/full).
4 |
5 |
6 |
7 |
8 |
9 | ## Setup/Installation
10 | 1. Clone this repository
11 | ```
12 | git clone https://github.com/varun-shanker/structural-evolution.git
13 | ```
14 | 2. Install and Activate Conda Environment with Required Dependencies
15 | ```
16 | conda env create -f environment.yml
17 | conda activate struct-evo
18 | ```
19 | 3. Download and unzip the model weights from [here](https://zenodo.org/records/12631662), then insert them in torch checkpoints.
20 | ```
21 | wget -P ~/.cache/torch/hub/checkpoints https://zenodo.org/records/12631662/files/esm_if1_20220410.zip
22 | unzip ~/.cache/torch/hub/checkpoints/esm_if1_20220410.zip
23 | ```
24 | 4. Navigate to repository
25 | ```
26 | cd structural-evolution
27 | ```
28 |
29 | ## Generating Predictions
30 |
31 | To evaluate this model on a new protein or protein complex structure, run
32 | ```bash
33 | python bin/recommend.py [pdb/cif file] --chain [X]
34 | ```
35 | where `[pdb file]` is the file path to the pdb/cif structure file of the protein or protein complex and `[X]` is the target chain you wish to evolve. The default script will output the top `n`=10 predicted substitutions at unique residue positions (`maxrep=1`), where `n` and `maxrep` can be modified using the arguments (see below).
36 |
37 | To recommend mutations to antibody variable domain sequences, we have simply run the above script separately on the heavy and light chains.
38 |
39 | Additional arguments:
40 |
41 | ```
42 | --seqpath: filepath where fasta with dms library should be saved (defaults to new subdirectory in outputs directory)
43 | --outpath: output filepath for scores of variant sequences (defaults to new subdirectory in outputs directory)
44 | --chain: chain id for the chain of interest
45 | --n: number of top recommendations to be output (default: n=10)
46 | --maxrep: maximum representation of a single site in the output recommendations (default: maxrep = 1 is a unique set of recommendations where each mutation of a given wildtype residue is recommended at most once)
47 | --upperbound: only residue positions less than the user-defined upperbound are considered for recommendation in the final output (but all positions are still conditioned for scoring)
48 | --order: for multichain conditioning, provides option to specify the order of chains
49 | --offset: integer offset or adjustment for labeling of residue indices encoded in the structure file
50 | --multichain-backbone: use the backbones of all chains in the input for conditioning (default is True)
51 | --singlechain-backbone: use the backbone of only the target chain in the input for conditioning
52 | --nogpu: Do not use GPU even if available
53 | ```
54 |
55 | For example, to generate mutations for the heavy chain of LYCoV-1404, we would simply run the following:
56 |
57 | ```
58 | python bin/recommend.py examples/7mmo_abc_fvar.pdb \
59 | --chain A --seqpath examples/7mmo_chainA_lib.fasta \
60 | --outpath examples/7mmo_chainA_scores.csv \
61 | --upperbound 109 --offset 1
62 | ```
63 | In this example, we use a pdb structure file with variable regions of both chains of the antibody in complex with the antigen, SARS-CoV-2 receptor binding domain (RBD). To obtain recommendations specifically for the heavy chain, we specify chain A. The fasta file containing the library screened *in silico* and the corresponding output scores file are saved at the indicated paths.
64 | To limit the recommendations that are output and exclude mutations predicted in the final framework region, we set the upper bound to 109 (this value will vary for each antibody). Since the first residue is not included in the structure, we specify an offset of 1 to ensure the returned mutations are correctly indexed.
65 |
66 | ## Paper analysis scripts
67 |
68 | To reproduce the analysis in the paper, first download and extract data with the commands:
69 | ```bash
70 | wget https://zenodo.org/record/11260318/files/data.tar.gz
71 | tar xvf data.tar.gz
72 | ```
73 | To evaluate alternate sequence-only and structure-based scoring methods, follow directions [here](https://github.com/facebookresearch/esm?tab=readme-ov-file#zs_variant) and [here](https://github.com/dauparas/ProteinMPNN) for installation instructions.
74 |
75 | ## Citation
76 |
77 | Please cite the following publication when referencing this work.
78 |
79 | ```
80 | @article {Shanker-struct-evo,
81 | author = {Shanker, Varun and Bruun, Theodora and Hie, Brian and Kim, Peter},
82 | title = {Unsupervised evolution of protein and antibody complexes with a structure-informed language model},
83 | year = {2024},
84 | doi = {10.1126/science.adk8946},
85 | publisher = {American Association for the Advancement of Science},
86 | URL = {https://www.science.org/doi/10.1126/science.adk8946},
87 | journal = {Science}
88 | }
89 | ```
90 |
91 | ## License
92 | This project is licensed under the MIT License - see the LICENSE file for details.
93 |
--------------------------------------------------------------------------------
/bin/recommend.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from biotite.sequence.io.fasta import FastaFile, get_sequences
3 | import numpy as np
4 | from pathlib import Path
5 | import torch
6 | import torch.nn.functional as F
7 | from tqdm import tqdm
8 | import os
9 |
10 | import esm
11 | from dms_utils import deep_mutational_scan
12 | import warnings
13 | import pandas as pd
14 | import util
15 | import multichain_util
16 |
17 | import score_log_likelihoods
18 |
19 | def get_native_seq(pdbfile, chain):
20 | structure = util.load_structure(pdbfile, chain)
21 | _ , native_seq = util.extract_coords_from_structure(structure)
22 | return native_seq
23 |
24 | def write_dms_lib(args):
25 | '''Writes a deep mutational scanning library, including the native/wildtype (wt) of the
26 | indicated target chain in the structure to an output Fasta file'''
27 |
28 | sequence = get_native_seq(args.pdbfile, args.chain)
29 | Path(args.seqpath).parent.mkdir(parents=True, exist_ok=True)
30 | with open(args.seqpath, 'w') as f:
31 | f.write('>wt\n')
32 | f.write(sequence+'\n')
33 | for pos, wt, mt in deep_mutational_scan(sequence):
34 | assert(sequence[pos] == wt)
35 | mut_seq = sequence[:pos] + mt + sequence[(pos + 1):]
36 | f.write('>' + str(wt) + str(pos+1+args.offset) + str(mt) + '\n')
37 | f.write(mut_seq + '\n')
38 |
39 | def get_top_n(args):
40 | recs, rec_inds = [], []
41 | scores_df = pd.read_csv(args.outpath).sort_values(by = 'log_likelihood', ascending = False)
42 |
43 | for seqid in scores_df['seqid']:
44 | res_ind = seqid[1:-1]
45 | if (rec_inds.count(res_ind) < args.maxrep):
46 | if args.upperbound == None or (int(res_ind) < int(args.upperbound)):
47 | recs.append(seqid)
48 | rec_inds.append(res_ind)
49 | if len(recs) == args.n:
50 | break
51 |
52 | print(f'\n Chain {args.chain}')
53 | print(*recs, sep='\n')
54 |
55 | def get_model_checkpoint_path(filename):
56 | # Expanding the user's home directory
57 | return os.path.expanduser(f"~/.cache/torch/hub/checkpoints/{filename}")
58 |
59 | def main():
60 | parser = argparse.ArgumentParser(
61 | description='Score sequences based on a given structure.'
62 | )
63 | parser.add_argument(
64 | 'pdbfile', type=str,
65 | help='input filepath, either .pdb or .cif',
66 | )
67 | parser.add_argument(
68 | '--seqpath', type=str,
69 | help='filepath where fasta of dms library should be saveda',
70 | )
71 | parser.add_argument(
72 | '--outpath', type=str,
73 | help='output filepath for scores of variant sequences',
74 | )
75 | parser.add_argument(
76 | '--chain', type=str,
77 | help='chain id for the chain of interest', default='A',
78 | )
79 | parser.set_defaults(multichain_backbone=True)
80 | parser.add_argument(
81 | '--multichain-backbone', action='store_true',
82 | help='use the backbones of all chains in the input for conditioning'
83 | )
84 | parser.add_argument(
85 | '--singlechain-backbone', dest='multichain_backbone',
86 | action='store_false',
87 | help='use the backbone of only target chain in the input for conditioning'
88 | )
89 | parser.add_argument(
90 | '--order', type=str, default=None,
91 | help='for multichain, option to specify order of chains'
92 | )
93 | parser.add_argument(
94 | '--n', type=int,
95 | help='number of desired predictions to be output',
96 | default=10,
97 | )
98 | parser.add_argument(
99 | '--maxrep', type=int,
100 | help='maximum representation of a single site in the top recommendations \
101 | (eg: maxrep = 1 is a unique set where no wildtype residue is mutated more than once)',
102 | default=1,
103 | )
104 | parser.add_argument(
105 | '--offset', type=int,
106 | help='integer offset for labeling of residue indices encoded in the structure',
107 | default=0,
108 | )
109 | parser.add_argument(
110 | '--upperbound', type=int,
111 | help='only residue positions less than the user-defined upperbound are considered to be recommended for screening \
112 | (but all positions are still conditioned for scoring)',
113 | default=None,
114 | )
115 | parser.add_argument(
116 | "--nogpu", action="store_true",
117 | help="Do not use GPU even if available"
118 | )
119 |
120 | args = parser.parse_args()
121 |
122 | if args.seqpath is None:
123 | args.seqpath = f'output/{args.pdbfile[:-4]}-chain{args.chain}_dms.fasta'
124 |
125 | if args.outpath is None:
126 | args.outpath = f'output/{args.pdbfile[:-4]}-chain{args.chain}_scores.csv'
127 |
128 | #write dms library for target chain
129 | write_dms_lib(args)
130 |
131 | model_checkpoint_path = get_model_checkpoint_path('esm_if1_20220410.pt')
132 | with warnings.catch_warnings():
133 | warnings.simplefilter('ignore', UserWarning)
134 | model, alphabet = esm.pretrained.load_model_and_alphabet( \
135 | model_checkpoint_path \
136 | )
137 | model = model.eval()
138 |
139 | if args.multichain_backbone:
140 | score_log_likelihoods.score_multichain_backbone(model, alphabet, args)
141 | else:
142 | score_log_likelihoods.score_singlechain_backbone(model, alphabet, args)
143 |
144 | get_top_n(args)
145 |
146 |
147 | if __name__ == '__main__':
148 | main()
--------------------------------------------------------------------------------
/bin/score_log_likelihoods.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from biotite.sequence.io.fasta import FastaFile, get_sequences
3 | import numpy as np
4 | from pathlib import Path
5 | import torch
6 | import torch.nn.functional as F
7 | from tqdm import tqdm
8 | import warnings
9 | import os
10 | from dms_utils import deep_mutational_scan
11 | import esm
12 | import pandas as pd
13 | from multichain_util import extract_coords_from_complex, _concatenate_coords, _concatenate_seqs, score_sequence_in_complex
14 | from util import get_sequence_loss, load_structure, load_coords, score_sequence, extract_coords_from_structure
15 |
16 | def get_native_seq(pdbfile, chain):
17 | structure = load_structure(pdbfile, chain)
18 | _ , native_seq = extract_coords_from_structure(structure)
19 | return native_seq
20 |
21 | def score_singlechain_backbone(model, alphabet, args):
22 | if torch.cuda.is_available() and not args.nogpu:
23 | model = model.cuda()
24 | print("Transferred model to GPU")
25 |
26 | coords, native_seq = load_coords(args.pdbfile, args.chain)
27 | print('Native sequence loaded from structure file:')
28 | print(native_seq)
29 | print('\n')
30 | ll, _ = score_sequence(
31 | model, alphabet, coords, native_seq)
32 | print('Native sequence')
33 | print(f'Log likelihood: {ll:.2f}')
34 | print(f'Perplexity: {np.exp(-ll):.2f}')
35 | print('\nScoring variant sequences from sequence file..\n')
36 | infile = FastaFile()
37 | infile.read(args.seqpath)
38 | seqs = get_sequences(infile)
39 | Path(args.outpath).parent.mkdir(parents=True, exist_ok=True)
40 | with open(args.outpath, 'w') as fout:
41 | fout.write('seqid,log_likelihood\n')
42 | for header, seq in tqdm(seqs.items()):
43 | ll, _ = score_sequence(
44 | model, alphabet, coords, str(seq))
45 | fout.write(header + ',' + str(ll) + '\n')
46 | print(f'Results saved to {args.outpath}')
47 |
48 |
49 | def score_multichain_backbone(model, alphabet, args):
50 | if torch.cuda.is_available() and not args.nogpu:
51 | model = model.cuda()
52 | print("Transferred model to GPU")
53 |
54 | structure = load_structure(args.pdbfile)
55 | coords, native_seqs = extract_coords_from_complex(structure)
56 | target_chain_id = args.chain
57 | native_seq = native_seqs[target_chain_id]
58 | order = args.order
59 |
60 | print('Native sequence loaded from structure file:')
61 | print(native_seq)
62 | print('\n')
63 |
64 | ll_complex, ll_targetchain = score_sequence_in_complex(
65 | model,
66 | alphabet,
67 | coords,
68 | native_seqs,
69 | target_chain_id,
70 | native_seq,
71 | order=order,
72 | )
73 | print('Native sequence')
74 | print(f'Log likelihood of complex: {ll_complex:.2f}')
75 | print(f'Log likelihood of target chain: {ll_targetchain:.2f}')
76 | print(f'Perplexity: {np.exp(ll_complex):.2f}')
77 |
78 | print('\nScoring variant sequences from sequence file..\n')
79 | infile = FastaFile()
80 | infile.read(args.seqpath)
81 | seqs = get_sequences(infile)
82 | Path(args.outpath).parent.mkdir(parents=True, exist_ok=True)
83 | with open(args.outpath, 'w') as fout:
84 | fout.write('seqid,log_likelihood, log_likelihood_target\n')
85 | for header, seq in tqdm(seqs.items()):
86 | ll_complex, ll_targetchain = score_sequence_in_complex(
87 | model,
88 | alphabet,
89 | coords,
90 | native_seqs,
91 | target_chain_id,
92 | str(seq),
93 | order=order,
94 | )
95 | fout.write(header + ',' + str(ll_complex) + ',' + str(ll_targetchain) + '\n')
96 | print(f'Results saved to {args.outpath}')
97 |
98 | def get_model_checkpoint_path(filename):
99 | # Expanding the user's home directory
100 | return os.path.expanduser(f"~/.cache/torch/hub/checkpoints/{filename}")
101 |
102 | def main():
103 | parser = argparse.ArgumentParser(
104 | description='Score sequences based on a given structure.'
105 | )
106 | parser.add_argument(
107 | 'pdbfile', type=str,
108 | help='input filepath, either .pdb or .cif',
109 | )
110 | parser.add_argument(
111 | '--seqpath', type=str,
112 | help='input filepath for variant sequences in a .fasta file',
113 | )
114 | parser.add_argument(
115 | '--outpath', type=str,
116 | help='output filepath for scores of variant sequences',
117 | )
118 | parser.add_argument(
119 | '--chain', type=str,
120 | help='chain id for the chain of interest', default='A',
121 | )
122 | parser.set_defaults(multichain_backbone=True)
123 | parser.add_argument(
124 | '--multichain-backbone', action='store_true',
125 | help='use the backbones of all chains in the input for conditioning'
126 | )
127 | parser.add_argument(
128 | '--order', type=str, default=None,
129 | help='for multichain, specify order'
130 | )
131 | parser.add_argument(
132 | '--singlechain-backbone', dest='multichain_backbone',
133 | action='store_false',
134 | help='use the backbone of only target chain in the input for conditioning'
135 | )
136 |
137 | parser.add_argument(
138 | "--nogpu", action="store_true",
139 | help="Do not use GPU even if available"
140 | )
141 | args = parser.parse_args()
142 |
143 | if args.outpath is None:
144 | args.outpath = f'output/{args.pdbfile[:-4]}-chain{args.chain}_scores.csv'
145 |
146 | model_checkpoint_path = get_model_checkpoint_path('esm_if1_20220410.pt')
147 | with warnings.catch_warnings():
148 | warnings.simplefilter('ignore', UserWarning)
149 | model, alphabet = esm.pretrained.load_model_and_alphabet( \
150 | model_checkpoint_path \
151 | )
152 | model = model.eval()
153 |
154 |
155 | if args.multichain_backbone:
156 | score_multichain_backbone(model, alphabet, args)
157 | else:
158 | score_singlechain_backbone(model, alphabet, args)
159 |
160 |
161 | if __name__ == '__main__':
162 | main()
163 |
--------------------------------------------------------------------------------
/bin/esm1v_ab_benchmarking.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | run_prediction() {
4 | python ../esm/examples/variant-prediction/predict.py \
5 | --model-location esm1v_t33_650M_UR90S_1 esm1v_t33_650M_UR90S_2 esm1v_t33_650M_UR90S_3 esm1v_t33_650M_UR90S_4 esm1v_t33_650M_UR90S_5 \
6 | --sequence "$1" \
7 | --dms-input "$2" \
8 | --mutation-col mutant \
9 | --dms-output "$3" \
10 | --scoring-strategy 'masked-marginals' \
11 | --offset-idx "$4"
12 | }
13 |
14 | # CR6261
15 | run_prediction "EVQLVESGAEVKKPGSSVKVSCKASGGPFRSYAISWVRQAPGQGPEWMGGIIPIFGTTKYAPKFQGRVTITADDFAGTVYMELSSLRSEDTAMYYCAKHMGYQVRETMDVWGKGTTVTVSS" \
16 | "data/ab_mutagenesis_expts/cr6261/cr6261_singleMuts_exp_data.csv" \
17 | "output/ab_mutagenesis_expts/cr6261/cr6261_exp_data_maskMargLabeled.csv" \
18 | 1
19 |
20 | # CR9114
21 | run_prediction "QVQLVQSGAEVKKPGSSVKVSCKSSGGTSNNYAISWVRQAPGQGLDWMGGISPIFGSTAYAQKFQGRVTISADIFSNTAYMELNSLTSEDTAVYFCARHGNYYYYSGMDVWGQGTTVTVSS" \
22 | "data/ab_mutagenesis_expts/cr9114/cr9114_singleMuts_exp_data.csv" \
23 | "output/ab_mutagenesis_expts/cr9114/cr9114_exp_data_maskMargLabeled.csv" \
24 | 1
25 |
26 | # G6 HC
27 | run_prediction "EVQLVESGGGLVQPGGSLRLSCAASGFTISDYWIHWVRQAPGKGLEWVAGITPAGGYTYYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARFVFFLPYAMDYWGQGTLVTV" \
28 | "data/ab_mutagenesis_expts/g6/g6_hc_exp_data.csv" \
29 | "output/ab_mutagenesis_expts/g6/g6Hc_exp_data_maskMargLabeled.csv" \
30 | 1
31 |
32 | # G6 LC
33 | run_prediction "DIQMTQSPSSLSASVGDRVTITCRASQDVSTAVAWYQQKPGKAPKLLIYSASFLYSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYTTPPTFGQGTKVEIK" \
34 | "data/ab_mutagenesis_expts/g6/g6_lc_exp_data.csv" \
35 | "output/ab_mutagenesis_expts/g6/g6Lc_exp_data_maskMargLabeled.csv" \
36 | 1
37 |
38 | # G6 HC (both chains)
39 | run_prediction "EVQLVESGGGLVQPGGSLRLSCAASGFTISDYWIHWVRQAPGKGLEWVAGITPAGGYTYYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARFVFFLPYAMDYWGQGTLVTVDIQMTQSPSSLSASVGDRVTITCRASQDVSTAVAWYQQKPGKAPKLLIYSASFLYSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYTTPPTFGQGTKVEIK" \
40 | "data/ab_mutagenesis_expts/g6/g6_hc_exp_data.csv" \
41 | "output/ab_mutagenesis_expts/g6/g6Hc_esm1vbothchains_exp_data_maskMargLabeled.csv" \
42 | 1
43 |
44 | # G6 LC (both chains)
45 | run_prediction "EVQLVESGGGLVQPGGSLRLSCAASGFTISDYWIHWVRQAPGKGLEWVAGITPAGGYTYYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARFVFFLPYAMDYWGQGTLVTVDIQMTQSPSSLSASVGDRVTITCRASQDVSTAVAWYQQKPGKAPKLLIYSASFLYSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYTTPPTFGQGTKVEIK" \
46 | "data/ab_mutagenesis_expts/g6/g6_lc_exp_data.csv" \
47 | "output/ab_mutagenesis_expts/g6/g6Lc_esm1vbothchains_exp_data_maskMargLabeled.csv" \
48 | -117
49 |
50 | # CR9114 (both chains)
51 | run_prediction "QVQLVQSGAEVKKPGSSVKVSCKSSGGTSNNYAISWVRQAPGQGLDWMGGISPIFGSTAYAQKFQGRVTISADIFSNTAYMELNSLTSEDTAVYFCARHGNYYYYSGMDVWGQGTTVTVSSQSALTQPPAVSGTPGQRVTISCSGSDSNIGRRSVNWYQQFPGTAPKLLIYSNDQRPSVVPDRFSGSKSGTSASLAISGLQSEDEAEYYCAAWDDSLKGAVFGGGTQLTVL" \
52 | "data/ab_mutagenesis_expts/cr9114/cr9114_singleMuts_exp_data.csv" \
53 | "output/ab_mutagenesis_expts/cr9114/cr9114_esm1vbothchains_exp_data_maskMargLabeled.csv" \
54 | 1
55 |
56 | # CR6261 (both chains)
57 | run_prediction "EVQLVESGAEVKKPGSSVKVSCKASGGPFRSYAISWVRQAPGQGPEWMGGIIPIFGTTKYAPKFQGRVTITADDFAGTVYMELSSLRSEDTAMYYCAKHMGYQVRETMDVWGKGTTVTVSSQSVLTQPPSVSAAPGQKVTISCSGSSSNIGNDYVSWYQQLPGTAPKLLIYDNNKRPSGIPDRFSGSKSGTSATLGITGLQTGDEANYYCATWDRRPTAYVVFGGGTKLTVL" \
58 | "data/ab_mutagenesis_expts/cr6261/cr6261_singleMuts_exp_data.csv" \
59 | "output/ab_mutagenesis_expts/cr6261/cr6261_esm1vbothchains_exp_data_maskMargLabeled.csv" \
60 | 1
61 |
62 | # CR9114 (antibody + antigen)
63 | run_prediction "QVQLVQSGAEVKKPGSSVKVSCKSSGGTSNNYAISWVRQAPGQGLDWMGGISPIFGSTAYAQKFQGRVTISADIFSNTAYMELNSLTSEDTAVYFCARHGNYYYYSGMDVWGQGTTVTVSSQSALTQPPAVSGTPGQRVTISCSGSDSNIGRRSVNWYQQFPGTAPKLLIYSNDQRPSVVPDRFSGSKSGTSASLAISGLQSEDEAEYYCAAWDDSLKGAVFGGGTQLTVLADPGDQICIGYHANNSTEQVDTIMEKNVTVTHAQDILEKKHNGKLCDLDGVKPLILRDCSVAGWLLGNPMCDEFINVPEWSYIVEKANPVNDLCYPGDFNDYEELKHLLSRINHFEKIQIIPKSSWSSHEASLGVSSACPYQGKSSFFRNVVWLIKKNSTYPTIKRSYNNTNQEDLLVLWGIHHPNDAAEQTKLYQNPTTYISVGTSTLNQRLVPRIATRSKVNGQSGRMEFFWTILKPNDAINFESNGNFIAPEYAYKIVKKGDSTIMKSELEYGNCNTKCQTPMGAINSSMPFHNIHPLTIGECPKYVKSNRLVLATGLRNSPQRERRRKKRGLFGAIAGFIEGGWQGMVDGWYGYHHSNEQGSGYAADKESTQKAIDGVTNKVNSIIDKMNTQFEAVGREFNNLERRIENLNKKMEDGFLDVWTYNAELLVLMENERTLDFHDSNVKNLYDKVRLQLRDNAKELGNGCFEFYHKCDNECMESVRNGTYDYPQYSEEARLKREEISSGR" \
64 | "data/ab_mutagenesis_expts/cr9114/cr9114_singleMuts_exp_data.csv" \
65 | "output/ab_mutagenesis_expts/cr9114/cr9114_esm1vAbAg_exp_data_maskMargLabeled.csv" \
66 | 1
67 |
68 | # CR6261 (antibody + antigen)
69 | run_prediction "EVQLVESGAEVKKPGSSVKVSCKASGGPFRSYAISWVRQAPGQGPEWMGGIIPIFGTTKYAPKFQGRVTITADDFAGTVYMELSSLRSEDTAMYYCAKHMGYQVRETMDVWGKGTTVTVSSQSVLTQPPSVSAAPGQKVTISCSGSSSNIGNDYVSWYQQLPGTAPKLLIYDNNKRPSGIPDRFSGSKSGTSATLGITGLQTGDEANYYCATWDRRPTAYVVFGGGTKLTVLADPGDTICIGYHANNSTDTVDTVLEKNVTVTHSVNLLEDSHNGKLCKLKGIAPLQLGKCNIAGWLLGNPECDLLLTASSWSYIVETSNSENGTCYPGDFIDYEELREQLSSVSSFEKFEIFPKTSSWPNHETTKGVTAACSYAGASSFYRNLLWLTKKGSSYPKLSKSYVNNKGKEVLVLWGVHHPPTGTDQQSLYQNADAYVSVGSSKYNRRFTPEIAARPKVRDQAGRMNYYWTLLEPGDTITFEATGNLIAPWYAFALNRGSGSGIITSDAPVHDCNTKCQTPHGAINSSLPFQNIHPVTIGECPKYVRSTKLRMATGLRNIPSIQSRGLFGAIAGFIEGGWTGMIDGWYGYHHQNEQGSGYAADQKSTQNAIDGITNKVNSVIEKMNTQFTAVGKEFNNLERRIENLNKKVDDGFLDIWTYNAELLVLLENERTLDFHDSNVRNLYEKVKSQLKNNAKEIGNGCFEFYHKCDDACMESVRNGTYDYPKYSEESKLNREEIDGVSGR" \
70 | "data/ab_mutagenesis_expts/cr6261/cr6261_singleMuts_exp_data.csv" \
71 | "output/ab_mutagenesis_expts/cr6261/cr6261_esm1vAbAg_exp_data_maskMargLabeled.csv" \
72 | 1
73 |
74 | # G6 HC (antibody + antigen)
75 | run_prediction "EVQLVESGGGLVQPGGSLRLSCAASGFTISDYWIHWVRQAPGKGLEWVAGITPAGGYTYYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARFVFFLPYAMDYWGQGTLVTVDIQMTQSPSSLSASVGDRVTITCRASQDVSTAVAWYQQKPGKAPKLLIYSASFLYSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYTTPPTFGQGTKVEIKGQNHHEVVKFMDVYQRSYCHPIETLVDIFQEYPDEIEYIFKPSCVPLMRCGGCCNDEGLECVPTEESNITMQIMRIKPHQGQHIGEMSFLQHNKCECRPKKD" \
76 | "data/ab_mutagenesis_expts/g6/g6_hc_exp_data.csv" \
77 | "output/ab_mutagenesis_expts/g6/g6Hc_esm1vAbAg_exp_data_maskMargLabeled.csv" \
78 | 1
79 |
80 | # G6 LC (antibody + antigen)
81 | run_prediction "EVQLVESGGGLVQPGGSLRLSCAASGFTISDYWIHWVRQAPGKGLEWVAGITPAGGYTYYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARFVFFLPYAMDYWGQGTLVTVDIQMTQSPSSLSASVGDRVTITCRASQDVSTAVAWYQQKPGKAPKLLIYSASFLYSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYTTPPTFGQGTKVEIKGQNHHEVVKFMDVYQRSYCHPIETLVDIFQEYPDEIEYIFKPSCVPLMRCGGCCNDEGLECVPTEESNITMQIMRIKPHQGQHIGEMSFLQHNKCECRPKKD" \
82 | "data/ab_mutagenesis_expts/g6/g6_lc_exp_data.csv" \
83 | "output/ab_mutagenesis_expts/g6/g6Lc_esm1vAbAg_exp_data_maskMargLabeled.csv" \
84 | -117
--------------------------------------------------------------------------------
/bin/multichain_util.py:
--------------------------------------------------------------------------------
1 | import biotite.structure
2 | import numpy as np
3 | import torch
4 | from typing import Sequence, Tuple, List
5 | from util import (
6 | load_structure,
7 | extract_coords_from_structure,
8 | load_coords,
9 | get_sequence_loss,
10 | get_encoder_output,
11 | )
12 | import util
13 |
14 |
15 | def extract_coords_from_complex(structure: biotite.structure.AtomArray):
16 | """
17 | Args:
18 | structure: biotite AtomArray
19 | Returns:
20 | Tuple (coords_list, seq_list)
21 | - coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
22 | coordinates representing the backbone of each chain
23 | - seqs: Dictionary mapping chain ids to native sequences of each chain
24 | """
25 | coords = {}
26 | seqs = {}
27 | all_chains = biotite.structure.get_chains(structure)
28 | for chain_id in all_chains:
29 | chain = structure[structure.chain_id == chain_id]
30 | coords[chain_id], seqs[chain_id] = extract_coords_from_structure(chain)
31 | return coords, seqs
32 |
33 |
34 | def load_complex_coords(fpath, chains):
35 | """
36 | Args:
37 | fpath: filepath to either pdb or cif file
38 | chains: the chain ids (the order matters for autoregressive model)
39 | Returns:
40 | Tuple (coords_list, seq_list)
41 | - coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
42 | coordinates representing the backbone of each chain
43 | - seqs: Dictionary mapping chain ids to native sequences of each chain
44 | """
45 | structure = load_structure(fpath, chains)
46 | return extract_coords_from_complex(structure)
47 |
48 | #*
49 | def _concatenate_coords(
50 | coords,
51 | target_chain_id,
52 | padding_length=10,
53 | order=None
54 | ):
55 | """
56 | Args:
57 | coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
58 | coordinates representing the backbone of each chain
59 | target_chain_id: The chain id to sample sequences for
60 | padding_length: Length of padding between concatenated chains
61 | Returns:
62 | Tuple (coords, seq)
63 | - coords_concatenated is an L x 3 x 3 array for N, CA, C coordinates, a
64 | concatenation of the chains with padding in between
65 | AND target chain placed first
66 | - seq is the extracted sequence, with padding tokens inserted
67 | between the concatenated chains
68 | """
69 | pad_coords = np.full((padding_length, 3, 3), np.nan, dtype=np.float32)
70 | if order is None:
71 | order = (
72 | [ target_chain_id ] +
73 | [ chain_id for chain_id in coords if chain_id != target_chain_id ]
74 | )
75 | coords_list, coords_chains = [], []
76 | for idx, chain_id in enumerate(order):
77 | if idx > 0:
78 | coords_list.append(pad_coords)
79 | coords_chains.append([ 'pad' ] * padding_length)
80 | coords_list.append(list(coords[chain_id]))
81 | coords_chains.append([ chain_id ] * coords[chain_id].shape[0])
82 | coords_concatenated = np.concatenate(coords_list, axis=0)
83 | coords_chains = np.concatenate(coords_chains, axis=0).ravel()
84 | return coords_concatenated, coords_chains
85 |
86 | #*
87 | def _concatenate_seqs(
88 | native_seqs,
89 | target_seq,
90 | target_chain_id,
91 | padding_length=10,
92 | order=None,
93 | ):
94 | """
95 | Args:
96 | native_seqs: Dictionary mapping chain ids to corresponding AA sequence
97 | target_seq: The chain id to sample sequences for
98 | padding_length: Length of padding between concatenated chains
99 | Returns:
100 | native_seqs_concatenated: Array of length L, concatenation of the chain
101 | sequences with padding in between
102 | """
103 | if order is None:
104 | order = (
105 | [ target_chain_id ] +
106 | [ chain_id for chain_id in native_seqs if chain_id != target_chain_id ]
107 | )
108 | native_seqs_list = []
109 | for idx, chain_id in enumerate(order):
110 | if idx > 0:
111 | native_seqs_list.append([''] * (padding_length - 1) + [''])
112 | if chain_id == target_chain_id:
113 | native_seqs_list.append(list(target_seq))
114 | else:
115 | native_seqs_list.append(list(native_seqs[chain_id]))
116 | native_seqs_concatenated = ''.join(np.concatenate(native_seqs_list, axis=0))
117 | return native_seqs_concatenated
118 |
119 |
120 | #*
121 | def sample_sequence_in_complex(model, coords, target_chain_id, temperature=1.,
122 | padding_length=10):
123 | """
124 | Samples sequence for one chain in a complex.
125 | Args:
126 | model: An instance of the GVPTransformer model
127 | coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
128 | coordinates representing the backbone of each chain
129 | target_chain_id: The chain id to sample sequences for
130 | padding_length: padding length in between chains
131 | Returns:
132 | Sampled sequence for the target chain
133 | """
134 | target_chain_len = coords[target_chain_id].shape[0]
135 | all_coords, coords_chains = _concatenate_coords(coords, target_chain_id)
136 | device = next(model.parameters()).device
137 |
138 | # Supply padding tokens for other chains to avoid unused sampling for speed
139 | padding_pattern = [''] * all_coords.shape[0]
140 | for i in range(target_chain_len):
141 | padding_pattern[i] = ''
142 | sampled = model.sample(all_coords, partial_seq=padding_pattern,
143 | temperature=temperature, device=device)
144 | sampled = sampled[:target_chain_len]
145 | return sampled
146 |
147 |
148 | #*
149 | def score_sequence_in_complex(
150 | model,
151 | alphabet,
152 | coords,
153 | native_seqs,
154 | target_chain_id,
155 | target_seq,
156 | padding_length=10,
157 | order=None,
158 | ):
159 | """
160 | Scores sequence for one chain in a complex.
161 | Args:
162 | model: An instance of the GVPTransformer model
163 | alphabet: Alphabet for the model
164 | coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
165 | coordinates representing the backbone of each chain
166 | native_seqs: Dictionary mapping chain ids to sequence
167 | extracted from each chain
168 | target_chain_id: The chain id to sample sequences for
169 | target_seq: Target sequence for the target chain for scoring.
170 | padding_length: padding length in between chains
171 | Returns:
172 | Tuple (ll_fullseq, ll_withcoord)
173 | - ll_fullseq: Average log-likelihood over the full target chain
174 | - ll_targetseq Average log-likelihood in target chain excluding those
175 | residues without coordinates
176 | """
177 |
178 | assert(len(target_seq) == len(native_seqs[target_chain_id]))
179 |
180 | all_coords, coords_chains = _concatenate_coords(
181 | coords,
182 | target_chain_id,
183 | order=order,
184 | )
185 | all_seqs = _concatenate_seqs(
186 | native_seqs,
187 | target_seq,
188 | target_chain_id,
189 | order=order,
190 | )
191 |
192 | loss, target_padding_mask = get_sequence_loss(model, alphabet, all_coords,
193 | all_seqs)
194 | assert(all_coords.shape[0] == coords_chains.shape[0] == loss.shape[0])
195 |
196 | ll_fullseq = -np.mean(loss[coords_chains != 'pad'])
197 | ll_targetseq = -np.mean(loss[coords_chains == target_chain_id])
198 |
199 | return ll_fullseq, ll_targetseq
200 |
201 |
202 | def get_encoder_output_for_complex(model, alphabet, coords, target_chain_id):
203 | """
204 | Args:
205 | model: An instance of the GVPTransformer model
206 | alphabet: Alphabet for the model
207 | coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
208 | coordinates representing the backbone of each chain
209 | target_chain_id: The chain id to sample sequences for
210 | Returns:
211 | Dictionary mapping chain id to encoder output for each chain
212 | """
213 | all_coords = _concatenate_coords(coords, target_chain_id)
214 | all_rep = get_encoder_output(model, alphabet, all_coords)
215 | target_chain_len = coords[target_chain_id].shape[0]
216 | return all_rep[:target_chain_len]
--------------------------------------------------------------------------------
/bin/plot_mpnn_benchmarks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import bokeh.io
4 | import bokeh.plotting
5 | import bokeh.palettes
6 | from bokeh.transform import factor_cmap
7 | import datashader
8 | import holoviews as hv
9 | import holoviews.operation.datashader
10 | import os
11 | from natsort import natsorted
12 |
13 | hv.extension("bokeh")
14 |
15 | import warnings
16 | # Suppress FutureWarning messages
17 | warnings.simplefilter(action='ignore',)
18 |
19 | cr9114_dict = {
20 | 'ab_name' : 'CR9114',
21 | 'files' : {
22 | 'ESM-IF1': 'output/ab_mutagenesis_expts/cr9114/4fqi_ablh_scores.csv',
23 | 'ProteinMPNN': 'output/ab_mutagenesis_expts/cr9114/mpnn/score_only/',
24 |
25 | },
26 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/cr9114/cr9114_exp_data.csv',
27 | dtype = {'genotype': str},
28 | ).rename(columns={'h1_mean': 'H1', 'h3_mean' : 'H3'}),
29 | 'ag_columns': ['H1', 'H3'],
30 | 'expt_type': 'Combinatorial Mutagenesis for Affinity',
31 | 'palette': bokeh.palettes.Spectral6
32 | }
33 |
34 | cr6261_dict = {
35 | 'ab_name' : 'CR6261',
36 | 'files' : {
37 | 'ESM-IF1': 'output/ab_mutagenesis_expts/cr6261/3gbn_ablh_scores.csv',
38 | 'ProteinMPNN': 'output/ab_mutagenesis_expts/cr6261/mpnn/score_only/'
39 | },
40 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/cr6261/cr6261_exp_data.csv',
41 | dtype={'genotype': str},
42 | ).rename(columns={'h1_mean': 'H1', 'h9_mean' : 'H9'}),
43 | 'ag_columns': ['H1', 'H9'],
44 | 'expt_type': 'Combinatorial Mutagenesis for Affinity',
45 | 'palette': bokeh.palettes.Dark2_6,
46 | }
47 |
48 |
49 | g6LC_dict = {
50 | 'ab_name' : 'g6',
51 | 'files' : {
52 | 'ESM-IF1':'output/ab_mutagenesis_expts/g6/2fjg_vlh_lc_scores.csv',
53 | 'ProteinMPNN':'output/ab_mutagenesis_expts/g6/proteinMpnnLC/score_only/',
54 |
55 | },
56 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/g6/g6_lc_exp_data.csv'),
57 | 'ag_columns': ['norm_binding'],
58 | 'expt_type': 'Deep Mutational Scan for Binding',
59 | 'palette': bokeh.palettes.Pastel1_6,
60 | 'chain': 'VL'
61 | }
62 |
63 | g6HC_dict = {
64 | 'ab_name' : 'g6',
65 | 'files' : {
66 | 'ESM-IF1': 'output/ab_mutagenesis_expts/g6/2fjg_vlh_hc_scores.csv',
67 | 'ProteinMPNN':'output/ab_mutagenesis_expts/g6/proteinMpnnHC/score_only/',
68 | },
69 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/g6/g6_hc_exp_data.csv'),
70 | 'ag_columns': ['norm_binding'],
71 | 'expt_type': 'Deep Mutational Scan for Binding',
72 | 'palette': bokeh.palettes.Pastel1_4,
73 | 'chain': 'VH'
74 | }
75 | def combine_mpnn_scores(dir):
76 | global_scores = []
77 | npz_files = [f for f in os.listdir(dir) if f.endswith('.npz')]
78 | npz_files = natsorted(npz_files)[:-1]
79 |
80 | for fname in npz_files:
81 | file_path = os.path.join(dir, fname)
82 | data = np.load(file_path)
83 | global_scores.append(-1* data['global_score'][0])
84 | data.close()
85 |
86 |
87 | return global_scores
88 |
89 |
90 | #get melted correlations matrix of structure input x target antigen, for a given antibody
91 | def get_corr(ab_name, files, dms_df, ag_columns,):
92 |
93 | #retreive correlations from inverse fold scores
94 | for key, filepath in files.items():
95 | column_label = key
96 | if key == 'ProteinMPNN':
97 | scores = combine_mpnn_scores(filepath)
98 | dms_df[column_label] = scores
99 | elif key == 'ESM-IF1':
100 | df = pd.read_csv(files['ESM-IF1']) # Read the CSV file into a DataFrame
101 | log_likelihood_column = df['log_likelihood'] # Extract the 'log_likelihood' column
102 | dms_df[column_label] = log_likelihood_column
103 |
104 |
105 | conditions = list(files.keys())
106 |
107 | correlations = dms_df[ag_columns + conditions].corr(method='spearman')
108 | correlations = correlations.drop(conditions, axis= 1)
109 | correlations = correlations.drop(ag_columns, axis= 0)
110 |
111 | # Melt the correlations DataFrame and rename the index column
112 | melted_correlations = (
113 | correlations
114 | .reset_index()
115 | .melt(id_vars='index', var_name='Target', value_name='Correlation')
116 | .rename(columns={'index': 'Input'})
117 | )
118 |
119 | print(melted_correlations)
120 | return melted_correlations
121 |
122 | return all_ag_melted
123 |
124 | #get melted correlations matrix of structure input x target antigen, for g6 antibody
125 | def get_g6_corr(g6Hc_dict, g6LC_dict, dropLOQ = False ):
126 |
127 | ab_name = 'g6'
128 | ag_columns = g6LC_dict['ag_columns']
129 | vh_and_vl_dms_df = pd.DataFrame({})
130 |
131 | for d in [g6HC_dict, g6LC_dict]:
132 |
133 | files = d['files']
134 | dms_df = d['dms_df']
135 |
136 | #retreive correlations from abysis, ablang, and esm1v scores
137 | for key, filepath in files.items():
138 | column_label = key
139 | if key == 'ProteinMPNN':
140 | scores = combine_mpnn_scores(filepath)
141 | dms_df[column_label] = scores
142 | elif key == 'ESM-IF1':
143 | df = pd.read_csv(files['ESM-IF1']) # Read the CSV file into a DataFrame
144 | log_likelihood_column = df['log_likelihood'] # Extract the 'log_likelihood' column
145 | dms_df[column_label] = log_likelihood_column
146 |
147 |
148 | vh_and_vl_dms_df = pd.concat([vh_and_vl_dms_df, dms_df], ignore_index= True)
149 |
150 | conditions = list(files.keys())
151 |
152 | correlations = vh_and_vl_dms_df[ag_columns + conditions].corr(method='spearman')
153 | correlations = correlations.drop(conditions, axis= 1)
154 | correlations = correlations.drop(ag_columns, axis= 0)
155 |
156 | # Melt the correlations DataFrame and rename the index column
157 | melted_correlations = (
158 | correlations
159 | .reset_index()
160 | .melt(id_vars='index', var_name='Target', value_name='Correlation')
161 | .rename(columns={'index': 'Input'})
162 | )
163 | melted_correlations['Target'] = ['VEGF-A'] * len(melted_correlations)
164 |
165 | print(melted_correlations)
166 | return melted_correlations
167 |
168 |
169 | #plot correlation bars for cr antibodies: cr6261 and cr9114
170 | def plot_hbar(title, melted_correlations, palette = bokeh.palettes.Spectral6, ):
171 |
172 | melted_correlations['cats'] = melted_correlations.apply(lambda x: (x["Target"], x["Input"]), axis = 1)
173 | factors = list(melted_correlations.cats)[::-1]
174 |
175 | p = bokeh.plotting.figure(
176 | height=340,
177 | width=440,
178 | x_axis_label="Spearman Correlation",
179 | x_range=[0, 1],
180 | y_range=bokeh.models.FactorRange(*factors),
181 | tools="save",
182 | title = title
183 | )
184 |
185 |
186 | p.hbar(
187 | source=melted_correlations,
188 | y="cats",
189 | right="Correlation",
190 | height=0.6,
191 | line_color = 'black',
192 | fill_color=bokeh.palettes.Dark2_6[0]
193 |
194 | )
195 |
196 | labels_df = melted_correlations
197 | labels_df['corr_str'] = labels_df['Correlation'].apply(lambda x: round(x, 2)).astype(str)
198 | labels_source = bokeh.models.ColumnDataSource(labels_df)
199 |
200 | labels = bokeh.models.LabelSet(x='Correlation', y='cats', text='corr_str',text_font_size = "10px",
201 | x_offset=12, y_offset=-5, source=labels_source, render_mode='canvas')
202 |
203 | p.ygrid.grid_line_color = None
204 | p.y_range.range_padding = 0.1
205 | p.add_layout(labels)
206 | p.legend.visible = False
207 |
208 | p.output_backend = "svg"
209 | return p
210 |
211 | if __name__ == '__main__':
212 | datasets = [cr9114_dict, cr6261_dict, (g6HC_dict, g6LC_dict) ]
213 |
214 | all_corr_plots = []
215 |
216 | for d in datasets:
217 | if type(d) is tuple:
218 | g6Hc, g6Lc = d
219 | title = g6Hc['ab_name'] + ', ' + g6Hc['expt_type']
220 | g6_combined = get_g6_corr(g6Hc, g6Lc)
221 |
222 | all_corr_plots.append(plot_hbar(title, g6_combined, g6Lc['palette']))
223 |
224 | else:
225 | corr_df = get_corr(d['ab_name'], d['files'], d['dms_df'], d['ag_columns'])
226 | title = d['ab_name'] + ', ' + d['expt_type']
227 | all_corr_plots.append(plot_hbar(title, corr_df, d['palette']))
228 |
229 |
230 | all_corr_fname = f"output/ab_mutagenesis_expts/mpnn_benchmarks.html"
231 | bokeh.plotting.output_file(all_corr_fname)
232 | bokeh.io.show(bokeh.layouts.gridplot(all_corr_plots, ncols = len(all_corr_plots)))
233 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: struct-evo
2 | channels:
3 | - pytorch
4 | - pyg
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - _openmp_mutex=5.1=1_gnu
10 | - abseil-cpp=20211102.0=hd4dd3e8_0
11 | - anyio=4.2.0=py39h06a4308_0
12 | - argon2-cffi=21.3.0=pyhd3eb1b0_0
13 | - argon2-cffi-bindings=21.2.0=py39h7f8727e_0
14 | - arrow-cpp=14.0.2=h374c478_1
15 | - asttokens=2.0.5=pyhd3eb1b0_0
16 | - attrs=23.1.0=py39h06a4308_0
17 | - aws-c-auth=0.6.19=h5eee18b_0
18 | - aws-c-cal=0.5.20=hdbd6064_0
19 | - aws-c-common=0.8.5=h5eee18b_0
20 | - aws-c-compression=0.2.16=h5eee18b_0
21 | - aws-c-event-stream=0.2.15=h6a678d5_0
22 | - aws-c-http=0.6.25=h5eee18b_0
23 | - aws-c-io=0.13.10=h5eee18b_0
24 | - aws-c-mqtt=0.7.13=h5eee18b_0
25 | - aws-c-s3=0.1.51=hdbd6064_0
26 | - aws-c-sdkutils=0.1.6=h5eee18b_0
27 | - aws-checksums=0.1.13=h5eee18b_0
28 | - aws-crt-cpp=0.18.16=h6a678d5_0
29 | - aws-sdk-cpp=1.10.55=h721c034_0
30 | - backcall=0.2.0=pyhd3eb1b0_0
31 | - beautifulsoup4=4.12.3=py39h06a4308_0
32 | - blas=1.0=mkl
33 | - bleach=4.1.0=pyhd3eb1b0_0
34 | - bokeh=2.4.3=py39h06a4308_0
35 | - boost-cpp=1.82.0=hdb19cb5_2
36 | - bottleneck=1.3.7=py39ha9d4c09_0
37 | - brotli=1.0.9=h5eee18b_8
38 | - brotli-bin=1.0.9=h5eee18b_8
39 | - brotli-python=1.0.9=py39h5a03fae_7
40 | - bzip2=1.0.8=h5eee18b_6
41 | - c-ares=1.19.1=h5eee18b_0
42 | - ca-certificates=2024.3.11=h06a4308_0
43 | - certifi=2024.6.2=py39h06a4308_0
44 | - cffi=1.16.0=py39h5eee18b_1
45 | - charset-normalizer=3.3.2=pyhd8ed1ab_0
46 | - click=8.1.7=py39h06a4308_0
47 | - cloudpickle=2.2.1=py39h06a4308_0
48 | - colorama=0.4.6=pyhd8ed1ab_0
49 | - colorcet=3.0.0=py39h06a4308_0
50 | - comm=0.2.1=py39h06a4308_0
51 | - contourpy=1.2.0=py39hdb19cb5_0
52 | - cudatoolkit=11.3.1=h2bc3f7f_2
53 | - cycler=0.11.0=pyhd3eb1b0_0
54 | - cyrus-sasl=2.1.28=h52b45da_1
55 | - cytoolz=0.12.2=py39h5eee18b_0
56 | - dask=2023.11.0=py39h06a4308_0
57 | - dask-core=2023.11.0=py39h06a4308_0
58 | - datashader=0.14.1=py39h06a4308_1
59 | - datashape=0.5.4=py39h06a4308_1
60 | - dbus=1.13.18=hb2f20db_0
61 | - debugpy=1.6.7=py39h6a678d5_0
62 | - decorator=5.1.1=pyhd3eb1b0_0
63 | - defusedxml=0.7.1=pyhd3eb1b0_0
64 | - distributed=2023.11.0=py39h06a4308_0
65 | - entrypoints=0.4=py39h06a4308_0
66 | - et_xmlfile=1.1.0=py39h06a4308_0
67 | - exceptiongroup=1.2.0=py39h06a4308_0
68 | - executing=0.8.3=pyhd3eb1b0_0
69 | - expat=2.6.2=h6a678d5_0
70 | - fontconfig=2.14.1=h4c34cd2_2
71 | - fonttools=4.51.0=py39h5eee18b_0
72 | - freetype=2.12.1=h4a9f257_0
73 | - fsspec=2024.3.1=py39h06a4308_0
74 | - gflags=2.2.2=h6a678d5_1
75 | - glib=2.78.4=h6a678d5_0
76 | - glib-tools=2.78.4=h6a678d5_0
77 | - glog=0.5.0=h6a678d5_1
78 | - grpc-cpp=1.48.2=he1ff14a_1
79 | - gst-plugins-base=1.14.1=h6a678d5_1
80 | - gstreamer=1.14.1=h5eee18b_1
81 | - heapdict=1.0.1=pyhd3eb1b0_0
82 | - holoviews=1.15.0=py39h06a4308_1
83 | - hvplot=0.8.0=py39h06a4308_0
84 | - icu=73.1=h6a678d5_0
85 | - idna=3.7=pyhd8ed1ab_0
86 | - importlib-metadata=7.0.1=py39h06a4308_0
87 | - importlib_resources=6.1.1=py39h06a4308_1
88 | - intel-openmp=2023.1.0=hdb19cb5_46306
89 | - ipykernel=6.28.0=py39h06a4308_0
90 | - ipython=8.15.0=py39h06a4308_0
91 | - ipython_genutils=0.2.0=pyhd3eb1b0_1
92 | - jedi=0.18.1=py39h06a4308_1
93 | - jinja2=2.11.3=pyhd8ed1ab_2
94 | - joblib=1.4.2=pyhd8ed1ab_0
95 | - jpeg=9e=h5eee18b_1
96 | - jsonschema=4.19.2=py39h06a4308_0
97 | - jsonschema-specifications=2023.7.1=py39h06a4308_0
98 | - jupyter_client=7.4.9=py39h06a4308_0
99 | - jupyter_core=5.7.2=py39h06a4308_0
100 | - jupyter_events=0.10.0=py39h06a4308_0
101 | - jupyter_server=2.10.0=py39h06a4308_0
102 | - jupyter_server_terminals=0.4.4=py39h06a4308_1
103 | - jupyterlab_pygments=0.2.2=py39h06a4308_0
104 | - kiwisolver=1.4.4=py39h6a678d5_0
105 | - krb5=1.20.1=h143b758_1
106 | - lcms2=2.12=h3be6417_0
107 | - ld_impl_linux-64=2.38=h1181459_1
108 | - lerc=3.0=h295c915_0
109 | - libboost=1.82.0=h109eef0_2
110 | - libbrotlicommon=1.0.9=h5eee18b_8
111 | - libbrotlidec=1.0.9=h5eee18b_8
112 | - libbrotlienc=1.0.9=h5eee18b_8
113 | - libclang=14.0.6=default_hc6dbbc7_1
114 | - libclang13=14.0.6=default_he11475f_1
115 | - libcups=2.4.2=h2d74bed_1
116 | - libcurl=8.7.1=h251f7ec_0
117 | - libdeflate=1.17=h5eee18b_1
118 | - libedit=3.1.20230828=h5eee18b_0
119 | - libev=4.33=h7f8727e_1
120 | - libevent=2.1.12=hdbd6064_1
121 | - libffi=3.4.4=h6a678d5_1
122 | - libgcc-ng=11.2.0=h1234567_1
123 | - libgfortran-ng=11.2.0=h00389a5_1
124 | - libgfortran5=11.2.0=h1234567_1
125 | - libglib=2.78.4=hdc74915_0
126 | - libgomp=11.2.0=h1234567_1
127 | - libiconv=1.16=h5eee18b_3
128 | - libllvm14=14.0.6=hdb19cb5_3
129 | - libnghttp2=1.57.0=h2d74bed_0
130 | - libpng=1.6.39=h5eee18b_0
131 | - libpq=12.17=hdbd6064_0
132 | - libprotobuf=3.20.3=he621ea3_0
133 | - libsodium=1.0.18=h7b6447c_0
134 | - libssh2=1.11.0=h251f7ec_0
135 | - libstdcxx-ng=11.2.0=h1234567_1
136 | - libthrift=0.15.0=h1795dd8_2
137 | - libtiff=4.5.1=h6a678d5_0
138 | - libuuid=1.41.5=h5eee18b_0
139 | - libuv=1.44.2=h5eee18b_0
140 | - libwebp-base=1.3.2=h5eee18b_0
141 | - libxcb=1.15=h7f8727e_0
142 | - libxkbcommon=1.0.1=h5eee18b_1
143 | - libxml2=2.10.4=hfdd30dd_2
144 | - llvmlite=0.42.0=py39h6a678d5_0
145 | - locket=1.0.0=py39h06a4308_0
146 | - lz4=4.3.2=py39h5eee18b_0
147 | - lz4-c=1.9.4=h6a678d5_1
148 | - markdown=3.4.1=py39h06a4308_0
149 | - markupsafe=1.1.1=py39h3811e60_3
150 | - matplotlib=3.8.0=py39h06a4308_0
151 | - matplotlib-base=3.8.0=py39h1128e8f_0
152 | - matplotlib-inline=0.1.6=py39h06a4308_0
153 | - mistune=0.8.4=py39h27cfd23_1000
154 | - mkl=2023.1.0=h213fc3f_46344
155 | - mkl-service=2.4.0=py39h5eee18b_1
156 | - mkl_fft=1.3.8=py39h5eee18b_0
157 | - mkl_random=1.2.4=py39hdb19cb5_0
158 | - msgpack-python=1.0.3=py39hd09550d_0
159 | - multipledispatch=0.6.0=py39h06a4308_0
160 | - mysql=5.7.24=h721c034_2
161 | - nbclassic=1.1.0=py39h06a4308_0
162 | - nbclient=0.5.13=py39h06a4308_0
163 | - nbconvert=6.4.4=py39h06a4308_0
164 | - nbformat=5.9.2=py39h06a4308_0
165 | - ncurses=6.4=h6a678d5_0
166 | - nest-asyncio=1.6.0=py39h06a4308_0
167 | - notebook=6.5.7=py39h06a4308_0
168 | - notebook-shim=0.2.3=py39h06a4308_0
169 | - numba=0.59.1=py39h6a678d5_0
170 | - numexpr=2.8.7=py39h85018f9_0
171 | - numpy=1.22.3=py39hf6e8229_2
172 | - numpy-base=1.22.3=py39h060ed82_2
173 | - openjpeg=2.4.0=h9ca470c_1
174 | - openpyxl=3.0.10=py39h5eee18b_0
175 | - openssl=3.0.14=h5eee18b_0
176 | - orc=1.7.4=hb3bc3d3_1
177 | - overrides=7.4.0=py39h06a4308_0
178 | - packaging=23.2=py39h06a4308_0
179 | - pandocfilters=1.5.0=pyhd3eb1b0_0
180 | - panel=0.13.1=py39h06a4308_0
181 | - param=1.12.0=pyhd3eb1b0_0
182 | - parso=0.8.3=pyhd3eb1b0_0
183 | - partd=1.4.1=py39h06a4308_0
184 | - pcre2=10.42=hebb0a14_1
185 | - pexpect=4.8.0=pyhd3eb1b0_3
186 | - pickleshare=0.7.5=pyhd3eb1b0_1003
187 | - pillow=10.3.0=py39h5eee18b_0
188 | - pip=24.0=py39h06a4308_0
189 | - platformdirs=3.10.0=py39h06a4308_0
190 | - ply=3.11=py39h06a4308_0
191 | - prometheus_client=0.14.1=py39h06a4308_0
192 | - prompt-toolkit=3.0.43=py39h06a4308_0
193 | - psutil=5.9.0=py39h5eee18b_0
194 | - ptyprocess=0.7.0=pyhd3eb1b0_2
195 | - pure_eval=0.2.2=pyhd3eb1b0_0
196 | - pyarrow=14.0.2=py39h1eedbd7_0
197 | - pybind11-abi=4=hd8ed1ab_3
198 | - pycparser=2.21=pyhd3eb1b0_0
199 | - pyct=0.5.0=py39h06a4308_0
200 | - pyg=2.1.0=py39_torch_1.11.0_cu113
201 | - pygments=2.15.1=py39h06a4308_1
202 | - pyparsing=3.0.9=py39h06a4308_0
203 | - pyqt=5.15.10=py39h6a678d5_0
204 | - pyqt5-sip=12.13.0=py39h5eee18b_0
205 | - pysocks=1.7.1=pyha2e5f31_6
206 | - python=3.9.19=h955ad1f_1
207 | - python-dateutil=2.9.0post0=py39h06a4308_2
208 | - python-fastjsonschema=2.16.2=py39h06a4308_0
209 | - python-json-logger=2.0.7=py39h06a4308_0
210 | - python-lmdb=1.4.1=py39h6a678d5_0
211 | - python_abi=3.9=2_cp39
212 | - pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0
213 | - pytorch-cluster=1.6.0=py39_torch_1.11.0_cu113
214 | - pytorch-mutex=1.0=cuda
215 | - pytorch-scatter=2.0.9=py39_torch_1.11.0_cu113
216 | - pytorch-sparse=0.6.15=py39_torch_1.11.0_cu113
217 | - pytz=2024.1=py39h06a4308_0
218 | - pyviz_comms=3.0.2=py39h06a4308_0
219 | - pyyaml=6.0.1=py39h5eee18b_0
220 | - pyzmq=24.0.1=py39h5eee18b_0
221 | - qt-main=5.15.2=h53bd1ea_10
222 | - re2=2022.04.01=h295c915_0
223 | - readline=8.2=h5eee18b_0
224 | - referencing=0.30.2=py39h06a4308_0
225 | - requests=2.32.3=pyhd8ed1ab_0
226 | - rfc3339-validator=0.1.4=py39h06a4308_0
227 | - rfc3986-validator=0.1.1=py39h06a4308_0
228 | - rpds-py=0.10.6=py39hb02cf49_0
229 | - s2n=1.3.27=hdbd6064_0
230 | - scikit-learn=1.4.2=py39h1128e8f_1
231 | - scipy=1.11.4=py39h5f9d8c6_0
232 | - send2trash=1.8.2=py39h06a4308_0
233 | - setuptools=69.5.1=py39h06a4308_0
234 | - sip=6.7.12=py39h6a678d5_0
235 | - six=1.16.0=pyhd3eb1b0_1
236 | - snappy=1.1.10=h6a678d5_1
237 | - sniffio=1.3.0=py39h06a4308_0
238 | - sortedcontainers=2.4.0=pyhd3eb1b0_0
239 | - soupsieve=2.5=py39h06a4308_0
240 | - sqlite=3.45.3=h5eee18b_0
241 | - stack_data=0.2.0=pyhd3eb1b0_0
242 | - tbb=2021.8.0=hdb19cb5_0
243 | - tblib=1.7.0=pyhd3eb1b0_0
244 | - terminado=0.17.1=py39h06a4308_0
245 | - testpath=0.6.0=py39h06a4308_0
246 | - threadpoolctl=3.5.0=pyhc1e730c_0
247 | - tk=8.6.14=h39e8969_0
248 | - tomli=2.0.1=py39h06a4308_0
249 | - toolz=0.12.0=py39h06a4308_0
250 | - tornado=6.4.1=py39h5eee18b_0
251 | - tqdm=4.66.4=pyhd8ed1ab_0
252 | - traitlets=5.14.3=py39h06a4308_0
253 | - typing-extensions=4.11.0=py39h06a4308_0
254 | - typing_extensions=4.11.0=py39h06a4308_0
255 | - unicodedata2=15.1.0=py39h5eee18b_0
256 | - urllib3=2.2.2=pyhd8ed1ab_0
257 | - utf8proc=2.6.1=h5eee18b_1
258 | - wcwidth=0.2.5=pyhd3eb1b0_0
259 | - webencodings=0.5.1=py39h06a4308_1
260 | - websocket-client=1.8.0=py39h06a4308_0
261 | - wheel=0.43.0=py39h06a4308_0
262 | - xarray=2023.6.0=py39h06a4308_0
263 | - xz=5.4.6=h5eee18b_1
264 | - yaml=0.2.5=h7b6447c_0
265 | - zeromq=4.3.5=h6a678d5_0
266 | - zict=3.0.0=py39h06a4308_0
267 | - zipp=3.17.0=py39h06a4308_0
268 | - zlib=1.2.13=h5eee18b_1
269 | - zstd=1.5.5=hc292b87_2
270 | - pip:
271 | - ablang==0.3.1
272 | - biopython==1.78
273 | - biotite==0.34.1
274 | - git+https://github.com/facebookresearch/esm.git
275 | - iqplot==0.3.3
276 | - msgpack==1.0.8
277 | - natsort==8.4.0
278 | - networkx==3.2.1
279 | - pandas==2.0.3
280 | - tzdata==2024.1
281 | prefix: /data/home/vrs/miniconda3/envs/struct-evo
282 |
--------------------------------------------------------------------------------
/bin/util.py:
--------------------------------------------------------------------------------
1 | import json
2 | import math
3 | import biotite.structure
4 | from biotite.structure.io import pdbx, pdb
5 | from biotite.structure.residues import get_residues
6 | from biotite.structure import filter_backbone
7 | from biotite.structure import get_chains
8 | from biotite.sequence import ProteinSequence
9 | import numpy as np
10 | from scipy.spatial import transform
11 | from scipy.stats import special_ortho_group
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | import torch.utils.data as data
16 | from typing import Sequence, Tuple, List
17 | from esm.data import BatchConverter
18 |
19 |
20 | def load_structure(fpath, chain=None):
21 | """
22 | Args:
23 | fpath: filepath to either pdb or cif file
24 | chain: the chain id or list of chain ids to load
25 | Returns:
26 | biotite.structure.AtomArray
27 | """
28 | if fpath.endswith('cif'):
29 | with open(fpath) as fin:
30 | pdbxf = pdbx.PDBxFile.read(fin)
31 | structure = pdbx.get_structure(pdbxf, model=1)
32 | elif fpath.endswith('pdb'):
33 | with open(fpath) as fin:
34 | pdbf = pdb.PDBFile.read(fin)
35 | structure = pdb.get_structure(pdbf, model=1)
36 | bbmask = filter_backbone(structure)
37 | structure = structure[bbmask]
38 | all_chains = get_chains(structure)
39 | if len(all_chains) == 0:
40 | raise ValueError('No chains found in the input file.')
41 | if chain is None:
42 | chain_ids = all_chains
43 | elif isinstance(chain, list):
44 | chain_ids = chain
45 | else:
46 | chain_ids = [chain]
47 | for chain in chain_ids:
48 | if chain not in all_chains:
49 | raise ValueError(f'Chain {chain} not found in input file')
50 | chain_filter = [a.chain_id in chain_ids for a in structure]
51 | structure = structure[chain_filter]
52 | return structure
53 |
54 |
55 | def extract_coords_from_structure(structure: biotite.structure.AtomArray):
56 | """
57 | Args:
58 | structure: An instance of biotite AtomArray
59 | Returns:
60 | Tuple (coords, seq)
61 | - coords is an L x 3 x 3 array for N, CA, C coordinates
62 | - seq is the extracted sequence
63 | """
64 | coords = get_atom_coords_residuewise(["N", "CA", "C"], structure)
65 | residue_identities = get_residues(structure)[1]
66 | seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities])
67 | return coords, seq
68 |
69 |
70 | def load_coords(fpath, chain):
71 | """
72 | Args:
73 | fpath: filepath to either pdb or cif file
74 | chain: the chain id
75 | Returns:
76 | Tuple (coords, seq)
77 | - coords is an L x 3 x 3 array for N, CA, C coordinates
78 | - seq is the extracted sequence
79 | """
80 | structure = load_structure(fpath, chain)
81 | return extract_coords_from_structure(structure)
82 |
83 |
84 | def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray):
85 | """
86 | Example for atoms argument: ["N", "CA", "C"]
87 | """
88 | def filterfn(s, axis=None):
89 | filters = np.stack([s.atom_name == name for name in atoms], axis=1)
90 | sum = filters.sum(0)
91 | if not np.all(sum <= np.ones(filters.shape[1])):
92 | raise RuntimeError("structure has multiple atoms with same name")
93 | index = filters.argmax(0)
94 | coords = s[index].coord
95 | coords[sum == 0] = float("nan")
96 | return coords
97 |
98 | return biotite.structure.apply_residue_wise(struct, struct, filterfn)
99 |
100 | #* now consistent with updated gpu
101 | def get_sequence_loss(model, alphabet, coords, seq):
102 | device = next(model.parameters()).device
103 | batch_converter = CoordBatchConverter(alphabet)
104 | batch = [(coords, None, seq)]
105 | coords, confidence, strs, tokens, padding_mask = batch_converter(
106 | batch, device=device)
107 |
108 | prev_output_tokens = tokens[:, :-1].to(device)
109 | target = tokens[:, 1:]
110 | target_padding_mask = (target == alphabet.padding_idx)
111 | logits, _ = model.forward(coords, padding_mask, confidence, prev_output_tokens)
112 | loss = F.cross_entropy(logits, target, reduction='none')
113 | loss = loss[0].cpu().detach().numpy()
114 | target_padding_mask = target_padding_mask[0].cpu().numpy()
115 | return loss, target_padding_mask
116 |
117 |
118 | def score_sequence(model, alphabet, coords, seq):
119 | loss, target_padding_mask = get_sequence_loss(model, alphabet, coords, seq)
120 | ll_fullseq = -np.sum(loss * ~target_padding_mask) / np.sum(~target_padding_mask)
121 | # Also calculate average when excluding masked portions
122 | coord_mask = np.all(np.isfinite(coords), axis=(-1, -2))
123 | ll_withcoord = -np.sum(loss * coord_mask) / np.sum(coord_mask)
124 | return ll_fullseq, ll_withcoord
125 |
126 |
127 | def get_encoder_output(model, alphabet, coords):
128 | device = next(model.parameters()).device
129 | batch_converter = CoordBatchConverter(alphabet)
130 | batch = [(coords, None, None)]
131 | coords, confidence, strs, tokens, padding_mask = batch_converter(
132 | batch, device=device)
133 | encoder_out = model.encoder.forward(coords, padding_mask, confidence,
134 | return_all_hiddens=False)
135 | # remove beginning and end (bos and eos tokens)
136 | return encoder_out['encoder_out'][0][1:-1, 0]
137 |
138 |
139 | def rotate(v, R):
140 | """
141 | Rotates a vector by a rotation matrix.
142 |
143 | Args:
144 | v: 3D vector, tensor of shape (length x batch_size x channels x 3)
145 | R: rotation matrix, tensor of shape (length x batch_size x 3 x 3)
146 |
147 | Returns:
148 | Rotated version of v by rotation matrix R.
149 | """
150 | R = R.unsqueeze(-3)
151 | v = v.unsqueeze(-1)
152 | return torch.sum(v * R, dim=-2)
153 |
154 |
155 | def get_rotation_frames(coords):
156 | """
157 | Returns a local rotation frame defined by N, CA, C positions.
158 |
159 | Args:
160 | coords: coordinates, tensor of shape (batch_size x length x 3 x 3)
161 | where the third dimension is in order of N, CA, C
162 |
163 | Returns:
164 | Local relative rotation frames in shape (batch_size x length x 3 x 3)
165 | """
166 | v1 = coords[:, :, 2] - coords[:, :, 1]
167 | v2 = coords[:, :, 0] - coords[:, :, 1]
168 | e1 = normalize(v1, dim=-1)
169 | u2 = v2 - e1 * torch.sum(e1 * v2, dim=-1, keepdim=True)
170 | e2 = normalize(u2, dim=-1)
171 | e3 = torch.cross(e1, e2, dim=-1)
172 | R = torch.stack([e1, e2, e3], dim=-2)
173 | return R
174 |
175 |
176 | def nan_to_num(ts, val=0.0):
177 | """
178 | Replaces nans in tensor with a fixed value.
179 | """
180 | val = torch.tensor(val, dtype=ts.dtype, device=ts.device)
181 | return torch.where(~torch.isfinite(ts), val, ts)
182 |
183 |
184 | def rbf(values, v_min, v_max, n_bins=16):
185 | """
186 | Returns RBF encodings in a new dimension at the end.
187 | """
188 | rbf_centers = torch.linspace(v_min, v_max, n_bins, device=values.device)
189 | rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1])
190 | rbf_std = (v_max - v_min) / n_bins
191 | v_expand = torch.unsqueeze(values, -1)
192 | z = (values.unsqueeze(-1) - rbf_centers) / rbf_std
193 | return torch.exp(-z ** 2)
194 |
195 |
196 | def norm(tensor, dim, eps=1e-8, keepdim=False):
197 | """
198 | Returns L2 norm along a dimension.
199 | """
200 | return torch.sqrt(
201 | torch.sum(torch.square(tensor), dim=dim, keepdim=keepdim) + eps)
202 |
203 |
204 | def normalize(tensor, dim=-1):
205 | """
206 | Normalizes a tensor along a dimension after removing nans.
207 | """
208 | return nan_to_num(
209 | torch.div(tensor, norm(tensor, dim=dim, keepdim=True))
210 | )
211 |
212 | #*
213 | class CoordBatchConverter(BatchConverter):
214 | def __call__(self, raw_batch: Sequence[Tuple[Sequence, str]], device=None):
215 | """
216 | Args:
217 | raw_batch: List of tuples (coords, confidence, seq)
218 | In each tuple,
219 | coords: list of floats, shape L x 3 x 3
220 | confidence: list of floats, shape L; or scalar float; or None
221 | seq: string of length L
222 | Returns:
223 | coords: Tensor of shape batch_size x L x 3 x 3
224 | confidence: Tensor of shape batch_size x L
225 | strs: list of strings
226 | tokens: LongTensor of shape batch_size x L
227 | padding_mask: ByteTensor of shape batch_size x L
228 | """
229 | self.alphabet.cls_idx = self.alphabet.get_idx("")
230 | batch = []
231 | for coords, confidence, seq in raw_batch:
232 | if confidence is None:
233 | confidence = 1.
234 | if isinstance(confidence, float) or isinstance(confidence, int):
235 | confidence = [float(confidence)] * len(coords)
236 | if seq is None:
237 | seq = 'X' * len(coords)
238 | batch.append(((coords, confidence), seq))
239 |
240 | coords_and_confidence, strs, tokens = super().__call__(batch)
241 |
242 | # pad beginning and end of each protein due to legacy reasons
243 | coords = [
244 | F.pad(torch.tensor(cd), (0, 0, 0, 0, 1, 1), value=np.nan) #<---pad set to nan from np.inf
245 | for cd, _ in coords_and_confidence
246 | ]
247 | confidence = [
248 | F.pad(torch.tensor(cf), (1, 1), value=-1.)
249 | for _, cf in coords_and_confidence
250 | ]
251 | coords = self.collate_dense_tensors(coords, pad_v=np.nan)
252 | confidence = self.collate_dense_tensors(confidence, pad_v=-1.)
253 | if device is not None:
254 | coords = coords.to(device)
255 | confidence = confidence.to(device)
256 | tokens = tokens.to(device)
257 | padding_mask = torch.isnan(coords[:,:,0,0])
258 | coord_mask = torch.isfinite(coords.sum(-2).sum(-1))
259 | confidence = confidence * coord_mask + (-1.) * padding_mask
260 | return coords, confidence, strs, tokens, padding_mask
261 |
262 | def from_lists(self, coords_list, confidence_list=None, seq_list=None, device=None):
263 | """
264 | Args:
265 | coords_list: list of length batch_size, each item is a list of
266 | floats in shape L x 3 x 3 to describe a backbone
267 | confidence_list: one of
268 | - None, default to highest confidence
269 | - list of length batch_size, each item is a scalar
270 | - list of length batch_size, each item is a list of floats of
271 | length L to describe the confidence scores for the backbone
272 | with values between 0. and 1.
273 | seq_list: either None or a list of strings
274 | Returns:
275 | coords: Tensor of shape batch_size x L x 3 x 3
276 | confidence: Tensor of shape batch_size x L
277 | strs: list of strings
278 | tokens: LongTensor of shape batch_size x L
279 | padding_mask: ByteTensor of shape batch_size x L
280 | """
281 | batch_size = len(coords_list)
282 | if confidence_list is None:
283 | confidence_list = [None] * batch_size
284 | if seq_list is None:
285 | seq_list = [None] * batch_size
286 | raw_batch = zip(coords_list, confidence_list, seq_list)
287 | return self.__call__(raw_batch, device)
288 |
289 | @staticmethod
290 | def collate_dense_tensors(samples, pad_v):
291 | """
292 | Takes a list of tensors with the following dimensions:
293 | [(d_11, ..., d_1K),
294 | (d_21, ..., d_2K),
295 | ...,
296 | (d_N1, ..., d_NK)]
297 | and stack + pads them into a single tensor of:
298 | (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
299 | """
300 | if len(samples) == 0:
301 | return torch.Tensor()
302 | if len(set(x.dim() for x in samples)) != 1:
303 | raise RuntimeError(
304 | f"Samples has varying dimensions: {[x.dim() for x in samples]}"
305 | )
306 | (device,) = tuple(set(x.device for x in samples)) # assumes all on same device
307 | max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
308 | result = torch.empty(
309 | len(samples), *max_shape, dtype=samples[0].dtype, device=device
310 | )
311 | result.fill_(pad_v)
312 | for i in range(len(samples)):
313 | result_i = result[i]
314 | t = samples[i]
315 | result_i[tuple(slice(0, k) for k in t.shape)] = t
316 | return result
--------------------------------------------------------------------------------
/bin/dms_enrichment.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from biotite.sequence.io.fasta import FastaFile, get_sequences
4 | from Bio import pairwise2
5 | import pandas as pd
6 | import shutil
7 | import os
8 | import numpy as np
9 | import pandas as pd
10 | import bokeh.io
11 | import bokeh.plotting
12 | import bokeh.palettes
13 | from bokeh.transform import factor_cmap
14 |
15 | from bokeh.io import export_svg
16 | import iqplot
17 | from colorcet import glasbey_category10
18 |
19 | import subprocess
20 |
21 | bla_dict = {
22 | 'protein':'bla',
23 | 'pdbfile':'bla_1m40_a.pdb',
24 | 'chain':'A',
25 | 'dmsfile': 'dms_bla.csv',
26 | 'fitness_col': 'DMS_amp_2500_(b)',
27 | 'threshold' : 0.01
28 | }
29 |
30 | calm1_dict = {
31 | 'protein':'CALM1',
32 | 'pdbfile':'calm1_5v03_r.pdb',
33 | 'chain':'R',
34 | 'dmsfile': 'dms_calm1.csv',
35 | 'fitness_col':'DMS',
36 | 'threshold': 1
37 | }
38 |
39 | haeiiim_dict = {
40 | 'protein':'haeIIIM',
41 | 'pdbfile':'haeiiim_3ubt_b.pdb',
42 | 'chain':'B',
43 | 'dmsfile': 'dms_haeiiim.csv',
44 | 'fitness_col':'DMS_G3',
45 | 'threshold' : 0.01
46 | }
47 |
48 | gal4_dict = {
49 | 'protein':'GAL4',
50 | 'pdbfile':'gal4_3coq_b.pdb',
51 | 'chain':'B',
52 | 'dmsfile': 'dms_gal4.csv',
53 | 'fitness_col':'DMS_nonsel_24',
54 | 'threshold': 1
55 | }
56 |
57 | hras_dict = {
58 | 'protein':'HRAS',
59 | 'pdbfile':'hras_2ce2_x.pdb',
60 | 'chain':'X',
61 | 'dmsfile': 'dms_hras.csv',
62 | 'fitness_col':'DMS_unregulated',
63 | 'threshold': 1
64 | }
65 |
66 | mapk1_dict = {
67 | 'protein':'MAPK1',
68 | 'pdbfile':'mapk1_4zzn_a.pdb',
69 | 'chain':'A',
70 | 'dmsfile': 'dms_mapk1.csv',
71 | 'fitness_col':'DMS_VRT',
72 | 'threshold': 1
73 | }
74 |
75 | tpk1_dict = {
76 | 'protein':'TPK1',
77 | 'pdbfile':'tpk1_3s4y_a.pdb',
78 | 'chain':'A',
79 | 'dmsfile': 'dms_tpk1.csv',
80 | 'fitness_col':'DMS',
81 | 'threshold': 1
82 | }
83 |
84 | tpmt_dict = {
85 | 'protein':'TPMT',
86 | 'pdbfile':'tpmt_2bzg_a.pdb',
87 | 'chain':'A',
88 | 'dmsfile': 'dms_tpmt.csv',
89 | 'fitness_col':'DMS',
90 | 'threshold': 1
91 | }
92 |
93 | ube2i_dict = {
94 | 'protein':'UBE2I',
95 | 'pdbfile':'ube2i_5f6e_a.pdb',
96 | 'chain':'A',
97 | 'dmsfile': 'dms_tpmt.csv',
98 | 'fitness_col':'DMS',
99 | 'threshold': 1
100 | }
101 |
102 | ubi4_dict = {
103 | 'protein':'UBI4',
104 | 'pdbfile':'ubi4_4q5e_b.pdb',
105 | 'chain':'B',
106 | 'dmsfile': 'dms_ubi4.csv',
107 | 'fitness_col':'DMS_limiting_(b)',
108 | 'threshold': 1
109 | }
110 |
111 | studies = [
112 | bla_dict,
113 | calm1_dict,
114 | gal4_dict,
115 | haeiiim_dict,
116 | hras_dict,
117 | tpmt_dict,
118 | tpk1_dict,
119 | mapk1_dict,
120 | ube2i_dict,
121 | ubi4_dict,
122 | ]
123 |
124 | def score_if(prot_dict):
125 | prefix = 'data/dms/'
126 | fasta_outpath = f"dms_libs/{prot_dict['pdbfile'][:-4]}_dmsLib.fasta"
127 | score_if_cmd = [
128 | 'python',
129 | 'bin/score_log_likelihoods.py',
130 | prefix + f"structures/{prot_dict['pdbfile']}",
131 | '--seqpath', prefix + fasta_outpath,
132 | '--outpath', f"output/dms/if_results/dms_{prot_dict['protein']}_if.csv"
133 | ,'--chain', prot_dict['chain']
134 | ]
135 | subprocess.run(score_if_cmd, check=True)
136 |
137 | def get_stats(prot_dict, n, top_per):
138 |
139 | top_per = top_per/100
140 | prefix = 'output/dms/'
141 | esm1v_outpath = 'dms_esm1vScored/' + f"dms_{prot_dict['protein']}" + '_maskMarginals.csv'
142 | if_results_df = pd.read_csv( prefix + f"if_results/dms_{prot_dict['protein']}_if.csv")
143 | dms_df = pd.read_csv(prefix + esm1v_outpath)
144 |
145 | unprofiled_df = dms_df[dms_df[prot_dict['fitness_col']].isna()]
146 | unprofiled_variants = unprofiled_df['variant'].to_list()
147 |
148 | dms_df = dms_df.dropna(subset = [prot_dict['fitness_col']])
149 | pop_n = len(dms_df)
150 | if_results_df = if_results_df[~if_results_df['seqid'].isin(unprofiled_variants)]
151 |
152 | assert (len(if_results_df) == len(dms_df))
153 |
154 | threshold = dms_df[prot_dict['fitness_col']].quantile(1-top_per)
155 | highFit_subset_variants = dms_df[dms_df[prot_dict['fitness_col']] >= threshold]['variant']
156 | n_highFit = len(highFit_subset_variants)
157 |
158 |
159 | dms_df['esm1v'] = dms_df.loc[:, dms_df.columns.str.startswith('esm1v')].mean(axis=1)
160 | top_esm1v_df = dms_df.sort_values(by = 'esm1v', ascending = False)[:n]
161 | top_esm1v_variants = top_esm1v_df['variant'].to_list()
162 |
163 | top_if_df = if_results_df.sort_values(by = 'log_likelihood', ascending = False)[:n]
164 | top_if_variants = top_if_df['seqid'].to_list()
165 |
166 | esm1v_hits_df = top_esm1v_df[top_esm1v_df['variant'].isin(highFit_subset_variants)]
167 | if_hits_df = top_if_df[top_if_df['seqid'].isin(highFit_subset_variants)]
168 |
169 | n_esm1v_hits = len(esm1v_hits_df)
170 | n_if_hits = len(if_hits_df)
171 |
172 | if_hit_enrich = (n_if_hits / n) / top_per
173 | esm1v_hit_enrich = (n_esm1v_hits / n) / top_per
174 |
175 | if_hit_rate = (n_if_hits / n)
176 | esm1v_hit_rate = (n_esm1v_hits / n)
177 |
178 | return (prot_dict['protein'], pop_n, n_highFit, n_if_hits, n_esm1v_hits, if_hit_enrich, esm1v_hit_enrich, if_hit_rate, esm1v_hit_rate )
179 |
180 | def plot_bars(melted_df):
181 | p = bokeh.plotting.figure(
182 | height=350,
183 | width=1100,
184 | y_axis_label="High Fitness \n Prediction Precision",
185 | x_axis_label = "Functional Percentile Threshold \n for High Fitness Classification",
186 | x_range=bokeh.models.FactorRange(*factors, group_padding = 1.2),
187 | tools="save",
188 | title=''
189 | )
190 | p.output_backend = "svg"
191 |
192 | p.vbar(
193 | source=melted_df,
194 | x="cats",
195 | top="Hit Rate",
196 | width = 1,
197 | line_color='black',
198 | alpha = 'alpha',
199 | legend_field='legend_label',
200 | fill_color=bokeh.transform.factor_cmap(
201 | 'Method',
202 | palette=['#999999', '#43a2ca'],
203 | factors=list(melted_df['Method'].unique()),
204 | start=1,
205 | end=2
206 | )
207 | )
208 |
209 | p.xgrid.grid_line_color = None
210 | p.ygrid.grid_line_color = None
211 | p.x_range.range_padding = 0.03
212 | p.legend.location = "top_right"
213 | p.legend.spacing = 20
214 | p.legend.label_text_font_size = "11pt"
215 | p.legend.label_text_color = "black"
216 | p.legend.orientation = "horizontal"
217 | p.xaxis.major_label_orientation = 1.2
218 | p.xaxis.separator_line_alpha = 0
219 | p.xaxis.group_text_font_size = '11pt'
220 | p.xaxis.subgroup_text_font_size = '0pt'
221 | p.xaxis.axis_label_text_font_size = '11pt'
222 | p.xaxis.major_label_text_font_size = '10pt'
223 | p.xaxis.axis_label_text_font_style = 'normal'
224 | p.xaxis.major_label_text_color = 'black'
225 | p.xaxis.axis_label_text_color = 'black'
226 |
227 | p.yaxis.axis_label_text_font_size = '11pt'
228 | p.yaxis.axis_label_text_font_style = 'normal'
229 | p.yaxis.axis_label_text_color = 'black'
230 | p.yaxis.major_label_text_font_size = '10pt'
231 | p.xaxis.group_text_color = 'dimgrey'
232 |
233 | return p
234 |
235 | def plot_ecdf_hits(prot_dict, n, top_per_lst):
236 |
237 | #use least stringent threshold to plot all hits
238 | top_per = max(top_per_lst)/100
239 | prefix = 'output/dms/'
240 |
241 | esm1v_outpath = 'dms_esm1vScored/' + f"dms_{prot_dict['protein']}" + '_maskMarginals.csv'
242 | if_results_df = pd.read_csv( prefix + f"if_results/dms_{prot_dict['protein']}_if.csv")
243 | dms_df = pd.read_csv(prefix + esm1v_outpath)
244 |
245 |
246 | unprofiled_df = dms_df[dms_df[prot_dict['fitness_col']].isna()]
247 | unprofiled_variants = unprofiled_df['variant'].to_list()
248 |
249 | dms_df = dms_df.dropna(subset = [prot_dict['fitness_col']])
250 | dms_df['percentile'] = dms_df[prot_dict['fitness_col']].rank(pct=True, method='average')
251 | dms_df['zscore'] = (dms_df[prot_dict['fitness_col']] - dms_df[prot_dict['fitness_col']].mean())/dms_df[prot_dict['fitness_col']].std()
252 | if_results_df = if_results_df[~if_results_df['seqid'].isin(unprofiled_variants)]
253 |
254 | assert (len(if_results_df) == len(dms_df))
255 |
256 | threshold = dms_df[prot_dict['fitness_col']].quantile(1-top_per)
257 | highFit_subset_variants = dms_df[dms_df[prot_dict['fitness_col']] >= threshold]['variant']
258 | n_highFit = len(highFit_subset_variants)
259 |
260 | dms_df['esm1v'] = dms_df.loc[:, dms_df.columns.str.startswith('esm1v')].mean(axis=1)
261 | top_esm1v_df = dms_df.sort_values(by = 'esm1v', ascending = False)[:n]
262 | top_esm1v_variants = top_esm1v_df['variant'].to_list()
263 |
264 | top_if_df = if_results_df.sort_values(by = 'log_likelihood', ascending = False)[:n]
265 | top_if_variants = top_if_df['seqid'].to_list()
266 |
267 | esm1v_hits_df = top_esm1v_df[top_esm1v_df['variant'].isin(highFit_subset_variants)]
268 | if_hits_df = top_if_df[top_if_df['seqid'].isin(highFit_subset_variants)]
269 | if_hits_df = pd.merge(if_hits_df, dms_df[['variant', 'zscore', 'percentile', prot_dict['fitness_col']]], left_on='seqid', right_on='variant', how='left')
270 |
271 | p_1v = iqplot.ecdf(
272 | data=dms_df,
273 | q=prot_dict['fitness_col'],
274 | style="staircase",
275 | palette = 'lightgrey',
276 | line_kwargs= {'line_width':6},
277 | x_axis_label = prot_dict['protein']+' Fitness Measure',
278 | #title = prot_dict['protein']
279 | )
280 | p_if = iqplot.ecdf(
281 | data=dms_df,
282 | q=prot_dict['fitness_col'],
283 | style="staircase",
284 | palette = 'lightgrey',
285 | line_kwargs= {'line_width':6},
286 | x_axis_label = None, #prot_dict['protein']+'Fitness Measure',
287 | title = prot_dict['protein']
288 | )
289 | colors = ['#1f78b4', '#33a02c', '#ff7f00',]
290 | for i, top_per in enumerate(top_per_lst):
291 | top_per = top_per/100
292 | hline = bokeh.models.Span(location=1-top_per, dimension='width', line_color=colors[i], line_dash='dashed', line_width=4, )
293 | p_1v.add_layout(hline)
294 | p_if.add_layout(hline)
295 |
296 |
297 | p_1v.circle(
298 | source = esm1v_hits_df,
299 | x = prot_dict['fitness_col'],
300 | y = 'percentile',
301 | color = 'dimgrey',
302 | size = 11,
303 | line_color = 'lightgrey',
304 | alpha = 0.8,
305 | )
306 |
307 | p_if.circle(
308 | source = if_hits_df,
309 | x = prot_dict['fitness_col'],
310 | y = 'percentile',
311 | color = '#43a2ca',
312 | size = 11,
313 | line_color = 'white',
314 | alpha = 0.8,
315 | )
316 |
317 | p_if.title.text_font_style = 'normal'
318 | p_if.title.text_color = 'black'
319 | p_if.title.align = 'center'
320 | p_if.title.text_font_size = '16pt'
321 | p_1v.xaxis.axis_label_text_font_size = '14pt'
322 |
323 |
324 | for p in [p_if, p_1v]:
325 | p.xaxis.axis_label_text_font_style = 'normal'
326 | p.yaxis.axis_label_text_font_style = 'normal'
327 | p.xaxis.axis_label_text_font_size = '14pt'
328 | p.yaxis.axis_label_text_font_size = '14pt'
329 | p.xaxis.axis_label_text_color = 'black'
330 | p.yaxis.axis_label_text_color = 'black'
331 | p.xaxis.major_label_text_font_size = '14pt'
332 | p.yaxis.major_label_text_font_size = '14pt'
333 | p.xaxis.major_label_text_color = 'black'
334 | p.yaxis.major_label_text_color = 'black'
335 | p.output_backend = "svg"
336 |
337 | return [p_if, p_1v]
338 |
339 | if __name__ == '__main__':
340 | for s in studies:
341 | score_if(s)
342 | shutil.copytree("data/dms/dms_esm1vScored", "output/dms/dms_esm1vScored", dirs_exist_ok=True)
343 |
344 | name, n_pop, n_popHits, n_if_hits, n_esm1v_hits, if_hit_enrich, esm1v_hit_enrich, if_hit_rate, esm1v_hit_rate, p_threshold = ([] for _ in range(10))
345 | data_lists = [name, n_pop, n_popHits, n_if_hits, n_esm1v_hits, if_hit_enrich, esm1v_hit_enrich, if_hit_rate, esm1v_hit_rate, p_threshold]
346 | for s in studies:
347 | for p in [5, 10, 20 ]:
348 | outputs = get_stats(s, 10, p)
349 | for i in range(len(outputs)):
350 | data_lists[i].append(outputs[i])
351 | data_lists[len(data_lists)-1].append(p)
352 | col_names = ['Protein', 'Total Library Size', 'Library Hits (Variants with Fitness >95th)', 'Inverse Folding Hits (in Top 10)', 'ESM1v Hits (in Top 10)', 'Inverse Folding Hit Enrichment', 'ESM1v Hit Enrichment','Inverse Folding Hit Rate', 'ESM1v Hit Rate','Percentile Threshold' ]
353 | results_df = pd.DataFrame({col_names[i]: data_lists[i] for i in range(len(data_lists))})
354 | # Melt the dataframe
355 | melted_df = results_df.melt(id_vars=['Protein', 'Percentile Threshold'], value_vars=[ 'ESM1v Hit Rate', 'Inverse Folding Hit Rate',],
356 | var_name='Method', value_name='Hit Rate')
357 | melted_df['Method'] = melted_df['Method'].replace({
358 | 'Inverse Folding Hit Rate': 'Inverse Folding',
359 | 'ESM1v Hit Rate': 'Language Model'})
360 | melted_df['alpha'] = (1/melted_df['Percentile Threshold'].to_numpy())*3 + 0.4
361 | melted_df['Percentile Threshold'] = melted_df['Percentile Threshold'].astype(str)
362 | melted_df['legend_label'] = ['Structure-Informed Language Model' if x=='Inverse Folding' else 'Language Model' for x in melted_df['Method']]
363 |
364 | melted_df['cats'] = melted_df.apply(lambda x: (x["Protein"], x["Method"], x['Percentile Threshold'],), axis = 1)
365 | factors = list(melted_df.cats)
366 |
367 | p = plot_bars(melted_df)
368 | comparePrecision_fname = 'output/dms/precision_comparison.html'
369 | bokeh.plotting.output_file(comparePrecision_fname)
370 | bokeh.io.show(p)
371 |
372 | ecdf_plots = []
373 | for s in studies:
374 | ecdf_plots.extend(plot_ecdf_hits(s, 10, [20,10,5]))
375 | mid = int(len(ecdf_plots)/2)
376 | ecdf_fname = 'output/dms/fitness_ecdfs.html'
377 | bokeh.plotting.output_file(ecdf_fname)
378 | bokeh.io.show(bokeh.layouts.gridplot([ecdf_plots[:mid:2],
379 | ecdf_plots[1:mid:2],
380 | ecdf_plots[mid::2],
381 | ecdf_plots[mid+1::2]])
382 | )
383 |
--------------------------------------------------------------------------------
/bin/plot_esm1v_benchmarks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import bokeh.io
4 | import bokeh.plotting
5 | import bokeh.palettes
6 | from bokeh.transform import factor_cmap
7 | import datashader
8 | import holoviews as hv
9 | import holoviews.operation.datashader
10 | hv.extension("bokeh")
11 |
12 | import warnings
13 | # Suppress FutureWarning messages
14 | warnings.simplefilter(action='ignore',)
15 |
16 | cr9114_dict = {
17 | 'ab_name' : 'CR9114',
18 | 'files' : {
19 | 'ESM-1v Ab-Ag': 'output/ab_mutagenesis_expts/cr9114/cr9114_esm1vAbAg_exp_data_maskMargLabeled.csv',
20 | 'ESM-1v Ab only': 'output/ab_mutagenesis_expts/cr9114/cr9114_esm1vbothchains_exp_data_maskMargLabeled.csv',
21 | 'ESM-1v Ab VH only': 'output/ab_mutagenesis_expts/cr9114/cr9114_exp_data_maskMargLabeled.csv',
22 | },
23 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/cr9114/cr9114_exp_data.csv',
24 | dtype = {'genotype': str},
25 | ).rename(columns={'h1_mean': 'H1', 'h3_mean' : 'H3'}),
26 | 'ag_columns': ['H1', 'H3'],
27 | 'expt_type': 'Combinatorial Mutagenesis for Affinity',
28 | 'palette': bokeh.palettes.Spectral6
29 | }
30 |
31 | cr6261_dict = {
32 | 'ab_name' : 'CR6261',
33 | 'files' : {
34 | 'ESM-1v Ab-Ag': 'output/ab_mutagenesis_expts/cr6261/cr6261_esm1vAbAg_exp_data_maskMargLabeled.csv',
35 | 'ESM-1v Ab only': 'output/ab_mutagenesis_expts/cr6261/cr6261_esm1vbothchains_exp_data_maskMargLabeled.csv',
36 | 'ESM-1v Ab VH only': 'output/ab_mutagenesis_expts/cr6261/cr6261_exp_data_maskMargLabeled.csv',
37 | },
38 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/cr6261/cr6261_exp_data.csv',
39 | dtype={'genotype': str},
40 | ).rename(columns={'h1_mean': 'H1', 'h9_mean' : 'H9'}),
41 | 'ag_columns': ['H1', 'H9'],
42 | 'expt_type': 'Combinatorial Mutagenesis for Affinity',
43 | 'palette': bokeh.palettes.Dark2_6,
44 | }
45 |
46 | g6LC_dict = {
47 | 'ab_name' : 'g6',
48 | 'files' : {
49 | 'LM Ab-Ag': 'output/ab_mutagenesis_expts/g6/g6Lc_esm1vAbAg_exp_data_maskMargLabeled.csv',
50 | 'LM Ab only': 'output/ab_mutagenesis_expts/g6/g6Lc_esm1vbothchains_exp_data_maskMargLabeled.csv',
51 | 'LM Ab VH/VL only': 'output/ab_mutagenesis_expts/g6/g6Lc_exp_data_maskMargLabeled.csv',
52 | },
53 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/g6/g6_lc_exp_data.csv'),
54 | 'ag_columns': ['norm_binding'],
55 | 'expt_type': 'Deep Mutational Scan for Binding',
56 | 'palette': bokeh.palettes.Pastel1_6,
57 | 'chain': 'VL'
58 | }
59 |
60 | g6HC_dict = {
61 | 'ab_name' : 'g6',
62 | 'files' : {
63 | 'LM Ab-Ag': 'output/ab_mutagenesis_expts/g6/g6Hc_esm1vAbAg_exp_data_maskMargLabeled.csv',
64 | 'LM Ab only': 'output/ab_mutagenesis_expts/g6/g6Hc_esm1vbothchains_exp_data_maskMargLabeled.csv',
65 | 'LM Ab VH/VL only': 'output/ab_mutagenesis_expts/g6/g6Hc_exp_data_maskMargLabeled.csv',
66 | },
67 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/g6/g6_hc_exp_data.csv'),
68 | 'ag_columns': ['norm_binding'],
69 | 'expt_type': 'Deep Mutational Scan for Binding',
70 | 'palette': bokeh.palettes.Pastel1_4,
71 | 'chain': 'VH'
72 | }
73 |
74 | def apply_mask_and_average(scores, genotype):
75 | #scores here should already be a in log space
76 |
77 | if '1' in genotype:
78 | masked_list = []
79 | for s, mask_char in zip(scores, genotype):
80 | if mask_char == '1':
81 | masked_list.append(s)
82 | multi_avg = np.mean(masked_list)
83 | else:
84 | #if genotype is all zeros = wt
85 | multi_avg = 0
86 |
87 | return multi_avg
88 |
89 |
90 | def transform_single_to_multi( dms_df, singleMuts_df, condition, sort_col, ascending ):
91 | multi_scores = []
92 | #use input parameter as method bc key is kwarg in pd.sort_values
93 | #method should be either esm1v or abysis or AbLang
94 |
95 | #for cr sorts the single mutation dataframe in residue order, ie binary '10000' before '00010'. ascending = False
96 | #for g6 sorts the single mutation dataframe in position order, ie residue index. ascending = True
97 |
98 | singleMuts_sorted = singleMuts_df.sort_values(sort_col, key=lambda x: x.astype(int), ascending=ascending)
99 | singles_scores = singleMuts_sorted[condition].to_list()
100 |
101 | for g in dms_df['genotype']:
102 | multi_scores.append(apply_mask_and_average(singles_scores, g ))
103 |
104 | return multi_scores
105 |
106 | #get melted correlations matrix of structure input x target antigen, for a given antibody
107 | def get_corr(ab_name, files, dms_df, ag_columns, dropLOQ = False ):
108 |
109 | #retreive correlations from abysis and esm1v scores
110 | for key, filepath in files.items():
111 | singleMuts_df = pd.read_csv(filepath)
112 | # Get the columns for all esm1v models
113 | esm_columns = [col for col in singleMuts_df.columns if col.startswith('esm')]
114 | # Calculate the average for each esm column
115 | esm_avg_values = singleMuts_df[esm_columns].mean(axis=1)
116 | singleMuts_df[key] = esm_avg_values
117 | if ab_name == 'g6':
118 | #g6 data set is only single mutations
119 | dms_df[key] = esm_avg_values
120 | else:
121 | #cr datasets are multiple mutations
122 | dms_df[key] = transform_single_to_multi(dms_df, singleMuts_df, key, 'genotype', ascending = False)
123 |
124 |
125 | conditions = list(files.keys())
126 |
127 | if not dropLOQ:
128 | correlations = dms_df[ag_columns + conditions].corr(method='spearman')
129 | correlations = correlations.drop(conditions, axis= 1)
130 | correlations = correlations.drop(ag_columns, axis= 0)
131 |
132 | # Melt the correlations DataFrame and rename the index column
133 | melted_correlations = (
134 | correlations
135 | .reset_index()
136 | .melt(id_vars='index', var_name='Target', value_name='Correlation')
137 | .rename(columns={'index': 'Input'})
138 | )
139 | melted_correlations['Method'] = ['LM' for i in melted_correlations['Input']]
140 |
141 | return melted_correlations
142 | #if we wish to drop the samples on the lower limit of quantitation, drop samples from each ag group individually and compute correlation with IF log_like
143 | else:
144 | all_ag_melted = pd.DataFrame({})
145 | for ag in ag_columns:
146 | filt_df = dms_df[dms_df[ag] > min(dms_df[ag])]
147 |
148 | correlations = filt_df[[ag] + conditions].corr(method='spearman')
149 | correlations = correlations.drop(conditions, axis= 1)
150 | correlations = correlations.drop(ag, axis= 0)
151 |
152 | # Melt the correlations DataFrame and rename the index column
153 | melted_correlations = (
154 | correlations
155 | .reset_index()
156 | .melt(id_vars='index', var_name='Target', value_name='Correlation')
157 | ).rename(columns={'index': 'Input'})
158 |
159 | all_ag_melted = pd.concat([all_ag_melted, melted_correlations], ignore_index= True)
160 | all_ag_melted['Method'] = ['LM' for i in all_ag_melted['Input']]
161 |
162 | return all_ag_melted
163 |
164 | #get melted correlations matrix of structure input x target antigen, for a given antibody
165 | def get_g6_corr(g6Hc_dict, g6LC_dict, dropLOQ = False ):
166 |
167 | ab_name = 'g6'
168 | ag_columns = g6LC_dict['ag_columns']
169 | vh_and_vl_dms_df = pd.DataFrame({})
170 |
171 | for d in [g6HC_dict, g6LC_dict]:
172 |
173 | files = d['files']
174 | dms_df = d['dms_df']
175 |
176 | #retreive correlations from abysis, ablang, and esm1v scores
177 | for key, filepath in files.items():
178 | singleMuts_df = pd.read_csv(filepath)
179 | # Get the columns for all esm1v models
180 | esm_columns = [col for col in singleMuts_df.columns if col.startswith('esm')]
181 | # Calculate the average for each esm column
182 | esm_avg_values = singleMuts_df[esm_columns].mean(axis=1)
183 | singleMuts_df[key] = esm_avg_values
184 | if ab_name == 'g6':
185 | #g6 data set is only single mutations
186 | dms_df[key] = esm_avg_values
187 |
188 |
189 | vh_and_vl_dms_df = pd.concat([vh_and_vl_dms_df, dms_df], ignore_index= True)
190 |
191 | conditions = list(files.keys())
192 |
193 | #
194 | if not dropLOQ:
195 | correlations = vh_and_vl_dms_df[ag_columns + conditions].corr(method='spearman')
196 | correlations = correlations.drop(conditions, axis= 1)
197 | correlations = correlations.drop(ag_columns, axis= 0)
198 |
199 | # Melt the correlations DataFrame and rename the index column
200 | melted_correlations = (
201 | correlations
202 | .reset_index()
203 | .melt(id_vars='index', var_name='Target', value_name='Correlation')
204 | .rename(columns={'index': 'Input'})
205 | )
206 | melted_correlations['Method'] = ['LM' for i in melted_correlations['Input']]
207 | melted_correlations['Target'] = ['VEGF-A'] * len(melted_correlations)
208 |
209 | return melted_correlations
210 | #if we wish to drop the samples on the lower limit of quantitation, drop samples from each ag group individually and compute correlation with IF log_like
211 | else:
212 | all_ag_melted = pd.DataFrame({})
213 | for ag in ag_columns:
214 | filt_df = vh_and_vl_dms_df[vh_and_vl_dms_df[ag] > min(vh_and_vl_dms_df[ag])]
215 |
216 | correlations = filt_df[[ag] + conditions].corr(method='spearman')
217 | correlations = correlations.drop(conditions, axis= 1)
218 | correlations = correlations.drop(ag, axis= 0)
219 |
220 | # Melt the correlations DataFrame and rename the index column
221 | melted_correlations = (
222 | correlations
223 | .reset_index()
224 | .melt(id_vars='index', var_name='Target', value_name='Correlation')
225 | ).rename(columns={'index': 'Input'})
226 |
227 | all_ag_melted = pd.concat([all_ag_melted, melted_correlations], ignore_index= True)
228 | all_ag_melted['Method'] = ['LM' for i in melted_correlations['Input']]
229 | melted_correlations['Target'] = ['VEGF-A'] * len(melted_correlations)
230 |
231 | return all_ag_melted
232 |
233 | #plot correlation bars for cr antibodies: cr6261 and cr9114
234 | def plot_hbar(title, melted_correlations, palette = bokeh.palettes.Spectral6, compare = 'model'):
235 |
236 | if compare == 'model':
237 | melted_correlations.replace('Ab-Ag','InverseFolding', inplace = True)
238 | color_by = 'Input'
239 | elif compare == 'IF':
240 | color_by = 'Method'
241 | else:
242 | color_by = 'Target'
243 |
244 |
245 | melted_correlations['cats'] = melted_correlations.apply(lambda x: (x["Target"], x["Input"]), axis = 1)
246 | factors = list(melted_correlations.cats)[::-1]
247 |
248 | p = bokeh.plotting.figure(
249 | height=340,
250 | width=440,
251 | x_axis_label="Spearman Correlation",
252 | x_range=[0, 1],
253 | y_range=bokeh.models.FactorRange(*factors),
254 | tools="save",
255 | title = title
256 | )
257 |
258 |
259 | p.hbar(
260 | source=melted_correlations,
261 | y="cats",
262 | right="Correlation",
263 | height=0.6,
264 | line_color = 'black',
265 | legend_field = color_by,
266 | # use the palette to colormap based on the the x[1:2] values
267 | # fill_color=factor_cmap(
268 | # color_by,
269 | # palette= palette,
270 | # factors = list(melted_correlations[color_by].unique()),
271 | # start=1,
272 | # end=3)
273 | fill_color=bokeh.palettes.Dark2_6[1]
274 |
275 | )
276 |
277 | #labels_df = melted_correlations[melted_correlations['Input'].isin(['Ab-Ag','ESM-1v', 'abYsis', 'InverseFold'])]
278 | labels_df = melted_correlations
279 | labels_df['corr_str'] = labels_df['Correlation'].apply(lambda x: round(x, 2)).astype(str)
280 | labels_source = bokeh.models.ColumnDataSource(labels_df)
281 |
282 | labels = bokeh.models.LabelSet(x='Correlation', y='cats', text='corr_str',text_font_size = "10px",
283 | x_offset=12, y_offset=-5, source=labels_source, render_mode='canvas')
284 |
285 | p.ygrid.grid_line_color = None
286 | p.y_range.range_padding = 0.1
287 | p.add_layout(labels)
288 | p.legend.visible = False
289 |
290 | p.output_backend = "svg"
291 | return p
292 |
293 | if __name__ == '__main__':
294 | datasets = [cr9114_dict, cr6261_dict, (g6HC_dict, g6LC_dict) ]
295 |
296 | #compute with and without points on lower limit of quantitation ignored
297 | for dropLOQ in [False]:
298 |
299 | all_corr_plots = []
300 |
301 | for d in datasets:
302 | if type(d) is tuple:
303 | g6Hc, g6Lc = d
304 | title = g6Hc['ab_name'] + ', ' + g6Hc['expt_type']
305 | g6_combined = get_g6_corr(g6Hc, g6Lc, dropLOQ= dropLOQ)
306 |
307 | all_corr_plots.append(plot_hbar(title, g6_combined, g6Lc['palette'], 'IF'))
308 |
309 | else:
310 | corr_df = get_corr(d['ab_name'], d['files'], d['dms_df'], d['ag_columns'], dropLOQ= dropLOQ)
311 | title = d['ab_name'] + ', ' + d['expt_type']
312 | all_corr_plots.append(plot_hbar(title, corr_df, d['palette'], 'Target'))
313 |
314 |
315 | all_corr_fname = f"output/ab_mutagenesis_expts/esm1v_benchmarks{'_dropLOQ' if dropLOQ else ''}.html"
316 | bokeh.plotting.output_file(all_corr_fname)
317 | bokeh.io.show(bokeh.layouts.gridplot(all_corr_plots, ncols = len(all_corr_plots)))
318 |
--------------------------------------------------------------------------------
/bin/plot_ab-binding_benchmarks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import bokeh.io
4 | import bokeh.plotting
5 | import bokeh.palettes
6 | from bokeh.transform import factor_cmap
7 | import datashader
8 | import holoviews as hv
9 | import holoviews.operation.datashader
10 | hv.extension("bokeh")
11 |
12 | import warnings
13 | # Suppress FutureWarning messages
14 | warnings.simplefilter(action='ignore',)
15 |
16 | # Your Pandas code here
17 |
18 | if_conds = ['ag+ab', 'ab only', 'ab']
19 |
20 | cr9114_dict = {
21 | 'ab_name' : 'CR9114',
22 | 'files' : {
23 | 'Ab-Ag': 'output/ab_mutagenesis_expts/cr9114/4fqi_ablh_scores.csv',
24 | 'Ab only': 'output/ab_mutagenesis_expts/cr9114/4fqi_lh_scores.csv',
25 | 'Ab VH only': 'output/ab_mutagenesis_expts/cr9114/4fqi_h_scores.csv',
26 | 'ESM-1v': 'output/ab_mutagenesis_expts/cr9114/cr9114_exp_data_maskMargLabeled.csv',
27 | 'AbLang': 'output/ab_mutagenesis_expts/cr9114/cr9114_hc_ablangScores.csv',
28 | 'abYsis': 'output/ab_mutagenesis_expts/cr9114/abysis_counts_cr9114_vh.txt',
29 | },
30 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/cr9114/cr9114_exp_data.csv',
31 | dtype = {'genotype': str},
32 | ).rename(columns={'h1_mean': 'H1', 'h3_mean' : 'H3'}),
33 | 'ag_columns': ['H1', 'H3'],
34 | 'expt_type': 'Combinatorial Mutagenesis for Affinity',
35 | 'palette': bokeh.palettes.Spectral6
36 | }
37 |
38 | cr6261_dict = {
39 | 'ab_name' : 'CR6261',
40 | 'files' : {
41 | 'Ab-Ag': 'output/ab_mutagenesis_expts/cr6261/3gbn_ablh_scores.csv',
42 | 'Ab only': 'output/ab_mutagenesis_expts/cr6261/3gbn_lh_scores.csv',
43 | 'Ab VH only': 'output/ab_mutagenesis_expts/cr6261/3gbn_h_scores.csv',
44 | 'ESM-1v': 'output/ab_mutagenesis_expts/cr6261/cr6261_exp_data_maskMargLabeled.csv',
45 | 'AbLang': 'output/ab_mutagenesis_expts/cr6261/cr6261_hc_ablangScores.csv',
46 | 'abYsis': 'output/ab_mutagenesis_expts/cr6261/abysis_counts_cr6261_vh.txt',
47 | },
48 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/cr6261/cr6261_exp_data.csv',
49 | dtype={'genotype': str},
50 | ).rename(columns={'h1_mean': 'H1', 'h9_mean' : 'H9'}),
51 | 'ag_columns': ['H1', 'H9'],
52 | 'expt_type': 'Combinatorial Mutagenesis for Affinity',
53 | 'palette': bokeh.palettes.Dark2_6,
54 | }
55 |
56 |
57 | g6LC_dict = {
58 | 'ab_name' : 'g6',
59 | 'files' : {
60 | 'Ab-Ag': 'output/ab_mutagenesis_expts/g6/2fjg_vlh_lc_scores.csv',
61 | 'Ab only': 'output/ab_mutagenesis_expts/g6/2fjg_lh_lc_scores.csv',
62 | 'Ab VH/VL only': 'output/ab_mutagenesis_expts/g6/2fjg_l_lc_scores.csv',
63 | 'ESM-1v': 'output/ab_mutagenesis_expts/g6/g6Lc_exp_data_maskMargLabeled.csv',
64 | 'AbLang': 'output/ab_mutagenesis_expts/g6/g6_lc_ablangScores.csv',
65 | 'abYsis': 'output/ab_mutagenesis_expts/g6/abysis_counts_g6_vl.txt',
66 | },
67 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/g6/g6_lc_exp_data.csv'),
68 | 'ag_columns': ['norm_binding'],
69 | 'expt_type': 'Deep Mutational Scan for Binding',
70 | 'palette': bokeh.palettes.Pastel1_6,
71 | 'chain': 'VL'
72 | }
73 |
74 | g6HC_dict = {
75 | 'ab_name' : 'g6',
76 | 'files' : {
77 | 'Ab-Ag': 'output/ab_mutagenesis_expts/g6/2fjg_vlh_hc_scores.csv',
78 | 'Ab only': 'output/ab_mutagenesis_expts/g6/2fjg_lh_hc_scores.csv',
79 | 'Ab VH/VL only': 'output/ab_mutagenesis_expts/g6/2fjg_h_hc_scores.csv',
80 | 'ESM-1v': 'output/ab_mutagenesis_expts/g6/g6Hc_exp_data_maskMargLabeled.csv',
81 | 'AbLang': 'output/ab_mutagenesis_expts/g6/g6_hc_ablangScores.csv',
82 | 'abYsis': 'output/ab_mutagenesis_expts/g6/abysis_counts_g6_vh.txt',
83 | },
84 | 'dms_df' : pd.read_csv('data/ab_mutagenesis_expts/g6/g6_hc_exp_data.csv'),
85 | 'ag_columns': ['norm_binding'],
86 | 'expt_type': 'Deep Mutational Scan for Binding',
87 | 'palette': bokeh.palettes.Pastel1_4,
88 | 'chain': 'VH'
89 | }
90 | def apply_mask_and_average(scores, genotype):
91 | #scores here should already be a in log space
92 |
93 | if '1' in genotype:
94 | masked_list = []
95 | for s, mask_char in zip(scores, genotype):
96 | if mask_char == '1':
97 | masked_list.append(s)
98 | multi_avg = np.mean(masked_list)
99 | else:
100 | #if genotype is all zeros = wt
101 | multi_avg = 0
102 |
103 | return multi_avg
104 |
105 |
106 | def transform_single_to_multi( dms_df, singleMuts_df, condition, sort_col, ascending ):
107 | multi_scores = []
108 | #use input parameter as method bc key is kwarg in pd.sort_values
109 | #method should be either esm1v or abysis or AbLang
110 |
111 | #for cr sorts the single mutation dataframe in residue order, ie binary '10000' before '00010'. ascending = False
112 | #for g6 sorts the single mutation dataframe in position order, ie residue index. ascending = True
113 |
114 | singleMuts_sorted = singleMuts_df.sort_values(sort_col, key=lambda x: x.astype(int), ascending=ascending)
115 | singles_scores = singleMuts_sorted[condition].to_list()
116 |
117 | for g in dms_df['genotype']:
118 | multi_scores.append(apply_mask_and_average(singles_scores, g ))
119 |
120 | return multi_scores
121 |
122 | def get_abysis_single_scores_df(singleMut_ids, abysis_df):
123 | single_scores = []
124 | all_pos = []
125 |
126 | for id in singleMut_ids:
127 | wt, pos, mt = id[0], int(id[1:-1]), id[-1]
128 | likelihood_ratio = abysis_df.loc[(abysis_df['pos'] == pos) & (abysis_df['wt'] == wt) & (abysis_df['mt'] == mt)]['likelihood_ratio'].to_list()[0]
129 | log_like_ratio = np.log10(likelihood_ratio)
130 | single_scores.append(log_like_ratio)
131 | all_pos.append(pos)
132 |
133 | singleMuts_df = pd.DataFrame({'pos': all_pos, 'abYsis': single_scores})
134 |
135 |
136 | return singleMuts_df
137 |
138 | def get_ablang_single_scores_df(singleMut_ids, ablang_df):
139 | single_scores = []
140 | all_pos = []
141 |
142 | for id in singleMut_ids:
143 | wt, pos, mt = id[0], int(id[1:-1]), id[-1]
144 | log_likelihood_ratio = ablang_df.loc[(ablang_df['pos'] == pos) & (ablang_df['wt'] == wt) & (ablang_df['mt'] == mt)]['log_likelihood_ratio'].to_list()[0]
145 | single_scores.append(log_likelihood_ratio)
146 | all_pos.append(pos)
147 |
148 | singleMuts_df = pd.DataFrame({'pos': all_pos, 'AbLang': single_scores})
149 |
150 |
151 | return singleMuts_df
152 |
153 |
154 | #get melted correlations matrix of structure input x target antigen, for a given antibody
155 | def get_corr(ab_name, files, dms_df, ag_columns, dropLOQ = False ):
156 |
157 | inv_fold_files = {key: value for key, value in files.items() if key not in ['ESM-1v', 'AbLang', 'abYsis']}
158 | other_files = {key: value for key, value in files.items() if key in ['ESM-1v', 'AbLang', 'abYsis']}
159 |
160 | #retreive correlations from inverse fold scores
161 | for key, filepath in inv_fold_files.items():
162 | df = pd.read_csv(filepath) # Read the CSV file into a DataFrame
163 | column_label = key
164 | log_likelihood_column = df['log_likelihood'] # Extract the 'log_likelihood' column
165 | dms_df[column_label] = log_likelihood_column
166 |
167 | #retreive correlations from abysis and esm1v scores
168 | for key, filepath in other_files.items():
169 | if key == 'ESM-1v':
170 | singleMuts_df = pd.read_csv(filepath)
171 | # Get the columns for all esm1v models
172 | esm_columns = [col for col in singleMuts_df.columns if col.startswith('esm')]
173 | # Calculate the average for each esm column
174 | esm_avg_values = singleMuts_df[esm_columns].mean(axis=1)
175 | singleMuts_df[key] = esm_avg_values
176 | if ab_name == 'g6':
177 | #g6 data set is only single mutations
178 | dms_df[key] = esm_avg_values
179 | else:
180 | #cr datasets are multiple mutations
181 | dms_df[key] = transform_single_to_multi(dms_df, singleMuts_df, key, 'genotype', ascending = False)
182 | elif key == 'abYsis':
183 | abysis_df = pd.read_csv(filepath, sep = '\t', dtype = {'pos': int})
184 | if ab_name == 'g6':
185 | dms_df[key] = get_abysis_single_scores_df(dms_df['mutant'], abysis_df)[key]
186 | else:
187 | #cr data set are multiple mutations. use same scoring strategy as esm1v paper
188 | singleMut_ids = pd.read_csv(files['ESM-1v'])['mutant'].to_list()
189 | singleMuts_df = get_abysis_single_scores_df(singleMut_ids, abysis_df)
190 | dms_df[key] = transform_single_to_multi(dms_df, singleMuts_df, key, 'pos', ascending = True)
191 | elif key == 'AbLang':
192 | ablang_df = pd.read_csv(filepath, dtype = {'pos': int})
193 | if ab_name == 'g6':
194 | dms_df[key] = get_ablang_single_scores_df(dms_df['mutant'], ablang_df)[key]
195 | else:
196 | #cr data set are multiple mutations. use same scoring strategy as esm1v paper
197 | singleMut_ids = pd.read_csv(files['ESM-1v'])['mutant'].to_list()
198 | singleMuts_df = get_ablang_single_scores_df(singleMut_ids, ablang_df)
199 | dms_df[key] = transform_single_to_multi(dms_df, singleMuts_df, key, 'pos', ascending = True)
200 |
201 | conditions = list(files.keys())
202 |
203 | #
204 | if not dropLOQ:
205 | correlations = dms_df[ag_columns + conditions].corr(method='spearman')
206 | correlations = correlations.drop(conditions, axis= 1)
207 | correlations = correlations.drop(ag_columns, axis= 0)
208 |
209 | # Melt the correlations DataFrame and rename the index column
210 | melted_correlations = (
211 | correlations
212 | .reset_index()
213 | .melt(id_vars='index', var_name='Target', value_name='Correlation')
214 | .rename(columns={'index': 'Input'})
215 | )
216 | melted_correlations['Method'] = ['LM' if (('ESM' in i) or ('AbLang' in i))
217 | else 'IF' if 'Ab' in i
218 | else 'MSA'
219 | for i in melted_correlations['Input']
220 | ]
221 |
222 | return melted_correlations
223 | #if we wish to drop the samples on the lower limit of quantitation, drop samples from each ag group individually and compute correlation with IF log_like
224 | else:
225 | all_ag_melted = pd.DataFrame({})
226 | for ag in ag_columns:
227 | filt_df = dms_df[dms_df[ag] > min(dms_df[ag])]
228 |
229 | correlations = filt_df[[ag] + conditions].corr(method='spearman')
230 | correlations = correlations.drop(conditions, axis= 1)
231 | correlations = correlations.drop(ag, axis= 0)
232 |
233 | # Melt the correlations DataFrame and rename the index column
234 | melted_correlations = (
235 | correlations
236 | .reset_index()
237 | .melt(id_vars='index', var_name='Target', value_name='Correlation')
238 | ).rename(columns={'index': 'Input'})
239 |
240 | all_ag_melted = pd.concat([all_ag_melted, melted_correlations], ignore_index= True)
241 | all_ag_melted['Method'] = ['LM' if (('ESM' in i) or ('AbLang' in i))
242 | else 'IF' if 'Ab' in i
243 | else 'MSA'
244 | for i in all_ag_melted['Input']
245 | ]
246 |
247 | return all_ag_melted
248 |
249 |
250 | #get melted correlations matrix of structure input x target antigen, for a given antibody
251 | def get_g6_corr(g6Hc_dict, g6LC_dict, dropLOQ = False ):
252 |
253 | ab_name = 'g6'
254 | ag_columns = g6LC_dict['ag_columns']
255 | vh_and_vl_dms_df = pd.DataFrame({})
256 |
257 | for d in [g6HC_dict, g6LC_dict]:
258 |
259 | files = d['files']
260 | dms_df = d['dms_df']
261 |
262 | inv_fold_files = {key: value for key, value in files.items() if key not in ['ESM-1v', 'AbLang', 'abYsis']}
263 | other_files = {key: value for key, value in files.items() if key in ['ESM-1v', 'AbLang', 'abYsis']}
264 |
265 | #retreive correlations from inverse fold scores
266 | for key, filepath in inv_fold_files.items():
267 | df = pd.read_csv(filepath) # Read the CSV file into a DataFrame
268 | column_label = key
269 | log_likelihood_column = df['log_likelihood'] # Extract the 'log_likelihood' column
270 | dms_df[column_label] = log_likelihood_column
271 |
272 | #retreive correlations from abysis, ablang, and esm1v scores
273 | for key, filepath in other_files.items():
274 | if key == 'ESM-1v':
275 | singleMuts_df = pd.read_csv(filepath)
276 | # Get the columns for all esm1v models
277 | esm_columns = [col for col in singleMuts_df.columns if col.startswith('esm')]
278 | # Calculate the average for each esm column
279 | esm_avg_values = singleMuts_df[esm_columns].mean(axis=1)
280 | singleMuts_df[key] = esm_avg_values
281 | if ab_name == 'g6':
282 | #g6 data set is only single mutations
283 | dms_df[key] = esm_avg_values
284 | elif key == 'abYsis':
285 | abysis_df = pd.read_csv(filepath, sep = '\t', dtype = {'pos': int})
286 | if ab_name == 'g6':
287 | dms_df[key] = get_abysis_single_scores_df(dms_df['mutant'], abysis_df)[key]
288 | elif key == 'AbLang':
289 | ablang_df = pd.read_csv(filepath, dtype = {'pos': int})
290 | if ab_name == 'g6':
291 | dms_df[key] = get_ablang_single_scores_df(dms_df['mutant'], ablang_df)[key]
292 |
293 |
294 | vh_and_vl_dms_df = pd.concat([vh_and_vl_dms_df, dms_df], ignore_index= True)
295 |
296 | conditions = list(files.keys())
297 |
298 | #
299 | if not dropLOQ:
300 | correlations = vh_and_vl_dms_df[ag_columns + conditions].corr(method='spearman')
301 | correlations = correlations.drop(conditions, axis= 1)
302 | correlations = correlations.drop(ag_columns, axis= 0)
303 |
304 | # Melt the correlations DataFrame and rename the index column
305 | melted_correlations = (
306 | correlations
307 | .reset_index()
308 | .melt(id_vars='index', var_name='Target', value_name='Correlation')
309 | .rename(columns={'index': 'Input'})
310 | )
311 | melted_correlations['Method'] = ['LM' if (('ESM' in i) or ('AbLang' in i))
312 | else 'IF' if 'Ab' in i
313 | else 'MSA'
314 | for i in melted_correlations['Input']
315 | ]
316 | melted_correlations['Target'] = ['VEGF-A'] * len(melted_correlations)
317 |
318 | return melted_correlations
319 | #if we wish to drop the samples on the lower limit of quantitation, drop samples from each ag group individually and compute correlation with IF log_like
320 | else:
321 | all_ag_melted = pd.DataFrame({})
322 | for ag in ag_columns:
323 | filt_df = vh_and_vl_dms_df[vh_and_vl_dms_df[ag] > min(vh_and_vl_dms_df[ag])]
324 |
325 | correlations = filt_df[[ag] + conditions].corr(method='spearman')
326 | correlations = correlations.drop(conditions, axis= 1)
327 | correlations = correlations.drop(ag, axis= 0)
328 |
329 | # Melt the correlations DataFrame and rename the index column
330 | melted_correlations = (
331 | correlations
332 | .reset_index()
333 | .melt(id_vars='index', var_name='Target', value_name='Correlation')
334 | ).rename(columns={'index': 'Input'})
335 |
336 | all_ag_melted = pd.concat([all_ag_melted, melted_correlations], ignore_index= True)
337 | all_ag_melted['Method'] = ['LM' if (('ESM' in i) or ('AbLang' in i))
338 | else 'IF' if 'Ab' in i
339 | else 'MSA'
340 | for i in all_ag_melted['Input']
341 | ]
342 | melted_correlations['Target'] = ['VEGF-A'] * len(melted_correlations)
343 |
344 | return all_ag_melted
345 |
346 |
347 | #plot correlation bars for cr antibodies: cr6261 and cr9114
348 | def plot_hbar(title, melted_correlations, palette = bokeh.palettes.Spectral6, compare = 'model'):
349 |
350 | if compare == 'model':
351 | melted_correlations.replace('Ab-Ag','InverseFolding', inplace = True)
352 | color_by = 'Input'
353 | elif compare == 'IF':
354 | color_by = 'Method'
355 | else:
356 | color_by = 'Target'
357 |
358 |
359 | melted_correlations['cats'] = melted_correlations.apply(lambda x: (x["Target"], x["Input"]), axis = 1)
360 | factors = list(melted_correlations.cats)[::-1]
361 |
362 | if 'CR6261' in title:
363 | x_min = -0.15
364 | else:
365 | x_min = 0
366 |
367 | p = bokeh.plotting.figure(
368 | height=340,
369 | width=440,
370 | x_axis_label="Spearman Correlation",
371 | x_range=[x_min, 1],
372 | y_range=bokeh.models.FactorRange(*factors),
373 | tools="save",
374 | title = title
375 | )
376 |
377 |
378 | p.hbar(
379 | source=melted_correlations,
380 | y="cats",
381 | right="Correlation",
382 | height=0.6,
383 | line_color = 'black',
384 | legend_field = color_by,
385 | # use the palette to colormap based on the the x[1:2] values
386 | fill_color=factor_cmap(
387 | color_by,
388 | palette= palette,
389 | factors = list(melted_correlations[color_by].unique()),
390 | start=1,
391 | end=3)
392 | )
393 |
394 | #labels_df = melted_correlations[melted_correlations['Input'].isin(['Ab-Ag','ESM-1v', 'abYsis', 'InverseFold'])]
395 | labels_df = melted_correlations
396 | labels_df['corr_str'] = labels_df['Correlation'].apply(lambda x: round(x, 2)).astype(str)
397 | labels_source = bokeh.models.ColumnDataSource(labels_df)
398 |
399 | labels = bokeh.models.LabelSet(x='Correlation', y='cats', text='corr_str',text_font_size = "10px",
400 | x_offset=12, y_offset=-5, source=labels_source, render_mode='canvas')
401 |
402 | p.ygrid.grid_line_color = None
403 | p.y_range.range_padding = 0.1
404 | p.add_layout(labels)
405 |
406 | if compare == 'IF':
407 | p.legend.orientation = "vertical"
408 | p.legend.location = "bottom_right"
409 | p.legend.label_text_font_size = "8pt"
410 | else:
411 | p.legend.visible = False
412 |
413 | p.output_backend = "svg"
414 | return p
415 |
416 | def plot_g6_scatter(g6Hc_df, g6Lc_df, ag_columns, dropLOQ):
417 | ag = ag_columns[0]
418 | if dropLOQ:
419 | g6Lc_df = g6Lc_df[g6Lc_df[ag] > np.min(g6Lc_df[ag])]
420 | g6Hc_df = g6Hc_df[g6Hc_df[ag] > np.min(g6Hc_df[ag])]
421 |
422 |
423 | hv.extension("bokeh")
424 |
425 | plots = []
426 | for (plot_df,title) in [(g6Hc_df, 'g6.31 VH'),(g6Lc_df, 'g6.31 VL')]:
427 |
428 | # Generate HoloViews Points Element
429 | points = hv.Points(
430 | data=plot_df,
431 | kdims=[ag, 'Ab-Ag'],
432 | )
433 |
434 | # Datashade with spreading of points
435 | p = hv.operation.datashader.dynspread(
436 | hv.operation.datashader.rasterize(
437 | points
438 | ).opts(
439 | cmap='Magma',
440 | cnorm='linear',
441 | )
442 | ).opts(
443 | frame_width=350,
444 | frame_height=300,
445 | padding=0.05,
446 | show_grid=False,
447 | colorbar = True
448 | )
449 |
450 | p = hv.render(p)
451 | p.output_backend = "svg"
452 |
453 | plots.append(p)
454 |
455 | return plots
456 |
457 | #plot scatter plot colored by n_muts for CR antibodies
458 | def single_cr_scatterMut(plot_df, ag_columns, x, y, title, ):
459 | x_axis_label = x + '-logKd' if x in ag_columns else 'log likelihood'
460 | y_axis_label = y + '-logKd' if y in ag_columns else 'log likelihood'
461 |
462 |
463 | #set up figure
464 | p = bokeh.plotting.figure(
465 | width=500,
466 | height=400,
467 | x_axis_label= x_axis_label,
468 | y_axis_label= y_axis_label,
469 | title = title
470 | )
471 |
472 | #Set up color mapper to color by number of mutations from germline
473 | mapper = bokeh.transform.linear_cmap(field_name='som_mut', palette=bokeh.palettes.inferno(16) ,low=min(plot_df.som_mut) ,high=max(plot_df.som_mut))
474 | p.circle(
475 | source=plot_df,
476 | x= x,
477 | y= y,
478 | alpha = 0.2,
479 | line_color=mapper,
480 | color=mapper,
481 | size = 4
482 | )
483 |
484 | #
485 | p.diamond(
486 | #germline seqid is all 0's
487 | source = plot_df[plot_df['genotype'].str.contains('1') == False],
488 | x= x,
489 | y= y,
490 | color = 'dodgerblue',
491 | size = 12,
492 | legend_label = 'germline'
493 | )
494 | p.star(
495 | #mature seqid is all 0's
496 | source = plot_df[plot_df['genotype'].str.contains('0') == False],
497 | x= x,
498 | y= y,
499 | color = 'green',
500 | size = 12,
501 | legend_label = 'mature'
502 | )
503 | p.legend.location = "top_left"
504 | p.output_backend = "svg"
505 |
506 | return p, mapper
507 |
508 | #plot single scatter plot for CR antibodies - with dynamic spreading and datashading
509 | def single_cr_scatter(plot_df, ag_columns, x, y, title, ):
510 | x_axis_label = x + '-logKd' if x in ag_columns else 'log likelihood'
511 | y_axis_label = y + '-logKd' if y in ag_columns else 'log likelihood'
512 |
513 | hv.extension("bokeh")
514 |
515 | # Generate HoloViews Points Element
516 | points = hv.Points(
517 | data=plot_df,
518 | kdims=[x, y],
519 | )
520 |
521 | # Datashade with spreading of points
522 | p = hv.operation.datashader.dynspread(
523 | hv.operation.datashader.rasterize(
524 | points
525 | ).opts(
526 | cmap='Magma',
527 | cnorm='linear',
528 | )
529 | ).opts(
530 | frame_width=350,
531 | frame_height=300,
532 | padding=0.05,
533 | show_grid=False,
534 | colorbar = True
535 | )
536 |
537 | p = hv.render(p)
538 |
539 | p.grid.visible = False
540 | #
541 | p.diamond(
542 | #germline seqid is all 0's
543 | source = plot_df[plot_df['genotype'].str.contains('1') == False],
544 | x= x,
545 | y= y,
546 | color = 'dodgerblue',
547 | size = 12,
548 | #legend_label = 'germline'
549 | )
550 | p.star(
551 | #mature seqid is all 0's
552 | source = plot_df[plot_df['genotype'].str.contains('0') == False],
553 | x= x,
554 | y= y,
555 | color = 'green',
556 | size = 12,
557 | #legend_label = 'mature'
558 | )
559 | #p.legend.location = "top_left"
560 | p.output_backend = "svg"
561 |
562 | return p
563 |
564 | def plot_CR_scatter_colorMuts(title, files, dms_df, ag_columns, dropLOQ = False):
565 |
566 | agAb_files = {key: value for key, value in files.items() if key in ['Ab-Ag',]}
567 |
568 | for key, filepath in agAb_files.items():
569 | df = pd.read_csv(filepath) # Read the CSV file into a DataFrame
570 | column_label = key
571 | log_likelihood_column = df['log_likelihood'] # Extract the 'log_likelihood' column
572 | dms_df[column_label] = log_likelihood_column
573 |
574 | plots = []
575 | #predictions vs experiment
576 | for ag in ag_columns:
577 |
578 | if dropLOQ:
579 | plot_df = dms_df[dms_df[ag] > min(dms_df[ag])]
580 | #add back wt which is on LOQ for visualization
581 | plot_df = pd.concat([plot_df, dms_df[dms_df['som_mut'] == 0]], ignore_index = True, axis = 0)
582 | else:
583 | plot_df = dms_df
584 |
585 | p, _ = single_cr_scatterMut(plot_df, ag_columns, x = ag, y = 'Ab-Ag', title = f'{title} against {ag}')
586 | plots.append(p)
587 |
588 | #experimental vs experimental
589 | plot_df = dms_df
590 | if dropLOQ:
591 | for ag in ag_columns:
592 | plot_df = plot_df[plot_df[ag] > min(plot_df[ag])]
593 | #add back wt which is on LOQ for visualization
594 | plot_df = pd.concat([plot_df, dms_df[dms_df['som_mut'] == 0]], ignore_index = True, axis = 0)
595 |
596 | p_exp, mapper = single_cr_scatterMut(plot_df, ag_columns, x= ag_columns[0], y = ag_columns[1], title = 'Experimental Cross-Reactive Binding Landscape')
597 | p_exp.legend.location = "top_left"
598 | color_bar = bokeh.models.ColorBar(color_mapper=mapper['transform'], width=8, title = 'Amino Acid Mutations')
599 | #p_exp.add_layout(color_bar, 'right')
600 |
601 | plots.append(p_exp)
602 | p.output_backend = "svg"
603 |
604 | return plots
605 |
606 | def plot_CR_scatterplots(title, files, dms_df, ag_columns, dropLOQ = False):
607 |
608 | agAb_files = {key: value for key, value in files.items() if key in ['Ab-Ag',]}
609 |
610 | for key, filepath in agAb_files.items():
611 | df = pd.read_csv(filepath) # Read the CSV file into a DataFrame
612 | column_label = key
613 | log_likelihood_column = df['log_likelihood'] # Extract the 'log_likelihood' column
614 | dms_df[column_label] = log_likelihood_column
615 |
616 | plots = []
617 | #predictions vs experiment
618 | for ag in ag_columns:
619 |
620 | if dropLOQ:
621 | plot_df = dms_df[dms_df[ag] > min(dms_df[ag])]
622 | #add back wt which is on LOQ for visualization
623 | plot_df = pd.concat([plot_df, dms_df[dms_df['som_mut'] == 0]], ignore_index = True, axis = 0)
624 | else:
625 | plot_df = dms_df
626 |
627 | p = single_cr_scatter(plot_df, ag_columns, x = ag, y = 'Ab-Ag', title = f'{title} against {ag}')
628 | plots.append(p)
629 |
630 | #experimental vs experimental
631 | plot_df = dms_df
632 | if dropLOQ:
633 | for ag in ag_columns:
634 | plot_df = plot_df[plot_df[ag] > min(plot_df[ag])]
635 | #add back wt which is on LOQ for visualization
636 | plot_df = pd.concat([plot_df, dms_df[dms_df['som_mut'] == 0]], ignore_index = True, axis = 0)
637 |
638 | p_exp = single_cr_scatter(plot_df, ag_columns, x= ag_columns[0], y = ag_columns[1], title = 'Experimental Cross-Reactive Binding Landscape')
639 |
640 | plots.append(p_exp)
641 | p_exp.output_backend = "svg"
642 |
643 | return plots
644 |
645 | #plot correlation bars for cr antibodies: cr6261 and cr9114
646 | def plot_LOQ_comparison(ab_name, unfilt_corrs, filt_corrs, palette = bokeh.palettes.Spectral6,):
647 | unfilt_corrs['LOQ'] = ['Included'] * len(unfilt_corrs)
648 | filt_corrs['LOQ'] = ['Excluded'] * len(filt_corrs)
649 | combined_melt = pd.concat([unfilt_corrs,filt_corrs], ignore_index= True)
650 |
651 | combined_melt['cats'] = combined_melt.apply(lambda x: (x["Target"], x["Input"], x["LOQ"]), axis = 1)
652 | factors = list(combined_melt.cats)[::-1]
653 |
654 | p = bokeh.plotting.figure(
655 | height=850,
656 | width=400,
657 | x_axis_label="Spearman Correlation",
658 | x_range=[-.25, 1],
659 | y_range=bokeh.models.FactorRange(*factors),
660 | tools="save",
661 | title = f'{ab_name}, Impact of Ignoring Lower LOQ '
662 | )
663 |
664 | p.hbar(
665 | source=combined_melt,
666 | y="cats",
667 | right="Correlation",
668 | height=0.7,
669 | line_color = 'black',
670 | legend_field = 'LOQ',
671 | # use the palette to colormap based on the the x[1:2] values
672 | fill_color=factor_cmap(
673 | 'LOQ',
674 | palette= palette,
675 | factors = list(combined_melt['LOQ'].unique()),
676 | start=1,
677 | end=2)
678 | )
679 |
680 | p.ygrid.grid_line_color = None
681 | p.y_range.range_padding = 0.1
682 | p.legend.location = "bottom_right"
683 | p.legend.orientation = "vertical"
684 | p.yaxis.axis_label = "Model Input"
685 | p.output_backend = "svg"
686 |
687 | return p
688 |
689 |
690 |
691 | if __name__ == '__main__':
692 |
693 | datasets = [cr9114_dict, cr6261_dict, (g6HC_dict, g6LC_dict) ]
694 | results = []
695 |
696 | #compute with and without points on lower limit of quantitation ignored
697 | for dropLOQ in [False, True]:
698 |
699 | g6_corrs = []
700 | all_corr_plots = []
701 |
702 | for d in datasets:
703 | if type(d) is tuple:
704 | g6Hc, g6Lc = d
705 | title = g6Hc['ab_name'] + ', ' + g6Hc['expt_type']
706 | g6_combined = get_g6_corr(g6Hc, g6Lc, dropLOQ= dropLOQ)
707 |
708 | all_corr_plots.append(plot_hbar(title, g6_combined, g6Lc['palette'], 'IF'))
709 |
710 | #also plot scatter plots
711 | scatter_fname = f"output/ab_mutagenesis_expts/g6_scatterplot_colorMuts{'_dropLOQ' if dropLOQ else ''}.html"
712 | bokeh.plotting.output_file(scatter_fname)
713 | g6_scatterMuts_plots = plot_g6_scatter(g6Hc['dms_df'], g6Lc['dms_df'], g6Hc['ag_columns'], dropLOQ= dropLOQ)
714 | bokeh.io.show(bokeh.layouts.gridplot(g6_scatterMuts_plots, ncols = 2))
715 |
716 | if (dropLOQ == False):
717 | results.append(g6_combined)
718 |
719 | else:
720 | corr_df = get_corr(d['ab_name'], d['files'], d['dms_df'], d['ag_columns'], dropLOQ= dropLOQ)
721 | title = d['ab_name'] + ', ' + d['expt_type']
722 |
723 | all_corr_plots.append(plot_hbar(title, corr_df, d['palette'], 'IF'))
724 |
725 | #also plot scatter plots for CR9114, CR6261
726 | scatter_fname = f"output/ab_mutagenesis_expts/{d['ab_name']}_scatter_colorMuts{'_dropLOQ' if dropLOQ else ''}.html"
727 | bokeh.plotting.output_file(scatter_fname)
728 | cr_scatterMuts_plots = plot_CR_scatter_colorMuts(title, d['files'], d['dms_df'], d['ag_columns'], dropLOQ= dropLOQ)
729 | bokeh.io.show(bokeh.layouts.gridplot(cr_scatterMuts_plots, ncols = 3))
730 |
731 | #also scatter plots with datashading/dynamic spreading for CR9114, CR6261
732 | scatter_fname = f"output/ab_mutagenesis_expts/{d['ab_name']}_scatter{'_dropLOQ' if dropLOQ else ''}.html"
733 | bokeh.plotting.output_file(scatter_fname)
734 | cr_scattter_plots = plot_CR_scatterplots(title, d['files'], d['dms_df'], d['ag_columns'], dropLOQ= dropLOQ)
735 | bokeh.io.show(bokeh.layouts.gridplot(cr_scattter_plots, ncols = 3))
736 |
737 | if (dropLOQ == False):
738 | results.append(corr_df)
739 |
740 | all_corr_fname = f"output/ab_mutagenesis_expts/benchmarks{'_dropLOQ' if dropLOQ else ''}.html"
741 | bokeh.plotting.output_file(all_corr_fname)
742 | bokeh.io.show(bokeh.layouts.gridplot(all_corr_plots, ncols = len(all_corr_plots)))
743 |
744 | compareLOQ_fname = 'output/ab_mutagenesis_expts/LOQ_comparison.html'
745 | bokeh.plotting.output_file(compareLOQ_fname)
746 | compareLOQ_plots = []
747 | for d in datasets:
748 | if (type(d) != tuple) and (d['ab_name'].startswith('CR')):
749 | corr_withLOQ = get_corr(d['ab_name'], d['files'], d['dms_df'], d['ag_columns'], )
750 | corr_withoutLOQ = get_corr(d['ab_name'], d['files'], d['dms_df'], d['ag_columns'], dropLOQ= dropLOQ)
751 |
752 | compareLOQ_plots.append(plot_LOQ_comparison(d['ab_name'], corr_withLOQ, corr_withoutLOQ , d['palette']))
753 | bokeh.io.show(bokeh.layouts.gridplot(compareLOQ_plots, ncols = len(compareLOQ_plots)))
754 |
755 | pd.concat(results, ignore_index= True).to_csv( 'output/ab_mutagenesis_expts/results.csv' , index = False)
756 |
--------------------------------------------------------------------------------