├── CalcMetrics.xml ├── IteratedConvergence.xml ├── IteratedConvergence_GFP.xml ├── README.md ├── RelaxSimpleDesign_gfp.xml ├── SimpleDesign.xml ├── calc_metrics.sh ├── converge_it.sh ├── converge_it_gfp.sh ├── design.options ├── design_analysis.ipynb ├── design_analysis_known_mutations_only.ipynb ├── emi ├── ESM03.sh ├── ESM1.sh ├── ESM15.sh ├── FastDesign.sh ├── IC_ESM1.sh ├── IC_MIFST1.sh ├── IC_calc_metrics.sh ├── MIFST03.sh ├── MIFST1.sh ├── MIFST15.sh ├── RelaxDesign_calc_metrics.sh ├── RelaxSimpleMPNN1.sh ├── SimpleMPNN03.sh ├── SimpleMPNN1.sh ├── SimpleMPNN15.sh ├── avg03.sh ├── avg1.sh ├── avg15.sh ├── calc_metrics.sh ├── emi_LDA_ANT.joblib ├── emi_LDA_OVA.joblib ├── emi_binding.csv └── emi_designs.csv ├── environment.yaml ├── gb1 ├── ESM03.sh ├── ESM1.sh ├── ESM15.sh ├── FastDesign.sh ├── IC_ESM1.sh ├── IC_MIFST1.sh ├── IC_calc_metrics.sh ├── MIFST03.sh ├── MIFST1.sh ├── MIFST15.sh ├── RelaxDesign_calc_metrics.sh ├── RelaxSimpleMPNN1.sh ├── SimpleMPNN03.sh ├── SimpleMPNN1.sh ├── SimpleMPNN15.sh ├── avg03.sh ├── avg1.sh ├── avg15.sh ├── calc_metrics.sh ├── gb1_designs.csv ├── gb1_mutations_full_data.csv └── gb1_ridge.joblib ├── gfp ├── ESM03.sh ├── ESM1.sh ├── ESM15.sh ├── FastDesign.sh ├── IC_ESM1.sh ├── IC_MIFST1.sh ├── MIFST03.sh ├── MIFST1.sh ├── MIFST15.sh ├── RelaxSimpleMPNN1.sh ├── SimpleMPNN03.sh ├── SimpleMPNN1.sh ├── SimpleMPNN15.sh ├── avg03.sh ├── avg1.sh ├── avg15.sh ├── calc_metrics.sh ├── gfp_data.csv └── gfp_designs.csv ├── herceptin ├── ESM03.sh ├── ESM1.sh ├── ESM15.sh ├── FastDesign.sh ├── IC_ESM1.sh ├── IC_MIFST1.sh ├── IC_calc_metrics.sh ├── MIFST03.sh ├── MIFST1.sh ├── MIFST15.sh ├── RelaxDesign_calc_metrics.sh ├── RelaxSimpleMPNN1.sh ├── SimpleMPNN03.sh ├── SimpleMPNN1.sh ├── SimpleMPNN15.sh ├── avg03.sh ├── avg1.sh ├── avg15.sh ├── calc_metrics.sh ├── herceptin_designs.csv ├── lda_herceptin.joblib ├── mHER_H3_AgNeg.csv └── mHER_H3_AgPos.csv ├── model_training.ipynb └── simple_design.sh /CalcMetrics.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /IteratedConvergence.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /IteratedConvergence_GFP.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code for "Self-supervised machine learning methods for protein design improve sampling, but not the identification of high-fitness variants" 2 | ![figure_1_overview_horizontal (2)](https://github.com/meilerlab/probabilities_design/assets/59534445/35dd28d5-2c5a-4245-81ad-8fc8c30eaeb7) 3 | 4 | This repo contains the code for reproducing the results of the publication "Self-supervised machine learning methods for protein design improve sampling, but not the identification of high-fitness variants" [(link to preprint)](https://www.biorxiv.org/content/10.1101/2024.06.20.599843v1). If you are interested instead in using the implemented features for your own work, an overview of them [can be found in the Rosetta documentation here](https://www.rosettacommons.org/docs/latest/scripting_documentation/RosettaScripts/composite_protocols/Working-with-PerResidueProbabilitiesMetrics), and a tutorial is available from the Meiler Rosetta workshop 2023 ["Tutorial 2: Machine Learning in Rosetta"](https://meilerlab.org/rosetta-workshop-2023/). 5 | 6 | ## Running the different design protocols 7 | Code for running the different design protocols can be found in the folder of each dataset, e.g. `emi/avg03.sh`. All scripts use the RosettaScripts XML provided in the main folder which are named after the different protocols shown in the paper. 8 | 9 | ## Sequences and metrics of resulting designs 10 | The unique sequences and calculated metrics of each design protocol are available in the dataset folders ("dataset/dataset_designs.csv"), e.g. `emi/emi_designs.csv`. 11 | 12 | ## Analysis of designs 13 | The code for analyzing the resulting designs and reproducing figures can be found in the `design_analysis.ipynb` notebook. In order to run the jupyter notebooks, first create a python environment using the `environment.yaml` file with either conda or mamba: 14 | 15 | ``` 16 | # create environment 17 | conda env create -f environment.yaml 18 | # activate environment 19 | conda activate probs_design 20 | ``` 21 | 22 | ## Oracle model data and training 23 | The code for training and evaluating the oracle models for each dataset can be found in the `model_training.ipynb` notebook. The datasets used for training can be found in each dataset folder, e.g. `gb1/gb1_mutations_full_data.csv`. The already trained models are also available, e.g. `gb1/gb1_ridge.joblib`. 24 | 25 | ## Rosetta code 26 | The Rosetta source code can be found at https://github.com/RosettaCommons/rosetta/. Docker containers for Rosetta (including the Tensorflow/LibTorch extras version) can be found at https://hub.docker.com/r/rosettacommons/rosetta. 27 | -------------------------------------------------------------------------------- /RelaxSimpleDesign_gfp.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /SimpleDesign.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | ​ 9 | 10 | 11 | ​ 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | ​ 44 | 45 | 46 | 47 | 48 | ​ 49 | 50 | 51 | 52 | ​ 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | ​ 61 | ​ 62 | 63 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /calc_metrics.sh: -------------------------------------------------------------------------------- 1 | Rosetta/main/source/bin/rosetta_scripts.pytorchtensorflow.linuxgccrelease @ ./metric.options -parser:protocol ./CalcMetrics.xml -l $1 -parser:script_vars design_chain=$2 antigen=$3 AIFA=$4 interface=$5 -out:path:all "${6}" -out:file:score_only "${7}" 2 | 3 | -------------------------------------------------------------------------------- /converge_it.sh: -------------------------------------------------------------------------------- 1 | Rosetta/main/source/bin/rosetta_scripts.pytorchtensorflow.linuxgccrelease @ ./design.options -parser:protocol ./IteratedConvergence.xml -s $1 -parser:script_vars protocol=$2 design_chain=$3 antigen=$4 AIFA=$5 pos_temp=$6 aa_temp=$7 n_muts=$8 resfile=${10} filter=$9 -out:path:all "${11}" 2 | 3 | -------------------------------------------------------------------------------- /converge_it_gfp.sh: -------------------------------------------------------------------------------- 1 | Rosetta/main/source/bin/rosetta_scripts.pytorchtensorflow.linuxgccrelease @ ./design.options -parser:protocol ./IteratedConvergence_GFP.xml -s $1 -parser:script_vars protocol=$2 design_chain=$3 antigen=$4 AIFA=$5 pos_temp=$6 aa_temp=$7 n_muts=$8 resfile=${10} filter=$9 -out:path:all "${11}" 2 | 3 | -------------------------------------------------------------------------------- /design.options: -------------------------------------------------------------------------------- 1 | -linmem_ig 5 2 | -ex1 3 | -ex2aro 4 | -beta 5 | -never_rerun_filters 6 | -multiple_processes_writing_to_one_directory 7 | -nstruct 1000 8 | -------------------------------------------------------------------------------- /emi/ESM03.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../emi/emi_sema_complex_relax_best.pdb sample_mutations_esm A C 1 0.3 0.3 100 ../emi/resfile.resfile ../emi/output_ESM03/ 2 | -------------------------------------------------------------------------------- /emi/ESM1.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../emi/emi_sema_complex_relax_best.pdb sample_mutations_esm A C 1 1 1 100 ../emi/resfile.resfile ../emi/output_ESM1/ 2 | -------------------------------------------------------------------------------- /emi/ESM15.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../emi/emi_sema_complex_relax_best.pdb sample_mutations_esm A C 1 1.5 1.5 100 ../emi/resfile.resfile ../emi/output_ESM15/ 2 | -------------------------------------------------------------------------------- /emi/FastDesign.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../emi/emi_sema_complex_relax_best.pdb if_relax A C 1 0.0 0.0 0 ../emi/resfile.resfile ../emi/output_FastDesign/ 2 | -------------------------------------------------------------------------------- /emi/IC_ESM1.sh: -------------------------------------------------------------------------------- 1 | ../converge_it.sh ../emi/emi_sema_complex_relax_best.pdb sample_mutations_esm A C 1 1 1 1 filt_pp_esm ../emi/resfile.resfile ../emi/output_IC_ESM1/ 2 | -------------------------------------------------------------------------------- /emi/IC_MIFST1.sh: -------------------------------------------------------------------------------- 1 | ../converge_it.sh ../emi/emi_sema_complex_relax_best.pdb sample_mutations_mifst A C 1 1 1 1 filt_pp_mifst ../emi/resfile.resfile ../emi/output_IC_MIFST1/ 2 | -------------------------------------------------------------------------------- /emi/IC_calc_metrics.sh: -------------------------------------------------------------------------------- 1 | ../calc_metrics.sh ../emi/results_IC_ESM1.list A C 1 AB_C ../emi/ score_IC_ESM1.sc 2 | -------------------------------------------------------------------------------- /emi/MIFST03.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../emi/emi_sema_complex_relax_best.pdb sample_mutations_mifst A C 1 0.3 0.3 100 ../emi/resfile.resfile ../emi/output_MIFST03/ 2 | -------------------------------------------------------------------------------- /emi/MIFST1.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../emi/emi_sema_complex_relax_best.pdb sample_mutations_mifst A C 1 1 1 100 ../emi/resfile.resfile ../emi/output_MIFST1/ 2 | -------------------------------------------------------------------------------- /emi/MIFST15.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../emi/emi_sema_complex_relax_best.pdb sample_mutations_mifst A C 1 1.5 1.5 100 ../emi/resfile.resfile ../emi/output_MIFST15/ 2 | -------------------------------------------------------------------------------- /emi/RelaxDesign_calc_metrics.sh: -------------------------------------------------------------------------------- 1 | ../calc_metrics.sh ../emi/results_RelaxSimpleMPNN1.list A C 1 AB_C ../emi/ score_RelaxSimpleMPNN1.sc 2 | -------------------------------------------------------------------------------- /emi/RelaxSimpleMPNN1.sh: -------------------------------------------------------------------------------- 1 | ../relax_design.sh ../emi/emi_sema_complex_relax_best.pdb sample_mutations_mpnn A C 1 1 1 100 ../emi/resfile.resfile ../emi/output_RelaxSimpleMPNN1/ 2 | -------------------------------------------------------------------------------- /emi/SimpleMPNN03.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../emi/emi_sema_complex_relax_best.pdb sample_mutations_mpnn A C 1 0.3 0.3 100 ../emi/resfile.resfile ../emi/output_SimpleMPNN03/ 2 | -------------------------------------------------------------------------------- /emi/SimpleMPNN1.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../emi/emi_sema_complex_relax_best.pdb sample_mutations_mpnn A C 1 1 1 100 ../emi/resfile.resfile ../emi/output_SimpleMPNN1/ 2 | -------------------------------------------------------------------------------- /emi/SimpleMPNN15.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../emi/emi_sema_complex_relax_best.pdb sample_mutations_mpnn A C 1 1.5 1.5 100 ../emi/resfile.resfile ../emi/output_SimpleMPNN15/ 2 | -------------------------------------------------------------------------------- /emi/avg03.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../emi/emi_sema_complex_relax_best.pdb sample_mutations_avg A C 1 0.3 0.3 100 ../emi/resfile.resfile ../emi/output_avg03/ 2 | -------------------------------------------------------------------------------- /emi/avg1.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../emi/emi_sema_complex_relax_best.pdb sample_mutations_avg A C 1 1 1 100 ../emi/resfile.resfile ../emi/output_avg1/ 2 | -------------------------------------------------------------------------------- /emi/avg15.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../emi/emi_sema_complex_relax_best.pdb sample_mutations_avg A C 1 1.5 1.5 100 ../emi/resfile.resfile ../emi/output_avg15/ 2 | -------------------------------------------------------------------------------- /emi/calc_metrics.sh: -------------------------------------------------------------------------------- 1 | ../calc_metrics.sh ../emi/results_FastDesign.list A C 1 AB_C ../emi/ score_FastDesign.sc 2 | -------------------------------------------------------------------------------- /emi/emi_LDA_ANT.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meilerlab/probabilities_design/31916a3b5792737b49d0833a3396d1da546460c6/emi/emi_LDA_ANT.joblib -------------------------------------------------------------------------------- /emi/emi_LDA_OVA.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meilerlab/probabilities_design/31916a3b5792737b49d0833a3396d1da546460c6/emi/emi_LDA_OVA.joblib -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: probs_design 2 | channels: 3 | - conda-forge 4 | - pkgs/main 5 | dependencies: 6 | - _libgcc_mutex=0.1 7 | - _openmp_mutex=4.5 8 | - backcall=0.2.0 9 | - backports=1.0 10 | - backports.functools_lru_cache=2.0.0 11 | - bottleneck=1.3.5 12 | - brotli=1.1.0 13 | - brotli-bin=1.1.0 14 | - ca-certificates=2024.6.2 15 | - certifi=2024.6.2 16 | - cycler=0.11.0 17 | - debugpy=1.6.3 18 | - decorator=5.1.1 19 | - entrypoints=0.4 20 | - fonttools=4.38.0 21 | - freetype=2.12.1 22 | - imbalanced-learn=0.8.1 23 | - ipykernel=6.16.2 24 | - ipython=7.33.0 25 | - jedi=0.19.1 26 | - joblib=0.17.0 27 | - jpeg=9e 28 | - jupyter_client=7.4.9 29 | - jupyter_core=4.11.1 30 | - kiwisolver=1.4.4 31 | - lcms2=2.14 32 | - ld_impl_linux-64=2.40 33 | - lerc=4.0.0 34 | - libblas=3.9.0 35 | - libbrotlicommon=1.1.0 36 | - libbrotlidec=1.1.0 37 | - libbrotlienc=1.1.0 38 | - libcblas=3.9.0 39 | - libdeflate=1.14 40 | - libffi=3.3 41 | - libgcc-ng=13.2.0 42 | - libgfortran-ng=13.2.0 43 | - libgfortran5=13.2.0 44 | - libgomp=13.2.0 45 | - liblapack=3.9.0 46 | - libopenblas=0.3.25 47 | - libpng=1.6.43 48 | - libsodium=1.0.18 49 | - libsqlite=3.46.0 50 | - libstdcxx-ng=13.2.0 51 | - libtiff=4.4.0 52 | - libwebp-base=1.4.0 53 | - libxcb=1.13 54 | - libzlib=1.2.13 55 | - matplotlib-base=3.5.3 56 | - matplotlib-inline=0.1.7 57 | - munkres=1.1.4 58 | - ncurses=6.5 59 | - nest-asyncio=1.6.0 60 | - nomkl=1.0 61 | - numexpr=2.8.3 62 | - numpy=1.21.6 63 | - openjpeg=2.5.0 64 | - openssl=1.1.1w 65 | - packaging=23.2 66 | - pandas=1.3.5 67 | - parso=0.8.4 68 | - patsy=0.5.6 69 | - pexpect=4.9.0 70 | - pickleshare=0.7.5 71 | - pillow=9.2.0 72 | - pip=24.0 73 | - prompt-toolkit=3.0.47 74 | - psutil=5.9.3 75 | - pthread-stubs=0.4 76 | - ptyprocess=0.7.0 77 | - pygments=2.17.2 78 | - pyparsing=3.1.2 79 | - python=3.7.8 80 | - python-dateutil=2.9.0 81 | - python_abi=3.7 82 | - pytz=2024.1 83 | - pyzmq=24.0.1 84 | - readline=8.2 85 | - scikit-learn=1.0.1 86 | - scipy=1.7.3 87 | - seaborn=0.9.0 88 | - setuptools=69.0.3 89 | - six=1.16.0 90 | - sqlite=3.46.0 91 | - statsmodels=0.13.2 92 | - threadpoolctl=3.1.0 93 | - tk=8.6.13 94 | - tornado=6.2 95 | - traitlets=5.9.0 96 | - typing-extensions=4.7.1 97 | - typing_extensions=4.7.1 98 | - unicodedata2=14.0.0 99 | - wcwidth=0.2.10 100 | - wheel=0.42.0 101 | - xorg-libxau=1.0.11 102 | - xorg-libxdmcp=1.1.3 103 | - xz=5.2.6 104 | - zeromq=4.3.5 105 | - zlib=1.2.13 106 | - zstd=1.5.6 107 | 108 | -------------------------------------------------------------------------------- /gb1/ESM03.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gb1/gb1_IgG1FC_relax_best.pdb sample_mutations_esm C A 1 0.3 0.3 100 ../gb1/resfile.resfile ../gb1/output_ESM03/ 2 | -------------------------------------------------------------------------------- /gb1/ESM1.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gb1/gb1_IgG1FC_relax_best.pdb sample_mutations_esm C A 1 1 1 100 ../gb1/resfile.resfile ../gb1/output_ESM1/ 2 | -------------------------------------------------------------------------------- /gb1/ESM15.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gb1/gb1_IgG1FC_relax_best.pdb sample_mutations_esm C A 1 1.5 1.5 100 ../gb1/resfile.resfile ../gb1/output_ESM15/ 2 | -------------------------------------------------------------------------------- /gb1/FastDesign.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gb1/gb1_IgG1FC_relax_best.pdb if_relax C A 1 0.0 0.0 0 ../gb1/resfile.resfile ../gb1/output_FastDesign/ 2 | -------------------------------------------------------------------------------- /gb1/IC_ESM1.sh: -------------------------------------------------------------------------------- 1 | ../converge_it.sh ../gb1/gb1_IgG1FC_relax_best.pdb sample_mutations_esm C A 1 1 1 1 filt_pp_esm ../gb1/resfile.resfile ../gb1/output_IC_ESM1/ 2 | -------------------------------------------------------------------------------- /gb1/IC_MIFST1.sh: -------------------------------------------------------------------------------- 1 | ../converge_it.sh ../gb1/gb1_IgG1FC_relax_best.pdb sample_mutations_mifst C A 1 1 1 1 filt_pp_mifst ../gb1/resfile.resfile ../gb1/output_IC_MIFST1/ 2 | -------------------------------------------------------------------------------- /gb1/IC_calc_metrics.sh: -------------------------------------------------------------------------------- 1 | ../calc_metrics.sh ../gb1/results_IC_ESM1.list C A 1 C_A ../gb1/ score_IC_ESM1.sc 2 | -------------------------------------------------------------------------------- /gb1/MIFST03.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gb1/gb1_IgG1FC_relax_best.pdb sample_mutations_mifst C A 1 0.3 0.3 100 ../gb1/resfile.resfile ../gb1/output_MIFST03/ 2 | -------------------------------------------------------------------------------- /gb1/MIFST1.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gb1/gb1_IgG1FC_relax_best.pdb sample_mutations_mifst C A 1 1 1 100 ../gb1/resfile.resfile ../gb1/output_MIFST1/ 2 | -------------------------------------------------------------------------------- /gb1/MIFST15.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gb1/gb1_IgG1FC_relax_best.pdb sample_mutations_mifst C A 1 1.5 1.5 100 ../gb1/resfile.resfile ../gb1/output_MIFST15/ 2 | -------------------------------------------------------------------------------- /gb1/RelaxDesign_calc_metrics.sh: -------------------------------------------------------------------------------- 1 | ../calc_metrics.sh ../gb1/results_RelaxSimpleMPNN1.list C A 1 C_A ../gb1/ score_RelaxSimpleMPNN1.sc 2 | -------------------------------------------------------------------------------- /gb1/RelaxSimpleMPNN1.sh: -------------------------------------------------------------------------------- 1 | ../relax_design.sh ../gb1/gb1_IgG1FC_relax_best.pdb sample_mutations_mpnn C A 1 1 1 100 ../gb1/resfile.resfile ../gb1/output_RelaxSimpleMPNN1/ 2 | -------------------------------------------------------------------------------- /gb1/SimpleMPNN03.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gb1/gb1_IgG1FC_relax_best.pdb sample_mutations_mpnn C A 1 0.3 0.3 100 ../gb1/resfile.resfile ../gb1/output_SimpleMPNN03/ 2 | -------------------------------------------------------------------------------- /gb1/SimpleMPNN1.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gb1/gb1_IgG1FC_relax_best.pdb sample_mutations_mpnn C A 1 1 1 100 ../gb1/resfile.resfile ../gb1/output_SimpleMPNN1/ 2 | -------------------------------------------------------------------------------- /gb1/SimpleMPNN15.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gb1/gb1_IgG1FC_relax_best.pdb sample_mutations_mpnn C A 1 1.5 1.5 100 ../gb1/resfile.resfile ../gb1/output_SimpleMPNN15/ 2 | -------------------------------------------------------------------------------- /gb1/avg03.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gb1/gb1_IgG1FC_relax_best.pdb sample_mutations_avg C A 1 0.3 0.3 100 ../gb1/resfile.resfile ../gb1/output_avg03/ 2 | -------------------------------------------------------------------------------- /gb1/avg1.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gb1/gb1_IgG1FC_relax_best.pdb sample_mutations_avg C A 1 1 1 100 ../gb1/resfile.resfile ../gb1/output_avg1/ 2 | -------------------------------------------------------------------------------- /gb1/avg15.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gb1/gb1_IgG1FC_relax_best.pdb sample_mutations_avg C A 1 1.5 1.5 100 ../gb1/resfile.resfile ../gb1/output_avg15/ 2 | -------------------------------------------------------------------------------- /gb1/calc_metrics.sh: -------------------------------------------------------------------------------- 1 | ../calc_metrics.sh ../gb1/results_FastDesign.list C A 1 C_A ../gb1/ score_FastDesign.sc 2 | -------------------------------------------------------------------------------- /gb1/gb1_ridge.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meilerlab/probabilities_design/31916a3b5792737b49d0833a3396d1da546460c6/gb1/gb1_ridge.joblib -------------------------------------------------------------------------------- /gfp/ESM03.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gfp/avGFP_F64L_relax_best.pdb sample_mutations_esm A A 0 0.3 0.3 4 ../gfp/resfile.resfile ../gfp/output_ESM03/ 2 | -------------------------------------------------------------------------------- /gfp/ESM1.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gfp/avGFP_F64L_relax_best.pdb sample_mutations_esm A A 0 1 1 5 ../gfp/resfile.resfile ../gfp/output_ESM1/ 2 | -------------------------------------------------------------------------------- /gfp/ESM15.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gfp/avGFP_F64L_relax_best.pdb sample_mutations_esm A A 0 1.5 1.5 4 ../gfp/resfile.resfile ../gfp/output_ESM15/ 2 | -------------------------------------------------------------------------------- /gfp/FastDesign.sh: -------------------------------------------------------------------------------- 1 | ../simple_design_gfp.sh ../gfp/avGFP_F64L_relax_best.pdb if_relax A A 0 0.0 0.0 0 ../gfp/resfile.resfile ../gfp/output_FastDesign/ 2 | -------------------------------------------------------------------------------- /gfp/IC_ESM1.sh: -------------------------------------------------------------------------------- 1 | ../converge_it_gfp.sh ../gfp/avGFP_F64L_relax_best.pdb sample_mutations_esm A A 0 1 1 1 filt_pp_esm ../gfp/resfile.resfile ../gfp/output_IC_ESM1/ 2 | -------------------------------------------------------------------------------- /gfp/IC_MIFST1.sh: -------------------------------------------------------------------------------- 1 | ../converge_it_gfp.sh ../gfp/avGFP_F64L_relax_best.pdb sample_mutations_mifst A A 0 1 1 1 filt_pp_mifst ../gfp/resfile.resfile ../gfp/output_IC_MIFST1/ 2 | -------------------------------------------------------------------------------- /gfp/MIFST03.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gfp/avGFP_F64L_relax_best.pdb sample_mutations_mifst A A 0 0.3 0.3 4 ../gfp/resfile.resfile ../gfp/output_MIFST03/ 2 | -------------------------------------------------------------------------------- /gfp/MIFST1.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gfp/avGFP_F64L_relax_best.pdb sample_mutations_mifst A A 0 1 1 5 ../gfp/resfile.resfile ../gfp/output_MIFST1/ 2 | -------------------------------------------------------------------------------- /gfp/MIFST15.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gfp/avGFP_F64L_relax_best.pdb sample_mutations_mifst A A 0 1.5 1.5 4 ../gfp/resfile.resfile ../gfp/output_MIFST15/ 2 | -------------------------------------------------------------------------------- /gfp/RelaxSimpleMPNN1.sh: -------------------------------------------------------------------------------- 1 | ../relax_design_gfp.sh ../gfp/avGFP_F64L_relax_best.pdb sample_mutations_mpnn A A 0 1 1 5 ../gfp/resfile.resfile ../gfp/output_RelaxSimpleMPNN1/ 2 | -------------------------------------------------------------------------------- /gfp/SimpleMPNN03.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gfp/avGFP_F64L_relax_best.pdb sample_mutations_mpnn A A 0 0.3 0.3 4 ../gfp/resfile.resfile ../gfp/output_SimpleMPNN03/ 2 | -------------------------------------------------------------------------------- /gfp/SimpleMPNN1.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gfp/avGFP_F64L_relax_best.pdb sample_mutations_mpnn A A 0 1 1 5 ../gfp/resfile.resfile ../gfp/output_SimpleMPNN1/ 2 | -------------------------------------------------------------------------------- /gfp/SimpleMPNN15.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gfp/avGFP_F64L_relax_best.pdb sample_mutations_mpnn A A 0 1.5 1.5 4 ../gfp/resfile.resfile ../gfp/output_SimpleMPNN15/ 2 | -------------------------------------------------------------------------------- /gfp/avg03.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gfp/avGFP_F64L_relax_best.pdb sample_mutations_avg A A 0 0.3 0.3 4 ../gfp/resfile.resfile ../gfp/output_avg03/ 2 | -------------------------------------------------------------------------------- /gfp/avg1.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gfp/avGFP_F64L_relax_best.pdb sample_mutations_avg A A 0 1 1 5 ../gfp/resfile.resfile ../gfp/output_avg1/ 2 | -------------------------------------------------------------------------------- /gfp/avg15.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../gfp/avGFP_F64L_relax_best.pdb sample_mutations_avg A A 0 1.5 1.5 4 ../gfp/resfile.resfile ../gfp/output_avg15/ 2 | -------------------------------------------------------------------------------- /gfp/calc_metrics.sh: -------------------------------------------------------------------------------- 1 | ../calc_metrics.sh ../gfp/results_FastDesign.list A A 0 B_A ../gfp/ score_FastDesign_v2.sc 2 | -------------------------------------------------------------------------------- /herceptin/ESM03.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../herceptin/herceptin_her2_relax_best.pdb sample_mutations_esm E A 1 0.3 0.3 100 ../herceptin/resfile.resfile ../herceptin/output_ESM03/ 2 | -------------------------------------------------------------------------------- /herceptin/ESM1.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../herceptin/herceptin_her2_relax_best.pdb sample_mutations_esm E A 1 1 1 100 ../herceptin/resfile.resfile ../herceptin/output_ESM1/ 2 | -------------------------------------------------------------------------------- /herceptin/ESM15.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../herceptin/herceptin_her2_relax_best.pdb sample_mutations_esm E A 1 1.5 1.5 100 ../herceptin/resfile.resfile ../herceptin/output_ESM15/ 2 | -------------------------------------------------------------------------------- /herceptin/FastDesign.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../herceptin/herceptin_her2_relax_best.pdb if_relax E A 1 0.0 0.0 0 ../herceptin/resfile.resfile ../herceptin/output_FastDesign/ 2 | -------------------------------------------------------------------------------- /herceptin/IC_ESM1.sh: -------------------------------------------------------------------------------- 1 | ../converge_it.sh ../herceptin/herceptin_her2_relax_best.pdb sample_mutations_esm E A 1 1 1 1 filt_pp_esm ../herceptin/resfile.resfile ../herceptin/output_IC_ESM1/ 2 | -------------------------------------------------------------------------------- /herceptin/IC_MIFST1.sh: -------------------------------------------------------------------------------- 1 | ../converge_it.sh ../herceptin/herceptin_her2_relax_best.pdb sample_mutations_mifst E A 1 1 1 1 filt_pp_mifst ../herceptin/resfile.resfile ../herceptin/output_IC_MIFST1/ 2 | -------------------------------------------------------------------------------- /herceptin/IC_calc_metrics.sh: -------------------------------------------------------------------------------- 1 | ../calc_metrics.sh ../herceptin/results_IC_ESM1.list E A 1 ED_A ../herceptin/ score_IC_ESM1.sc 2 | -------------------------------------------------------------------------------- /herceptin/MIFST03.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../herceptin/herceptin_her2_relax_best.pdb sample_mutations_mifst E A 1 0.3 0.3 100 ../herceptin/resfile.resfile ../herceptin/output_MIFST03/ 2 | -------------------------------------------------------------------------------- /herceptin/MIFST1.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../herceptin/herceptin_her2_relax_best.pdb sample_mutations_mifst E A 1 1 1 100 ../herceptin/resfile.resfile ../herceptin/output_MIFST1/ 2 | -------------------------------------------------------------------------------- /herceptin/MIFST15.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../herceptin/herceptin_her2_relax_best.pdb sample_mutations_mifst E A 1 1.5 1.5 100 ../herceptin/resfile.resfile ../herceptin/output_MIFST15/ 2 | -------------------------------------------------------------------------------- /herceptin/RelaxDesign_calc_metrics.sh: -------------------------------------------------------------------------------- 1 | ../calc_metrics.sh ../herceptin/results_RelaxSimpleMPNN1.list E A 1 ED_A ../herceptin/ score_RelaxSimpleMPNN1.sc 2 | -------------------------------------------------------------------------------- /herceptin/RelaxSimpleMPNN1.sh: -------------------------------------------------------------------------------- 1 | ../relax_design.sh ../herceptin/herceptin_her2_relax_best.pdb sample_mutations_mpnn E A 1 1 1 100 ../herceptin/resfile.resfile ../herceptin/output_RelaxSimpleMPNN1/ 2 | -------------------------------------------------------------------------------- /herceptin/SimpleMPNN03.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../herceptin/herceptin_her2_relax_best.pdb sample_mutations_mpnn E A 1 0.3 0.3 100 ../herceptin/resfile.resfile ../herceptin/output_SimpleMPNN03/ 2 | -------------------------------------------------------------------------------- /herceptin/SimpleMPNN1.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../herceptin/herceptin_her2_relax_best.pdb sample_mutations_mpnn E A 1 1 1 100 ../herceptin/resfile.resfile ../herceptin/output_SimpleMPNN1/ 2 | -------------------------------------------------------------------------------- /herceptin/SimpleMPNN15.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../herceptin/herceptin_her2_relax_best.pdb sample_mutations_mpnn E A 1 1.5 1.5 100 ../herceptin/resfile.resfile ../herceptin/output_SimpleMPNN15/ 2 | -------------------------------------------------------------------------------- /herceptin/avg03.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../herceptin/herceptin_her2_relax_best.pdb sample_mutations_avg E A 1 0.3 0.3 100 ../herceptin/resfile.resfile ../herceptin/output_avg03/ 2 | -------------------------------------------------------------------------------- /herceptin/avg1.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../herceptin/herceptin_her2_relax_best.pdb sample_mutations_avg E A 1 1 1 100 ../herceptin/resfile.resfile ../herceptin/output_avg1/ 2 | -------------------------------------------------------------------------------- /herceptin/avg15.sh: -------------------------------------------------------------------------------- 1 | ../simple_design.sh ../herceptin/herceptin_her2_relax_best.pdb sample_mutations_avg E A 1 1.5 1.5 100 ../herceptin/resfile.resfile ../herceptin/output_avg15/ 2 | -------------------------------------------------------------------------------- /herceptin/calc_metrics.sh: -------------------------------------------------------------------------------- 1 | ../calc_metrics.sh ../herceptin/results_FastDesign.list E A 1 ED_A ../herceptin/ score_FastDesign.sc 2 | -------------------------------------------------------------------------------- /herceptin/lda_herceptin.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meilerlab/probabilities_design/31916a3b5792737b49d0833a3396d1da546460c6/herceptin/lda_herceptin.joblib -------------------------------------------------------------------------------- /model_training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import math" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": { 16 | "tags": [] 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import os\n", 21 | "import numpy as np\n", 22 | "import pandas as pd\n", 23 | "from sklearn.ensemble import RandomForestRegressor\n", 24 | "from sklearn.preprocessing import OneHotEncoder\n", 25 | "from sklearn.model_selection import train_test_split\n", 26 | "from sklearn.metrics import mean_squared_error\n", 27 | "from sklearn.linear_model import Ridge\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "import seaborn as sns\n", 30 | "from scipy.stats import spearmanr\n", 31 | "from sklearn import svm\n", 32 | "from joblib import dump, load\n", 33 | "from imblearn.over_sampling import RandomOverSampler\n", 34 | "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", 35 | "from sklearn.metrics import accuracy_score, matthews_corrcoef" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "RANDOM_STATE = 42" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 4, 50 | "metadata": { 51 | "tags": [] 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "def one_hot_encode_sequences(df, column_name):\n", 56 | " # Define a mapping from amino acids to integers\n", 57 | " amino_acids = 'ACDEFGHIKLMNPQRSTVWY'\n", 58 | " amino_acid_to_int = {aa: i for i, aa in enumerate(amino_acids)}\n", 59 | " num_amino_acids = len(amino_acids)\n", 60 | "\n", 61 | " encoded_sequences = []\n", 62 | "\n", 63 | " for sequence in df[column_name]:\n", 64 | " # Initialize a matrix of zeros\n", 65 | " encoded_matrix = np.zeros((len(sequence), num_amino_acids), dtype=int)\n", 66 | "\n", 67 | " for i, aa in enumerate(sequence):\n", 68 | " if aa in amino_acid_to_int:\n", 69 | " # Set the corresponding column to 1\n", 70 | " encoded_matrix[i, amino_acid_to_int[aa]] = 1\n", 71 | " else:\n", 72 | " raise ValueError(f\"Invalid amino acid '{aa}' found in sequence.\")\n", 73 | "\n", 74 | " encoded_sequences.append(encoded_matrix)\n", 75 | "\n", 76 | " return encoded_sequences" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "# GFP" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 5, 89 | "metadata": { 90 | "tags": [] 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "df_GFP = pd.read_csv('gfp/gfp_data.csv')" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 6, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "data": { 104 | "text/html": [ 105 | "
\n", 106 | "\n", 119 | "\n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | "
SequenceDescriptionLigandDataUnitsAssay/Protocol
0MSEGEELFAGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...K1E+T7A+V53E+M231KNaN1.301unitlessBrightness
1MSEGEELFAGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...K1E+T7A+M76L+M231TNaN3.702unitlessBrightness
2MSEGEELFAGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...K1E+T7A+N133DNaN3.689unitlessBrightness
3MSEGEELFPGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTPKF...K1E+T7P+L42P+Y180N+T184S+A204TNaN1.301unitlessBrightness
4MSEGEELFSGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...K1E+T7S+F98Y+K154R+E170GNaN3.647unitlessBrightness
\n", 179 | "
" 180 | ], 181 | "text/plain": [ 182 | " Sequence \\\n", 183 | "0 MSEGEELFAGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF... \n", 184 | "1 MSEGEELFAGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF... \n", 185 | "2 MSEGEELFAGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF... \n", 186 | "3 MSEGEELFPGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTPKF... \n", 187 | "4 MSEGEELFSGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF... \n", 188 | "\n", 189 | " Description Ligand Data Units Assay/Protocol \n", 190 | "0 K1E+T7A+V53E+M231K NaN 1.301 unitless Brightness \n", 191 | "1 K1E+T7A+M76L+M231T NaN 3.702 unitless Brightness \n", 192 | "2 K1E+T7A+N133D NaN 3.689 unitless Brightness \n", 193 | "3 K1E+T7P+L42P+Y180N+T184S+A204T NaN 1.301 unitless Brightness \n", 194 | "4 K1E+T7S+F98Y+K154R+E170G NaN 3.647 unitless Brightness " 195 | ] 196 | }, 197 | "execution_count": 6, 198 | "metadata": {}, 199 | "output_type": "execute_result" 200 | } 201 | ], 202 | "source": [ 203 | "df_GFP.head()" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 7, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "# GFP sequence and truncated version to match structure/dataset\n", 213 | "wt = 'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK'\n", 214 | "wt_trunc = 'KGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK'\n" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 8, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "# Function to parse mutations and count\n", 224 | "def count_mutations(df, wt_sequence):\n", 225 | " # Initialize a dictionary to count mutations\n", 226 | " mutation_count = {i: 0 for i in range(1, len(wt_sequence) + 1)}\n", 227 | "\n", 228 | " # Iterate over each row in the DataFrame\n", 229 | " for index, row in df.iterrows():\n", 230 | " # Split the mutations by '+'\n", 231 | " mutations = row['Description'].split('+')\n", 232 | "\n", 233 | " # Iterate over each mutation\n", 234 | " for mutation in mutations:\n", 235 | " # Extract the position and compare with wildtype\n", 236 | " position = int(''.join(filter(str.isdigit, mutation)))\n", 237 | " wt_amino_acid = wt_sequence[position - 1]\n", 238 | " mut_amino_acid = mutation[-1]\n", 239 | "\n", 240 | " # Check if mutation is different from wildtype\n", 241 | " if wt_amino_acid != mut_amino_acid:\n", 242 | " mutation_count[position] += 1\n", 243 | "\n", 244 | " return mutation_count\n", 245 | "\n", 246 | "# Count the mutations\n", 247 | "mutation_counts = count_mutations(df_GFP, wt_trunc)" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 9, 253 | "metadata": {}, 254 | "outputs": [ 255 | { 256 | "data": { 257 | "text/plain": [ 258 | "[117, 118, 236]" 259 | ] 260 | }, 261 | "execution_count": 9, 262 | "metadata": {}, 263 | "output_type": "execute_result" 264 | } 265 | ], 266 | "source": [ 267 | "# Extract positions with zero mutations, we keep those fixed during design as we have no information on them\n", 268 | "zero_mutation_positions = [position for position, count in mutation_counts.items() if count == 0]\n", 269 | "\n", 270 | "# Print the positions with zero mutations\n", 271 | "zero_mutation_positions\n", 272 | "\n" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 10, 278 | "metadata": { 279 | "tags": [] 280 | }, 281 | "outputs": [], 282 | "source": [ 283 | "# one hot encode\n", 284 | "df_GFP['encoded'] = one_hot_encode_sequences(df_GFP, 'Sequence')\n", 285 | "# Flatten the encoded sequence\n", 286 | "df_GFP['Flattened_Encoded'] = df_GFP['encoded'].apply(lambda x: x.flatten())\n", 287 | "# Create a feature matrix X and target vector y\n", 288 | "X = np.stack(df_GFP['Flattened_Encoded'].values)\n", 289 | "y = df_GFP['Data'].values" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 11, 295 | "metadata": { 296 | "tags": [] 297 | }, 298 | "outputs": [], 299 | "source": [ 300 | "# Split the data into training and testing sets\n", 301 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=RANDOM_STATE)" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 12, 307 | "metadata": { 308 | "scrolled": true, 309 | "tags": [] 310 | }, 311 | "outputs": [ 312 | { 313 | "data": { 314 | "text/plain": [ 315 | "Ridge(max_iter=1000000, solver='lsqr', tol=0.0001)" 316 | ] 317 | }, 318 | "execution_count": 12, 319 | "metadata": {}, 320 | "output_type": "execute_result" 321 | } 322 | ], 323 | "source": [ 324 | "model = Ridge(alpha=1.0, solver='lsqr', tol=1e-4, max_iter=1000000)\n", 325 | "model.fit(X_train, y_train)" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": 13, 331 | "metadata": { 332 | "tags": [] 333 | }, 334 | "outputs": [ 335 | { 336 | "name": "stdout", 337 | "output_type": "stream", 338 | "text": [ 339 | "Spearman Correlation: 0.7676198648740054\n" 340 | ] 341 | } 342 | ], 343 | "source": [ 344 | "# Predict on the test set\n", 345 | "y_pred = model.predict(X_test)\n", 346 | "spearman_corr, p_value = spearmanr(y_pred, y_test)\n", 347 | "print(\"Spearman Correlation:\", spearman_corr)" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": 14, 353 | "metadata": {}, 354 | "outputs": [ 355 | { 356 | "data": { 357 | "text/plain": [ 358 | "['gfp/gfp_ridge.joblib']" 359 | ] 360 | }, 361 | "execution_count": 14, 362 | "metadata": {}, 363 | "output_type": "execute_result" 364 | } 365 | ], 366 | "source": [ 367 | "dump(model, 'gfp/gfp_ridge.joblib') # save model for later use (also provided as file in the repo)" 368 | ] 369 | }, 370 | { 371 | "cell_type": "markdown", 372 | "metadata": {}, 373 | "source": [ 374 | "# GB1" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": 15, 380 | "metadata": {}, 381 | "outputs": [ 382 | { 383 | "name": "stderr", 384 | "output_type": "stream", 385 | "text": [ 386 | "/home/me/conda/envs/probs_design/lib/python3.7/site-packages/IPython/core/interactiveshell.py:3552: DtypeWarning: Columns (8,10,12) have mixed types.Specify dtype option on import or set low_memory=False.\n", 387 | " exec(code_obj, self.user_global_ns, self.user_ns)\n" 388 | ] 389 | } 390 | ], 391 | "source": [ 392 | "df_GB1 = pd.read_csv('gb1/gb1_mutations_full_data.csv')" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": 16, 398 | "metadata": {}, 399 | "outputs": [ 400 | { 401 | "data": { 402 | "text/html": [ 403 | "
\n", 404 | "\n", 417 | "\n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | "
VariantsHDCount inputCount selectedFitnesssequencekeepone_vs_restone_vs_rest_validationtwo_vs_resttwo_vs_rest_validationthree_vs_restthree_vs_rest_validationsampledsampled_validationlow_vs_highlow_vs_high_validation
0VDGV0927353383461.000000MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYD...TruetrainNaNtrainNaNtrainNaNtrainNaNtestNaN
1ADGV134430.061910MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGADGEWTYD...FalseNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
2CDGV18506410.242237MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGCDGEWTYD...FalseNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
3DDGV163630.006472MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGDDGEWTYD...FalseNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
4EDGV18411900.032719MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGEDGEWTYD...FalseNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
\n", 543 | "
" 544 | ], 545 | "text/plain": [ 546 | " Variants HD Count input Count selected Fitness \\\n", 547 | "0 VDGV 0 92735 338346 1.000000 \n", 548 | "1 ADGV 1 34 43 0.061910 \n", 549 | "2 CDGV 1 850 641 0.242237 \n", 550 | "3 DDGV 1 63 63 0.006472 \n", 551 | "4 EDGV 1 841 190 0.032719 \n", 552 | "\n", 553 | " sequence keep one_vs_rest \\\n", 554 | "0 MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYD... True train \n", 555 | "1 MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGADGEWTYD... False NaN \n", 556 | "2 MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGCDGEWTYD... False NaN \n", 557 | "3 MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGDDGEWTYD... False NaN \n", 558 | "4 MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGEDGEWTYD... False NaN \n", 559 | "\n", 560 | " one_vs_rest_validation two_vs_rest two_vs_rest_validation three_vs_rest \\\n", 561 | "0 NaN train NaN train \n", 562 | "1 NaN NaN NaN NaN \n", 563 | "2 NaN NaN NaN NaN \n", 564 | "3 NaN NaN NaN NaN \n", 565 | "4 NaN NaN NaN NaN \n", 566 | "\n", 567 | " three_vs_rest_validation sampled sampled_validation low_vs_high \\\n", 568 | "0 NaN train NaN test \n", 569 | "1 NaN NaN NaN NaN \n", 570 | "2 NaN NaN NaN NaN \n", 571 | "3 NaN NaN NaN NaN \n", 572 | "4 NaN NaN NaN NaN \n", 573 | "\n", 574 | " low_vs_high_validation \n", 575 | "0 NaN \n", 576 | "1 NaN \n", 577 | "2 NaN \n", 578 | "3 NaN \n", 579 | "4 NaN " 580 | ] 581 | }, 582 | "execution_count": 16, 583 | "metadata": {}, 584 | "output_type": "execute_result" 585 | } 586 | ], 587 | "source": [ 588 | "df_GB1.head()" 589 | ] 590 | }, 591 | { 592 | "cell_type": "markdown", 593 | "metadata": {}, 594 | "source": [ 595 | "### here the keep variable is used to balance the dataset, as most mutations destroy fitness as seen in the boxplot below" 596 | ] 597 | }, 598 | { 599 | "cell_type": "code", 600 | "execution_count": 17, 601 | "metadata": {}, 602 | "outputs": [ 603 | { 604 | "data": { 605 | "text/plain": [ 606 | "" 607 | ] 608 | }, 609 | "execution_count": 17, 610 | "metadata": {}, 611 | "output_type": "execute_result" 612 | }, 613 | { 614 | "data": { 615 | "image/png": "", 616 | "text/plain": [ 617 | "
" 618 | ] 619 | }, 620 | "metadata": {}, 621 | "output_type": "display_data" 622 | } 623 | ], 624 | "source": [ 625 | "sns.boxplot(df_GB1.Fitness)" 626 | ] 627 | }, 628 | { 629 | "cell_type": "code", 630 | "execution_count": 18, 631 | "metadata": {}, 632 | "outputs": [ 633 | { 634 | "data": { 635 | "text/plain": [ 636 | "" 637 | ] 638 | }, 639 | "execution_count": 18, 640 | "metadata": {}, 641 | "output_type": "execute_result" 642 | }, 643 | { 644 | "data": { 645 | "image/png": "", 646 | "text/plain": [ 647 | "
" 648 | ] 649 | }, 650 | "metadata": {}, 651 | "output_type": "display_data" 652 | } 653 | ], 654 | "source": [ 655 | "sns.boxplot(df_GB1[df_GB1.keep == True].Fitness)" 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": 19, 661 | "metadata": {}, 662 | "outputs": [ 663 | { 664 | "name": "stdout", 665 | "output_type": "stream", 666 | "text": [ 667 | "8733\n", 668 | "149361\n" 669 | ] 670 | } 671 | ], 672 | "source": [ 673 | "print(len(df_GB1[df_GB1.keep == True]))\n", 674 | "print(len(df_GB1))" 675 | ] 676 | }, 677 | { 678 | "cell_type": "code", 679 | "execution_count": 20, 680 | "metadata": {}, 681 | "outputs": [], 682 | "source": [ 683 | "df_GB1 = df_GB1[df_GB1.keep == True].copy()" 684 | ] 685 | }, 686 | { 687 | "cell_type": "code", 688 | "execution_count": 21, 689 | "metadata": {}, 690 | "outputs": [], 691 | "source": [ 692 | "df_GB1['trunc_seq'] = df_GB1['sequence'].apply(lambda seq: 'MTYKLIL'+seq[7:56]) # truncate it to the PDB sequence\n", 693 | "df_GB1['encoded'] = one_hot_encode_sequences(df_GB1, 'trunc_seq')\n", 694 | "df_GB1['Flattened_Encoded'] = df_GB1['encoded'].apply(lambda x: x.flatten())" 695 | ] 696 | }, 697 | { 698 | "cell_type": "code", 699 | "execution_count": 22, 700 | "metadata": {}, 701 | "outputs": [], 702 | "source": [ 703 | "# Create a feature matrix X and target vector y\n", 704 | "X = np.stack(df_GB1['Flattened_Encoded'].values)\n", 705 | "y = df_GB1['Fitness'].values" 706 | ] 707 | }, 708 | { 709 | "cell_type": "code", 710 | "execution_count": 23, 711 | "metadata": {}, 712 | "outputs": [], 713 | "source": [ 714 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=RANDOM_STATE)" 715 | ] 716 | }, 717 | { 718 | "cell_type": "code", 719 | "execution_count": 24, 720 | "metadata": {}, 721 | "outputs": [ 722 | { 723 | "data": { 724 | "text/plain": [ 725 | "Ridge(max_iter=1000000, solver='lsqr', tol=0.0001)" 726 | ] 727 | }, 728 | "execution_count": 24, 729 | "metadata": {}, 730 | "output_type": "execute_result" 731 | } 732 | ], 733 | "source": [ 734 | "model = Ridge(alpha=1.0, solver='lsqr', tol=1e-4, max_iter=1000000)\n", 735 | "model.fit(X_train, y_train)" 736 | ] 737 | }, 738 | { 739 | "cell_type": "code", 740 | "execution_count": 25, 741 | "metadata": {}, 742 | "outputs": [], 743 | "source": [ 744 | "y_pred = model.predict(X_test)" 745 | ] 746 | }, 747 | { 748 | "cell_type": "code", 749 | "execution_count": 26, 750 | "metadata": {}, 751 | "outputs": [ 752 | { 753 | "name": "stdout", 754 | "output_type": "stream", 755 | "text": [ 756 | "Spearman Correlation: 0.8098051820702165\n", 757 | "P-value: 3.810811979116243e-204\n" 758 | ] 759 | } 760 | ], 761 | "source": [ 762 | "# Evaluate the model\n", 763 | "spearman_corr, p_value = spearmanr(y_pred, y_test)\n", 764 | "print(\"Spearman Correlation:\", spearman_corr)\n", 765 | "print(\"P-value:\", p_value)" 766 | ] 767 | }, 768 | { 769 | "cell_type": "code", 770 | "execution_count": 27, 771 | "metadata": {}, 772 | "outputs": [ 773 | { 774 | "data": { 775 | "text/plain": [ 776 | "['gb1/gb1_ridge.joblib']" 777 | ] 778 | }, 779 | "execution_count": 27, 780 | "metadata": {}, 781 | "output_type": "execute_result" 782 | } 783 | ], 784 | "source": [ 785 | "dump(model, 'gb1/gb1_ridge.joblib') # save the model (already provided in github repo as well)" 786 | ] 787 | }, 788 | { 789 | "cell_type": "markdown", 790 | "metadata": {}, 791 | "source": [ 792 | "# Emi" 793 | ] 794 | }, 795 | { 796 | "cell_type": "code", 797 | "execution_count": 28, 798 | "metadata": {}, 799 | "outputs": [ 800 | { 801 | "data": { 802 | "text/html": [ 803 | "
\n", 804 | "\n", 817 | "\n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | "
VH SequenceANT BindingOVA BindingpI_seq
0QVQLVQSGAEVKKPGASVKVSCKASGYTFTDYYMHWVRQAPGQGLE...018.64
1QVQLVQSGAEVKKPGASVKVSCKASGYTFTDYYMHWVRQAPGQGLE...118.96
2QVQLVQSGAEVKKPGASVKVSCKASGYTFTDYFMHWVRQAPGQGLE...017.96
3QVQLVQSGAEVKKPGASVKVSCKASGYTFTDYSMHWVRQAPGQGLE...118.60
4QVQLVQSGAEVKKPGASVKVSCKASGYTFTDYFMHWVRQAPGQGLE...017.96
\n", 865 | "
" 866 | ], 867 | "text/plain": [ 868 | " VH Sequence ANT Binding \\\n", 869 | "0 QVQLVQSGAEVKKPGASVKVSCKASGYTFTDYYMHWVRQAPGQGLE... 0 \n", 870 | "1 QVQLVQSGAEVKKPGASVKVSCKASGYTFTDYYMHWVRQAPGQGLE... 1 \n", 871 | "2 QVQLVQSGAEVKKPGASVKVSCKASGYTFTDYFMHWVRQAPGQGLE... 0 \n", 872 | "3 QVQLVQSGAEVKKPGASVKVSCKASGYTFTDYSMHWVRQAPGQGLE... 1 \n", 873 | "4 QVQLVQSGAEVKKPGASVKVSCKASGYTFTDYFMHWVRQAPGQGLE... 0 \n", 874 | "\n", 875 | " OVA Binding pI_seq \n", 876 | "0 1 8.64 \n", 877 | "1 1 8.96 \n", 878 | "2 1 7.96 \n", 879 | "3 1 8.60 \n", 880 | "4 1 7.96 " 881 | ] 882 | }, 883 | "execution_count": 28, 884 | "metadata": {}, 885 | "output_type": "execute_result" 886 | } 887 | ], 888 | "source": [ 889 | "df_emi = pd.read_csv('emi/emi_binding.csv')\n", 890 | "df_emi.head()" 891 | ] 892 | }, 893 | { 894 | "cell_type": "code", 895 | "execution_count": 29, 896 | "metadata": {}, 897 | "outputs": [], 898 | "source": [ 899 | "df_emi['encoded'] = one_hot_encode_sequences(df_emi, 'VH Sequence')\n", 900 | "df_emi['Flattened_Encoded'] = df_emi['encoded'].apply(lambda x: x.flatten())" 901 | ] 902 | }, 903 | { 904 | "cell_type": "code", 905 | "execution_count": 30, 906 | "metadata": {}, 907 | "outputs": [], 908 | "source": [ 909 | "wt = 'QVQLVQSGAEVKKPGASVKVSCKASGYTFTDYYMHWVRQAPGQGLEWMGRVNPNRRGTTYNQKFEGRVTMTTDTSTSTAYMELRSLRSDDTAVYYCARANWLDYWGQGTTVTVSS'" 910 | ] 911 | }, 912 | { 913 | "cell_type": "code", 914 | "execution_count": 31, 915 | "metadata": {}, 916 | "outputs": [], 917 | "source": [ 918 | "# Create a feature matrix X and target vector y\n", 919 | "X = np.stack(df_emi['Flattened_Encoded'].values)\n", 920 | "y = df_emi[['ANT Binding', 'OVA Binding']].values\n", 921 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=RANDOM_STATE)" 922 | ] 923 | }, 924 | { 925 | "cell_type": "code", 926 | "execution_count": 32, 927 | "metadata": {}, 928 | "outputs": [ 929 | { 930 | "data": { 931 | "text/plain": [ 932 | "LinearDiscriminantAnalysis()" 933 | ] 934 | }, 935 | "execution_count": 32, 936 | "metadata": {}, 937 | "output_type": "execute_result" 938 | } 939 | ], 940 | "source": [ 941 | "lda_ANT = LinearDiscriminantAnalysis()\n", 942 | "lda_ANT.fit(X_train, y_train[:,0])" 943 | ] 944 | }, 945 | { 946 | "cell_type": "code", 947 | "execution_count": 33, 948 | "metadata": {}, 949 | "outputs": [ 950 | { 951 | "name": "stdout", 952 | "output_type": "stream", 953 | "text": [ 954 | "accuracy: 0.9375 mcc: 0.8732685075504191\n" 955 | ] 956 | } 957 | ], 958 | "source": [ 959 | "y_pred = lda_ANT.predict(X_test)\n", 960 | "accuracy = accuracy_score(y_test[:,0], y_pred)\n", 961 | "mcc = matthews_corrcoef(y_test[:,0], y_pred)\n", 962 | "print('accuracy:', accuracy, 'mcc:', mcc)" 963 | ] 964 | }, 965 | { 966 | "cell_type": "code", 967 | "execution_count": 34, 968 | "metadata": {}, 969 | "outputs": [ 970 | { 971 | "data": { 972 | "text/plain": [ 973 | "LinearDiscriminantAnalysis()" 974 | ] 975 | }, 976 | "execution_count": 34, 977 | "metadata": {}, 978 | "output_type": "execute_result" 979 | } 980 | ], 981 | "source": [ 982 | "lda_OVA = LinearDiscriminantAnalysis()\n", 983 | "lda_OVA.fit(X_train, y_train[:,1])" 984 | ] 985 | }, 986 | { 987 | "cell_type": "code", 988 | "execution_count": 35, 989 | "metadata": {}, 990 | "outputs": [ 991 | { 992 | "name": "stdout", 993 | "output_type": "stream", 994 | "text": [ 995 | "accuracy: 0.92 mcc: 0.8403386677035108\n" 996 | ] 997 | } 998 | ], 999 | "source": [ 1000 | "y_pred = lda_OVA.predict(X_test)\n", 1001 | "accuracy = accuracy_score(y_test[:,1], y_pred)\n", 1002 | "mcc = matthews_corrcoef(y_test[:,1], y_pred)\n", 1003 | "print('accuracy:', accuracy, 'mcc:', mcc)" 1004 | ] 1005 | }, 1006 | { 1007 | "cell_type": "code", 1008 | "execution_count": 36, 1009 | "metadata": {}, 1010 | "outputs": [ 1011 | { 1012 | "data": { 1013 | "text/plain": [ 1014 | "['emi/emi_LDA_ANT.joblib']" 1015 | ] 1016 | }, 1017 | "execution_count": 36, 1018 | "metadata": {}, 1019 | "output_type": "execute_result" 1020 | } 1021 | ], 1022 | "source": [ 1023 | "# save both models (provided in github repo as well)\n", 1024 | "dump(lda_OVA, 'emi/emi_LDA_OVA.joblib')\n", 1025 | "dump(lda_ANT, 'emi/emi_LDA_ANT.joblib')" 1026 | ] 1027 | }, 1028 | { 1029 | "cell_type": "markdown", 1030 | "metadata": {}, 1031 | "source": [ 1032 | "# Herceptin" 1033 | ] 1034 | }, 1035 | { 1036 | "cell_type": "code", 1037 | "execution_count": 37, 1038 | "metadata": {}, 1039 | "outputs": [], 1040 | "source": [ 1041 | "df_herceptin_neg = pd.read_csv('herceptin/mHER_H3_AgNeg.csv', index_col=0)\n", 1042 | "df_herceptin_pos = pd.read_csv('herceptin/mHER_H3_AgPos.csv', index_col=0)" 1043 | ] 1044 | }, 1045 | { 1046 | "cell_type": "code", 1047 | "execution_count": 38, 1048 | "metadata": {}, 1049 | "outputs": [], 1050 | "source": [ 1051 | "df_herceptin = df_herceptin_neg.append(df_herceptin_pos).copy()" 1052 | ] 1053 | }, 1054 | { 1055 | "cell_type": "code", 1056 | "execution_count": 39, 1057 | "metadata": {}, 1058 | "outputs": [], 1059 | "source": [ 1060 | "h_chain = 'EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLEWVARIYPTNGYTRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCSRWGGDGFYAMDYWGQGTLVTVS'\n", 1061 | "cdr3 = 'WGGDGFYAMD'" 1062 | ] 1063 | }, 1064 | { 1065 | "cell_type": "code", 1066 | "execution_count": 40, 1067 | "metadata": {}, 1068 | "outputs": [], 1069 | "source": [ 1070 | "full_seq_list = []\n", 1071 | "for seq in df_herceptin.AASeq:\n", 1072 | " full_seq_list.append(h_chain.replace(cdr3, seq))\n", 1073 | "df_herceptin['full_seq'] = full_seq_list" 1074 | ] 1075 | }, 1076 | { 1077 | "cell_type": "code", 1078 | "execution_count": 41, 1079 | "metadata": {}, 1080 | "outputs": [ 1081 | { 1082 | "data": { 1083 | "text/html": [ 1084 | "
\n", 1085 | "\n", 1098 | "\n", 1099 | " \n", 1100 | " \n", 1101 | " \n", 1102 | " \n", 1103 | " \n", 1104 | " \n", 1105 | " \n", 1106 | " \n", 1107 | " \n", 1108 | " \n", 1109 | " \n", 1110 | " \n", 1111 | " \n", 1112 | " \n", 1113 | " \n", 1114 | " \n", 1115 | " \n", 1116 | " \n", 1117 | " \n", 1118 | " \n", 1119 | " \n", 1120 | " \n", 1121 | " \n", 1122 | " \n", 1123 | " \n", 1124 | " \n", 1125 | " \n", 1126 | " \n", 1127 | " \n", 1128 | " \n", 1129 | " \n", 1130 | " \n", 1131 | " \n", 1132 | " \n", 1133 | " \n", 1134 | " \n", 1135 | " \n", 1136 | " \n", 1137 | " \n", 1138 | " \n", 1139 | " \n", 1140 | " \n", 1141 | " \n", 1142 | " \n", 1143 | " \n", 1144 | " \n", 1145 | " \n", 1146 | " \n", 1147 | " \n", 1148 | " \n", 1149 | " \n", 1150 | " \n", 1151 | " \n", 1152 | " \n", 1153 | " \n", 1154 | " \n", 1155 | " \n", 1156 | " \n", 1157 | "
CountFractionNucSeqAASeqAgClassfull_seq
070.000007TGTAGCAGGTACACTATCTGCAGTTTCTACAAGCTCCAGTATTGGYTICSFYKLQ0EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLE...
1950.000041TGTAGCAGGTGGTTCCTCTGCGGCTTCTACCAGAACATGTATTGGWFLCGFYQNM0EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLE...
230.000001TGTAGCAGGTTCGGCAACATCAGCTCCTTCGCGATCGCGTATTGGFGNISSFAIA0EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLE...
3100.000005TGTAGCAGGTTCAAGGTCAACGGTCTGTTCCCGCACCTCTATTGGFKVNGLFPHL0EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLE...
4160.000016TGTAGCAGGTACACTATCTGCAGTATGTACGAGTTCGATTATTGGYTICSMYEFD0EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLE...
\n", 1158 | "
" 1159 | ], 1160 | "text/plain": [ 1161 | " Count Fraction NucSeq AASeq \\\n", 1162 | "0 7 0.000007 TGTAGCAGGTACACTATCTGCAGTTTCTACAAGCTCCAGTATTGG YTICSFYKLQ \n", 1163 | "1 95 0.000041 TGTAGCAGGTGGTTCCTCTGCGGCTTCTACCAGAACATGTATTGG WFLCGFYQNM \n", 1164 | "2 3 0.000001 TGTAGCAGGTTCGGCAACATCAGCTCCTTCGCGATCGCGTATTGG FGNISSFAIA \n", 1165 | "3 10 0.000005 TGTAGCAGGTTCAAGGTCAACGGTCTGTTCCCGCACCTCTATTGG FKVNGLFPHL \n", 1166 | "4 16 0.000016 TGTAGCAGGTACACTATCTGCAGTATGTACGAGTTCGATTATTGG YTICSMYEFD \n", 1167 | "\n", 1168 | " AgClass full_seq \n", 1169 | "0 0 EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLE... \n", 1170 | "1 0 EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLE... \n", 1171 | "2 0 EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLE... \n", 1172 | "3 0 EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLE... \n", 1173 | "4 0 EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLE... " 1174 | ] 1175 | }, 1176 | "execution_count": 41, 1177 | "metadata": {}, 1178 | "output_type": "execute_result" 1179 | } 1180 | ], 1181 | "source": [ 1182 | "df_herceptin.head()" 1183 | ] 1184 | }, 1185 | { 1186 | "cell_type": "code", 1187 | "execution_count": 42, 1188 | "metadata": {}, 1189 | "outputs": [], 1190 | "source": [ 1191 | "df_herceptin['encoded'] = one_hot_encode_sequences(df_herceptin, 'full_seq')\n", 1192 | "df_herceptin['Flattened_Encoded'] = df_herceptin['encoded'].apply(lambda x: x.flatten())" 1193 | ] 1194 | }, 1195 | { 1196 | "cell_type": "code", 1197 | "execution_count": 43, 1198 | "metadata": {}, 1199 | "outputs": [], 1200 | "source": [ 1201 | "# Create a feature matrix X and target vector y\n", 1202 | "X = np.stack(df_herceptin['Flattened_Encoded'].values)\n", 1203 | "y = df_herceptin['AgClass'].values\n", 1204 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=RANDOM_STATE)" 1205 | ] 1206 | }, 1207 | { 1208 | "cell_type": "code", 1209 | "execution_count": 44, 1210 | "metadata": {}, 1211 | "outputs": [], 1212 | "source": [ 1213 | "# we randomly oversample the training dataset to balance\n", 1214 | "ros = RandomOverSampler(random_state=RANDOM_STATE)\n", 1215 | "X_train_resampled, y_train_resampled = ros.fit_resample(X_train, y_train)" 1216 | ] 1217 | }, 1218 | { 1219 | "cell_type": "code", 1220 | "execution_count": 45, 1221 | "metadata": {}, 1222 | "outputs": [ 1223 | { 1224 | "data": { 1225 | "text/plain": [ 1226 | "LinearDiscriminantAnalysis()" 1227 | ] 1228 | }, 1229 | "execution_count": 45, 1230 | "metadata": {}, 1231 | "output_type": "execute_result" 1232 | } 1233 | ], 1234 | "source": [ 1235 | "lda_herceptin = LinearDiscriminantAnalysis()\n", 1236 | "lda_herceptin.fit(X_train_resampled, y_train_resampled)" 1237 | ] 1238 | }, 1239 | { 1240 | "cell_type": "code", 1241 | "execution_count": 46, 1242 | "metadata": {}, 1243 | "outputs": [ 1244 | { 1245 | "name": "stdout", 1246 | "output_type": "stream", 1247 | "text": [ 1248 | "accuracy: 0.7893923789907312 mcc: 0.5677487791517641\n" 1249 | ] 1250 | } 1251 | ], 1252 | "source": [ 1253 | "y_pred = lda_herceptin.predict(X_test)\n", 1254 | "accuracy = accuracy_score(y_test, y_pred)\n", 1255 | "mcc = matthews_corrcoef(y_test, y_pred)\n", 1256 | "print('accuracy:', accuracy, 'mcc:', mcc)" 1257 | ] 1258 | }, 1259 | { 1260 | "cell_type": "code", 1261 | "execution_count": 47, 1262 | "metadata": {}, 1263 | "outputs": [ 1264 | { 1265 | "data": { 1266 | "text/plain": [ 1267 | "['herceptin/lda_herceptin.joblib']" 1268 | ] 1269 | }, 1270 | "execution_count": 47, 1271 | "metadata": {}, 1272 | "output_type": "execute_result" 1273 | } 1274 | ], 1275 | "source": [ 1276 | "# save the model (file is provided in the github repo as well)\n", 1277 | "dump(lda_herceptin, 'herceptin/lda_herceptin.joblib')" 1278 | ] 1279 | }, 1280 | { 1281 | "cell_type": "code", 1282 | "execution_count": null, 1283 | "metadata": {}, 1284 | "outputs": [], 1285 | "source": [] 1286 | }, 1287 | { 1288 | "cell_type": "code", 1289 | "execution_count": null, 1290 | "metadata": {}, 1291 | "outputs": [], 1292 | "source": [] 1293 | }, 1294 | { 1295 | "cell_type": "code", 1296 | "execution_count": null, 1297 | "metadata": {}, 1298 | "outputs": [], 1299 | "source": [] 1300 | }, 1301 | { 1302 | "cell_type": "code", 1303 | "execution_count": null, 1304 | "metadata": {}, 1305 | "outputs": [], 1306 | "source": [] 1307 | }, 1308 | { 1309 | "cell_type": "code", 1310 | "execution_count": null, 1311 | "metadata": {}, 1312 | "outputs": [], 1313 | "source": [] 1314 | } 1315 | ], 1316 | "metadata": { 1317 | "kernelspec": { 1318 | "display_name": "probs_design", 1319 | "language": "python", 1320 | "name": "python3" 1321 | }, 1322 | "language_info": { 1323 | "codemirror_mode": { 1324 | "name": "ipython", 1325 | "version": 3 1326 | }, 1327 | "file_extension": ".py", 1328 | "mimetype": "text/x-python", 1329 | "name": "python", 1330 | "nbconvert_exporter": "python", 1331 | "pygments_lexer": "ipython3", 1332 | "version": "3.7.8" 1333 | } 1334 | }, 1335 | "nbformat": 4, 1336 | "nbformat_minor": 5 1337 | } 1338 | -------------------------------------------------------------------------------- /simple_design.sh: -------------------------------------------------------------------------------- 1 | Rosetta/main/source/bin/rosetta_scripts.pytorchtensorflow.linuxgccrelease @ ./design.options -parser:protocol ./SimpleDesign.xml -s $1 -parser:script_vars protocol=$2 design_chain=$3 antigen=$4 AIFA=$5 pos_temp=$6 aa_temp=$7 n_muts=$8 resfile=$9 -out:path:all "${10}" 2 | 3 | --------------------------------------------------------------------------------