├── .DS_Store ├── LICENSE ├── README.md ├── SETUP.md ├── common ├── atoms.py ├── logger.py └── run_manager.py ├── imgs ├── ex2_results.png ├── ex3_results.png ├── ex5_results.png ├── ex6_results.png └── tim.gif ├── load_and_save_bb_coords.py ├── load_and_save_coords.py ├── pdbs ├── 1acf_gt.pdb ├── 1acf_gt_crelax.pdb ├── 1bkr_gt.pdb ├── 1bkr_gt_crelax.pdb ├── 1cc8_gt.pdb ├── 1cc8_gt_crelax.pdb ├── 3mx7_gt.pdb ├── 3mx7_gt_crelax.pdb └── tim10.pdb ├── requirements.txt ├── run.py ├── seq_des ├── __init__.py ├── models.py ├── sampler.py └── util │ ├── README.md │ ├── __init__.py │ ├── acc_util.py │ ├── canonicalize.py │ ├── data.py │ ├── pyrosetta_util.py │ ├── resfile_util.py │ ├── sampler_util.py │ └── voxelize.py ├── seq_des_info.pdf ├── train_autoreg_chi.py ├── train_autoreg_chi_baseline.py └── txt ├── resfiles ├── NATRO_all.txt ├── PIKAA_all_one_AA.txt ├── full_example.txt ├── generate_resfile.py ├── init_seq_1acf_gt.txt ├── init_seq_1bkr_gt.txt ├── init_seq_1cc8_gt.txt ├── init_seq_3mx7_gt.txt ├── resfile_1acf_gt_ex8.txt ├── resfile_1bkr_gt_ex6.txt ├── resfile_3mx7_gt_ex1.txt ├── resfile_3mx7_gt_ex2.txt ├── some_PIKAA_one.txt └── testing_TPIKAA_TNOTAA.txt ├── test_domains_s95.txt ├── test_idx.txt └── train_domains_s95.txt /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProteinDesignLab/protein_seq_des/bb1e5a968f84a2db189f6a7ce400b96c5eaff691/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020 The Leland Stanford Junior University (Stanford University), Namrata Anand-Achim. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Protein sequence design with a learned potential 2 | 3 | Code for the algorithm in our paper 4 | 5 | > Namrata Anand-Achim, Raphael R. Eguchi, Alexander Derry, Russ B. Altman, and Possu Huang. "Protein sequence design with a learned potential." bioRxiv (2020). 6 | > [[biorxiv]](https://www.biorxiv.org/content/10.1101/2020.01.06.895466v1) [[cite]](#citation) 7 | 8 | ![Model design trajectory](imgs/tim.gif) 9 | 10 | Entirely AI designed four-fold symmetric TIM-barrel 11 | 12 | ## Requirements 13 | 14 | * Python 3 15 | * [PyTorch](https://pytorch.org) 16 | * [PyRosetta4](http://www.pyrosetta.org/dow) 17 | * Python packages in requirements.txt 18 | * Download pretrained models [here](https://drive.google.com/file/d/1X66RLbaA2-qTlJLlG9TI53cao8gaKnEt/view?usp=sharing) 19 | 20 | See [here](https://github.com/nanand2/protein_seq_des/blob/master/SETUP.md) for set-up instructions on Ubuntu 18.04 with Miniconda, Python 3.7, PyTorch 1.1.0, CUDA 9.0. 21 | 22 | 23 | ## Design 24 | 25 | If you'd like to use the pre-trained models to run design, jump to [[this section]](#running-design) 26 | 27 | ## Generating data 28 | Data is available [here](https://drive.google.com/drive/folders/1MD-tu32SoYtZGag04HwntuxcuOnYPDXs). See the README in the drive for more information about the uploaded files. For the files used to generate the above coordinates, see the .txt files with the domain IDs (see data/train_domains_s95.txt and data/test_domain_s95.txt). These will be the inputs to regenerate the dataset. If you don't have PDB files downloaded, the script will download those and save it to pdb_dir. 29 | 30 | If you'd like to generate the dataset or change the underlying data run the following commands. 31 | 32 | To load and save coordinates for the backbone (BB) only model: 33 | ``` 34 | python load_and_save_bb_coords.py --save_dir PATH_TO_SAVE_DATA --pdb_dir PATH_TO_PDB_FILES --log_dir PATH_TO_LOG_DIR --txt PATH_TO_DOMAIN_TXT_FILE 35 | ``` 36 | 37 | To load and save coordinates for the main model: 38 | ``` 39 | python load_and_save_coords.py --save_dir PATH_TO_SAVE_DATA --pdb_dir PATH_TO_PDB_FILES --log_dir PATH_TO_LOG_DIR --txt PATH_TO_DOMAIN_TXT_FILE 40 | ``` 41 | 42 | ## Training the models 43 | 44 | Pretrained models are available [here](https://drive.google.com/file/d/1cHoyeI0H_Jo9bqgFH4z0dfx2s9as9Jp1/view?usp=sharing) but you can also use the available scripts to train from scratch. 45 | 46 | To train the baseline model -- residue and autoregressive rotamer prediction conditioned on backbone (BB) atoms only model (no side-chains): 47 | ``` 48 | python train_autoreg_chi_baseline.py --batchSize 4096 --workers 12 --lr 1.5e-4 --validation_frequency 100 --save_frequency 1000 --log_dir PATH_TO_LOG_DIR --data_dir PATH_TO_DATA 49 | ``` 50 | 51 | To train the main model -- residue and autoregressive rotamer prediction conditioned on neighboring side-chains: 52 | ``` 53 | python train_autoreg_chi.py --batchSize 2048 --workers 12 --lr 7.5e-5 --validation_frequency 200 --save_frequency 2000 --log_dir PATH_TO_LOG_DIR --data_dir PATH_TO_DATA 54 | ``` 55 | Note that training was originally done across 8 V100 GPUs with DataParallel mode. 56 | 57 | 58 | 59 | ## Running design 60 | 61 | To run a design trajectory, specify starting backbone with an input PDB. 62 | 63 | ``` 64 | python run.py --pdb pdbs/3mx7_gt.pdb 65 | ``` 66 | 67 | To run a rotamer repacking trajectory with the model, specify the repack only option 68 | ``` 69 | python run.py --pdb pdbs/3mx7_gt.pdb --repack_only 1 70 | ``` 71 | 72 | To specify k-fold symmetry in design or packing, specify the symmetry options 73 | ``` 74 | python run.py --pdb pdbs/tim10.pdb --symmetry 1 --k 4 [--repack_only 1] 75 | ``` 76 | 77 | To constraint a subset of positions to remain fixed, point to a txt file with fixed residue indices, for example 78 | ``` 79 | python run.py --pdb pdbs/tim10.pdb --fixed_idx txt/test_idx.txt 80 | ``` 81 | 82 | And to constrain a subset of positions to be designed, keeping all others fixed, point to a txt file with variable residue indices, for example 83 | ``` 84 | python run.py --pdb pdbs/tim10.pdb --var_idx txt/test_idx.txt 85 | ``` 86 | 87 | See [below](#design-parameters) for additional design parameters. 88 | 89 | ## Monitoring metrics 90 | Design metrics can be monitored using Tensorboard 91 | 92 | ``` 93 | tensorboard --log_dir='./logs' 94 | ``` 95 | 96 | Note that the input PDB sequence and rotamers are considered 'ground-truth' for sequence and rotamer recovery metrics. 97 | 98 | 99 | 100 | ## Design parameters 101 | 102 | * Design inputs 103 | ``` 104 | --pdb Path to input PDB 105 | --model_list Paths to conditional models. (Default: ['models/conditional_model_0.pt', 106 | 'models/conditional_model_0.pt', 'models/conditional_model_1.pt', 107 | 'models/conditional_model_2.pt', 'models/conditional_model_3.pt']) 108 | --init_model Path to baseline model for sequence initialization. 109 | (Default: 'models/baseline_model.pt') 110 | ``` 111 | * Saving / logging 112 | ``` 113 | --log_dir Path to desired output log folder for designed 114 | structures. (Default: ./logs) 115 | --seed Random seed. Design runs are non-deterministic. 116 | (Default: 2) 117 | --save_rate How often to save intermediate designed structures 118 | (Default: 10) 119 | 120 | ``` 121 | * Sequence initialization 122 | ``` 123 | --randomize {0,1} Randomize starting sequence/rotamers for design. 124 | Toggle to 0 to keep starting sequence and rotamers. 125 | (Default: 1) 126 | --no_init_model {0,1} Do not use baseline model to predict initial sequence/rotamers. 127 | (Default: 0) 128 | --ala {0,1} Initialize sequence with poly-alanine. (Default: 0) 129 | --val {0,1} Initialize sequence with poly-valine. (Default: 0) 130 | ``` 131 | * Rotamer repacking parameters 132 | ``` 133 | --repack_only {0,1} Only run rotamer repacking. (Default: 0) 134 | --use_rosetta_packer {0,1} 135 | Use the Rosetta packer instead of the model for 136 | rotamer repacking during design. If in symmetry 137 | mode, rotamers are not packed symmetrically. (Default: 0) 138 | --pack_radius Radius in angstroms for Rosetta rotamer packing after 139 | residue mutation. Must set --use_rosetta_packer 1 140 | (Default: 0) 141 | ``` 142 | * Design parameters 143 | ``` 144 | --symmetry {0,1} Enforce symmetry during design (Default: 0) 145 | --k Enforce k-fold symmetry. Input pose length must be 146 | divisible by k. Requires --symmetry 1 (Default: 4) 147 | --restrict_gly {0,1} Enforce no glycines for non-loop backbone positions 148 | based on DSSP assignment. (Default: 1) 149 | --no_cys {0,1} Enforce no cysteines in design (Default: 0) 150 | --no_met {0,1} Enforce no methionines in design (Default: 0) 151 | --var_idx Path to txt file listing pose indices that should be 152 | designed/packed, all other side-chains will remain 153 | fixed. Cannot be specified if fixed_idx file given. 154 | Not supported with symmetry mode. 0-indexed 155 | --fixed_idx Path to txt file listing pose indices that should NOT 156 | be designed/packed, all other side-chains will be 157 | designed/packed. Cannot be specified if var_idx file given. 158 | Not supported with symmetry mode. 0-indexed 159 | --resfile Enforce resfile on particular residues. 0-indexed 160 | ``` 161 | 162 | learn more about [resfile](https://github.com/ProteinDesignLab/protein_seq_des/tree/master/seq_des/util) 163 | 164 | * Sampling / optimization parameters 165 | ``` 166 | --anneal {0,1} Option to do simulated annealing of average negative 167 | model pseudo-log-likelihood. Toggle to 0 to do vanilla 168 | blocked sampling (Default: 1) 169 | --step_rate Multiplicative step rate for simulated annealing (Default: 0.995) 170 | --anneal_start_temp Starting temperature for simulated annealing (Default: 1) 171 | --anneal_final_temp Final temperature for simulated annealing (Default: 0) 172 | --n_iters Total number of iterations (Default: 2500) 173 | --threshold Threshold in angstroms for defining conditionally 174 | independent residues for blocked sampling (should be 175 | greater than ~17.3) (Default: 20) 176 | ``` 177 | 178 | Additional information 179 | * Code expects single chain PDB input. 180 | * Specifying fixed/variable indices not currently supported in symmetry mode. 181 | * Model rotamer packing in symmetry mode does symmetric rotamer packing, but using the Rosetta packer does not. 182 | 183 | ## Citation 184 | If you find our work relevant to your research, please cite: 185 | ``` 186 | @article{anand2020protein, 187 | title={Protein sequence design with a learned potential}, 188 | author={Anand, Namrata and Eguchi, Raphael Ryuichi and Derry, Alexander and Altman, Russ B and Huang, Possu}, 189 | journal={bioRxiv}, 190 | year={2020}, 191 | publisher={Cold Spring Harbor Laboratory} 192 | } 193 | ``` 194 | -------------------------------------------------------------------------------- /SETUP.md: -------------------------------------------------------------------------------- 1 | 2 | ## Setup 3 | 4 | Instructions for set up on Ubuntu 18.04 with Miniconda and Python 3.7 5 | 6 | * Install [Miniconda](https://docs.conda.io/en/latest/miniconda.html) 7 | * Create a conda env for the project 8 | ``` 9 | conda create -y -n seq_des python=3.7 anaconda 10 | conda activate seq_des 11 | ``` 12 | * Install [PyRosetta4](http://www.pyrosetta.org/dow) via conda 13 | * Get a license for PyRosetta 14 | * Add this to ~/.condarc 15 | ``` 16 | channels: 17 | - https://USERNAME:PASSWORD@conda.graylab.jhu.edu 18 | - conda-forge 19 | - defaults 20 | ``` 21 | * Install PyRosetta 22 | ``` 23 | conda install pyrosetta 24 | ``` 25 | * Install PyTorch 1.1.0 with CUDA 9.0 26 | ``` 27 | conda install -y pytorch=1.1.0 torchvision cudatoolkit=9.0 -c pytorch 28 | ``` 29 | * Clone this repo 30 | ``` 31 | git clone https://github.com/nanand2/protein_seq_des.git 32 | ``` 33 | * Install Python packages 34 | ``` 35 | cd protein_seq_des 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | * Download [pretrained models](https://drive.google.com/file/d/1X66RLbaA2-qTlJLlG9TI53cao8gaKnEt/view?usp=sharing) to current directory 40 | 41 | ``` 42 | unzip models.zip 43 | ``` 44 | 45 | -------------------------------------------------------------------------------- /common/atoms.py: -------------------------------------------------------------------------------- 1 | import string 2 | 3 | letters = string.ascii_uppercase 4 | 5 | rename_chains = {i: letters[i] for i in range(26)} # NOTE -- expect errors if you have more than 26 structures 6 | skip_res_list = [ 7 | "HOH", 8 | "GOL", 9 | "EDO", 10 | "SO4", 11 | "EDO", 12 | "NAG", 13 | "PO4", 14 | "ACT", 15 | "PEG", 16 | "MAN", 17 | "BMA", 18 | "DMS", 19 | "MPD", 20 | "MES", 21 | "PG4", 22 | "TRS", 23 | "FMT", 24 | "PGE", 25 | "EPE", 26 | "NO3", 27 | "UNX", 28 | "UNL", 29 | "UNK", 30 | "IPA", 31 | "IMD", 32 | "GLC", 33 | "MLI", 34 | "1PE", 35 | "NO3", 36 | "SCN", 37 | "P6G", 38 | "OXY", 39 | "EOH", 40 | "NH4", 41 | "DTT", 42 | "BEN", 43 | "BCT", 44 | "FUL", 45 | "AZI", 46 | "DOD", 47 | "OH", 48 | "CYN", 49 | "NO", 50 | "NO2", 51 | "SO3", 52 | "H2S", 53 | "MOH", 54 | "URE", 55 | "CO2", 56 | "2NO", 57 | ] # top ions, sugars, small molecules with N/C/O/S/P that appear to be crystal artifacts or common surface bound co-factors -- to ignore 58 | skip_atoms = ["H", "D"] 59 | atoms = ["N", "C", "O", "S", "P", "other"] 60 | aa = [ 61 | "ALA", 62 | "ARG", 63 | "ASN", 64 | "ASP", 65 | "CYS", 66 | "GLN", 67 | "GLU", 68 | "GLY", 69 | "HIS", 70 | "ILE", 71 | "LEU", 72 | "LYS", 73 | "MET", 74 | "PHE", 75 | "PRO", 76 | "SER", 77 | "THR", 78 | "TRP", 79 | "TYR", 80 | "VAL", 81 | "MSE", 82 | ] 83 | res_label_dict = { 84 | "HIS": 0, 85 | "LYS": 1, 86 | "ARG": 2, 87 | "ASP": 3, 88 | "GLU": 4, 89 | "SER": 5, 90 | "THR": 6, 91 | "ASN": 7, 92 | "GLN": 8, 93 | "ALA": 9, 94 | "VAL": 10, 95 | "LEU": 11, 96 | "ILE": 12, 97 | "MET": 13, 98 | "PHE": 14, 99 | "TYR": 15, 100 | "TRP": 16, 101 | "PRO": 17, 102 | "GLY": 18, 103 | "CYS": 19, 104 | "MSE": 13, 105 | } # MSE -- same as MET 106 | label_res_dict = { 107 | 0: "HIS", 108 | 1: "LYS", 109 | 2: "ARG", 110 | 3: "ASP", 111 | 4: "GLU", 112 | 5: "SER", 113 | 6: "THR", 114 | 7: "ASN", 115 | 8: "GLN", 116 | 9: "ALA", 117 | 10: "VAL", 118 | 11: "LEU", 119 | 12: "ILE", 120 | 13: "MET", 121 | 14: "PHE", 122 | 15: "TYR", 123 | 16: "TRP", 124 | 17: "PRO", 125 | 18: "GLY", 126 | 19: "CYS", 127 | } # , 20:'MSE'} 128 | 129 | chi_dict = { 130 | "ARG": {"chi_1": "CG", "chi_2": "CD", "chi_3": "NE", "chi_4": "CZ"}, 131 | "LYS": {"chi_1": "CG", "chi_2": "CD", "chi_3": "CE", "chi_4": "NZ"}, 132 | "GLN": {"chi_1": "CG", "chi_2": "CD", "chi_3": "OE1"}, 133 | "GLU": {"chi_1": "CG", "chi_2": "CD", "chi_3": "OE1"}, 134 | "MET": {"chi_1": "CG", "chi_2": "SD", "chi_3": "CE"}, 135 | "ASP": {"chi_1": "CG", "chi_2": "OD1"}, 136 | "ILE": {"chi_1": "CG1", "chi_2": "CD1"}, 137 | "HIS": {"chi_1": "CG", "chi_2": "ND1"}, 138 | "LEU": {"chi_1": "CG", "chi_2": "CD1"}, 139 | "ASN": {"chi_1": "CG", "chi_2": "OD1"}, 140 | "PHE": {"chi_1": "CG", "chi_2": "CD1"}, 141 | "PRO": {"chi_1": "CG", "chi_2": "CD"}, 142 | "TRP": {"chi_1": "CG", "chi_2": "CD1"}, 143 | "TYR": {"chi_1": "CG", "chi_2": "CD1"}, 144 | "VAL": {"chi_1": "CG1"}, 145 | "THR": {"chi_1": "OG1"}, 146 | "SER": {"chi_1": "OG"}, 147 | "CYS": {"chi_1": "SG"}, 148 | "GLY": {}, 149 | "ALA": {}, 150 | } 151 | 152 | 153 | chi_dict_old = { 154 | "ARG": 4, 155 | "LYS": 4, 156 | "GLN": 3, 157 | "GLU": 3, 158 | "MET": 3, 159 | "ASP": 2, 160 | "ILE": 2, 161 | "HIS": 2, 162 | "LEU": 2, 163 | "ASN": 2, 164 | "PHE": 2, 165 | "PRO": 3, 166 | "TRP": 2, 167 | "TYR": 2, 168 | "VAL": 1, 169 | "THR": 1, 170 | "SER": 1, 171 | "CYS": 1, 172 | "GLY": 0, 173 | "ALA": 0, 174 | } 175 | aa_map = { 176 | 0: "H", 177 | 1: "K", 178 | 2: "R", 179 | 3: "D", 180 | 4: "E", 181 | 5: "S", 182 | 6: "T", 183 | 7: "N", 184 | 8: "Q", 185 | 9: "A", 186 | 10: "V", 187 | 11: "L", 188 | 12: "I", 189 | 13: "M", 190 | 14: "F", 191 | 15: "Y", 192 | 16: "W", 193 | 17: "P", 194 | 18: "G", 195 | 19: "C", 196 | } # , 20: "M"} # caution methionine in place of MSE 197 | aa_inv = { 198 | "H": "HIS", 199 | "K": "LYS", 200 | "R": "ARG", 201 | "D": "ASP", 202 | "E": "GLU", 203 | "S": "SER", 204 | "T": "THR", 205 | "N": "ASN", 206 | "Q": "GLN", 207 | "A": "ALA", 208 | "V": "VAL", 209 | "L": "LEU", 210 | "I": "ILE", 211 | "M": "MET", 212 | "F": "PHE", 213 | "Y": "TYR", 214 | "W": "TRP", 215 | "P": "PRO", 216 | "G": "GLY", 217 | "C": "CYS", 218 | } 219 | aa_map_inv = { 220 | "H": 0, 221 | "K": 1, 222 | "R": 2, 223 | "D": 3, 224 | "E": 4, 225 | "S": 5, 226 | "T": 6, 227 | "N": 7, 228 | "Q": 8, 229 | "A": 9, 230 | "V": 10, 231 | "L": 11, 232 | "I": 12, 233 | "M": 13, 234 | "F": 14, 235 | "Y": 15, 236 | "W": 16, 237 | "P": 17, 238 | "G": 18, 239 | "C": 19, 240 | } 241 | aa_to_letter = {aa_inv[k]: k for k in aa_inv.keys()} 242 | label_res_single_dict = { 243 | 0: "H", 244 | 1: "K", 245 | 2: "R", 246 | 3: "D", 247 | 4: "E", 248 | 5: "S", 249 | 6: "T", 250 | 7: "N", 251 | 8: "Q", 252 | 9: "A", 253 | 10: "V", 254 | 11: "L", 255 | 12: "I", 256 | 13: "M", 257 | 14: "F", 258 | 15: "Y", 259 | 16: "W", 260 | 17: "P", 261 | 18: "G", 262 | 19: "C", 263 | } 264 | # resfile commands where values are amino acids allowed by that command 265 | resfile_commands = { 266 | "ALLAA": {'H', 'K', 'R', 'D', 'E', 'S', 'T', 'N', 'Q', 'A', 'V', 'L', 'I', 'M', 'F', 'Y', 'W', 'P', 'G', 'C'}, 267 | "ALLAAwc": {'H', 'K', 'R', 'D', 'E', 'S', 'T', 'N', 'Q', 'A', 'V', 'L', 'I', 'M', 'F', 'Y', 'W', 'P', 'G', 'C'}, 268 | "ALLAAxc": {'H', 'K', 'R', 'D', 'E', 'S', 'T', 'N', 'Q', 'A', 'V', 'L', 'I', 'M', 'F', 'Y', 'W', 'P', 'G'}, 269 | "POLAR": {'E', 'H', 'K', 'N', 'R', 'Q', 'D', 'S', 'T'}, 270 | "APOLAR": {'P', 'M', 'Y', 'V', 'F', 'L', 'I', 'A', 'C', 'W', 'G'}, 271 | } 272 | -------------------------------------------------------------------------------- /common/logger.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | import numpy as np 3 | import os 4 | import datetime 5 | import subprocess 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, log_dir="./logs", dummy=False, prefix="", suffix="", full_log_dir=None, rank=0): 10 | self.suffix = suffix 11 | self.prefix = prefix 12 | self.dummy = dummy 13 | if self.dummy: 14 | return 15 | 16 | self.iteration = 1 17 | 18 | if log_dir == "": 19 | log_dir = "./logs" 20 | if full_log_dir is None or log_dir == "": 21 | now = datetime.datetime.now() 22 | self.ts = now.strftime("%Y-%m-%d-%H-%M-%S") 23 | log_path = os.path.join(log_dir, self.prefix + self.ts + self.suffix) 24 | else: 25 | log_path = full_log_dir 26 | 27 | self.log_path = log_path 28 | if not os.path.isdir(log_path): 29 | self.writer = SummaryWriter(log_dir=log_path) 30 | self.kvs = {} 31 | 32 | print(("Logging to", log_path)) 33 | 34 | def log_args(self, args): 35 | with open("%s/args.txt" % self.log_path, "w") as f: 36 | for arg in vars(args): 37 | f.write("%s\t%s\n" % (arg, getattr(args, arg))) 38 | 39 | def advance_iteration(self): 40 | self.iteration += 1 41 | 42 | def reset_iteration(self): 43 | self.iteration = 0 44 | 45 | def log_scalar(self, name, value): 46 | if self.dummy: 47 | return 48 | 49 | if isinstance(value, list): 50 | assert len(value) == 1, (name, len(value), value) 51 | return self.log_scalar(name, value[0]) 52 | try: 53 | self.writer.add_scalar(name, value, self.iteration) 54 | except Exception as e: 55 | print(("Failed on", name, value, type(value))) 56 | raise 57 | 58 | def log_kvs(self, **kwargs): 59 | if self.dummy: 60 | return 61 | 62 | for k, v in kwargs.items(): 63 | assert isinstance(k, str) 64 | self.kvs[k] = v 65 | 66 | kv_strings = ["%s=%s" % (k, v) for k, v in sorted(self.kvs.items())] 67 | val = "
".join(kv_strings) 68 | self.writer.add_text("properties", val, global_step=self.iteration) 69 | -------------------------------------------------------------------------------- /common/run_manager.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import time 4 | import numpy as np 5 | from . import logger 6 | import torch 7 | import os 8 | 9 | 10 | class RunManager(object): 11 | def __init__(self): 12 | 13 | self.parser = argparse.ArgumentParser() 14 | 15 | self.parser.add_argument("--workers", type=int, help="number of data loading workers", default=0) 16 | self.parser.add_argument("--cuda", type=int, default=1, help="enables cuda") 17 | 18 | # training parameters 19 | self.parser.add_argument("--batchSize", type=int, default=64, help="input batch size") 20 | self.parser.add_argument("--ngpu", type=int, default=1, help="num gpus to parallelize over") 21 | 22 | self.parser.add_argument("--nf", type=int, default=64, help="base number of filters") 23 | self.parser.add_argument("--txt", type=str, default="txt/train_domains_s95.txt", help="default txt input file") 24 | 25 | self.parser.add_argument("--epochs", type=int, default=100, help="enables cuda") 26 | self.parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") 27 | self.parser.add_argument("--reg", type=float, default=5e-6, help="L2 regularization") 28 | self.parser.add_argument("--beta1", type=float, default=0.5, help="beta1 for adam. default=0.5") 29 | self.parser.add_argument("--momentum", type=float, default=0.01, help="momentum for batch norm") 30 | 31 | self.parser.add_argument( 32 | "--model", type=str, default="", help="path to saved pretrained model for resuming training", 33 | ) 34 | self.parser.add_argument("--optimizer", type=str, default="", help="path to saved optimizer params") 35 | self.parser.add_argument( 36 | "--validation_frequency", type=int, default=500, help="how often to validate during training", 37 | ) 38 | self.parser.add_argument("--save_frequency", type=int, default=2000, help="how often to save models") 39 | self.parser.add_argument("--sync_frequency", type=int, default=1000, help="how often to sync to GCP") 40 | 41 | self.parser.add_argument( 42 | "--num_return", type=int, default=400, help="number of nearest non-side-chain atmos to return per voxel", 43 | ) 44 | self.parser.add_argument("--chunk_size", type=int, default=10000, help="chunk size for saved coordinate tensors") 45 | 46 | self.parser.add_argument("--data_dir", type=str, default="/data/simdev_2tb/protein/sequence_design/data/coords") 47 | self.parser.add_argument("--pdb_dir", type=str, default="/data/drive2tb/protein/pdb") 48 | self.parser.add_argument("--save_dir", type=str, default="./coords") 49 | 50 | # design inputs 51 | self.parser.add_argument("--file_dir", type=str, default="run", help="folder to store files (must be specified") 52 | self.parser.add_argument("--pdb", type=str, default="pdbs/tim10.pdb", help="Input PDB") 53 | self.parser.add_argument( 54 | "--model_list", 55 | "--list", 56 | default=[ 57 | "models/conditional_model_0.pt", 58 | "models/conditional_model_1.pt", 59 | "models/conditional_model_2.pt", 60 | "models/conditional_model_3.pt", 61 | ], 62 | nargs="+", 63 | help="Paths to conditional models", 64 | ) 65 | self.parser.add_argument( 66 | "--init_model", type=str, default="models/baseline_model.pt", help="Path to baseline model (conditioned on backbone atoms only)", 67 | ) 68 | 69 | # saving / logging 70 | self.parser.add_argument( 71 | "--log_dir", type=str, default="./logs", help="Path to desired output log folder for designed structures", 72 | ) 73 | self.parser.add_argument("--seed", default=2, type=int, help="Random seed. Design runs are non-deterministic.") 74 | self.parser.add_argument( 75 | "--save_rate", type=int, default=10, help="How often to save intermediate designed structures", 76 | ) 77 | 78 | # design parameters 79 | self.parser.add_argument( 80 | "--no_init_model", type=int, default=0, choices=(0, 1), help="Do not use baseline model to initialize sequence/rotmaers.", 81 | ) 82 | self.parser.add_argument( 83 | "--randomize", 84 | type=int, 85 | default=1, 86 | choices=(0, 1), 87 | help="Randomize starting sequence/rotamers for design. Toggle OFF to keep starting sequence and rotamers", 88 | ) 89 | self.parser.add_argument( 90 | "--repack_only", type=int, default=0, choices=(0, 1), help="Only run rotamer repacking (no design, keep sequence fixed)", 91 | ) 92 | self.parser.add_argument( 93 | "--use_rosetta_packer", 94 | type=int, 95 | default=0, 96 | choices=(0, 1), 97 | help="Use the Rosetta packer instead of the model for rotamer repacking during design", 98 | ) 99 | self.parser.add_argument( 100 | "--threshold", 101 | type=float, 102 | default=20, 103 | help="Threshold in angstroms for defining conditionally independent residues for blocked sampling (should be greater than ~17.3)", 104 | ) 105 | self.parser.add_argument("--symmetry", type=int, default=0, choices=(0, 1), help="Enforce symmetry during design") 106 | self.parser.add_argument( 107 | "--k", type=int, default=4, help="Enforce k-fold symmetry. Input pose length must be divisible by k. Requires --symmetry 1", 108 | ) 109 | self.parser.add_argument( 110 | "--ala", type=int, default=0, choices=(0, 1), help="Initialize sequence with poly-alanine", 111 | ) 112 | self.parser.add_argument( 113 | "--val", type=int, default=0, choices=(0, 1), help="Initialize sequence with poly-valine", 114 | ) 115 | self.parser.add_argument( 116 | "--restrict_gly", type=int, default=1, choices=(0, 1), help="Restrict no glycines for non-loop residues", 117 | ) 118 | self.parser.add_argument("--no_cys", type=int, default=0, choices=(0, 1), help="Enforce no cysteines in design") 119 | self.parser.add_argument("--no_met", type=int, default=0, choices=(0, 1), help="Enforce no methionines in design") 120 | self.parser.add_argument( 121 | "--pack_radius", 122 | type=float, 123 | default=5.0, 124 | help="Rosetta packer radius for rotamer packing after residue mutation. Must set --use_rosetta_packer 1.", 125 | ) 126 | self.parser.add_argument( 127 | "--var_idx", 128 | type=str, 129 | default="", 130 | help="Path to txt file listing pose indices that should be designed/packed, all other side-chains will remain fixed. 0-indexed", 131 | ) 132 | self.parser.add_argument( 133 | "--fixed_idx", 134 | type=str, 135 | default="", 136 | help="Path to txt file listing pose indices that should NOT be designed/packed, all other side-chains will be designed/packed. 0-indexed", 137 | ) 138 | 139 | self.parser.add_argument("--resfile", type=str, default="", help="Specify path to a resfile to enforce constraints on particular residues") 140 | 141 | # optimization / sampling parameters 142 | self.parser.add_argument( 143 | "--anneal", 144 | type=int, 145 | default=1, 146 | choices=(0, 1), 147 | help="Option to do simulated annealing of average negative model pseudo-log-likelihood. Toggle OFF to do vanilla blocked sampling", 148 | ) 149 | self.parser.add_argument("--do_mcmc", type=int, default=0, help="Option to do Metropolis-Hastings") 150 | self.parser.add_argument( 151 | "--step_rate", type=float, default=0.995, help="Multiplicative step rate for simulated annealing", 152 | ) 153 | self.parser.add_argument( 154 | "--anneal_start_temp", type=float, default=1.0, help="Starting temperature for simulated annealing", 155 | ) 156 | self.parser.add_argument( 157 | "--anneal_final_temp", type=float, default=0.0, help="Final temperature for simulated annealing", 158 | ) 159 | self.parser.add_argument("--n_iters", type=int, default=2500, help="Total number of iterations") 160 | 161 | def add_argument(self, *args, **kwargs): 162 | self.parser.add_argument(*args, **kwargs) 163 | 164 | def parse_args(self): 165 | self.args = self.parser.parse_args() 166 | 167 | self.log = logger.Logger(log_dir=self.args.log_dir) 168 | self.log.log_kvs(**self.args.__dict__) 169 | self.log.log_args(self.args) 170 | 171 | np.random.seed(self.args.seed) 172 | random.seed(self.args.seed) 173 | torch.manual_seed(self.args.seed) 174 | torch.backends.cudnn.enabled = False 175 | torch.backends.cudnn.deterministic = True 176 | torch.backends.cudnn.benchmark = False 177 | 178 | return self.args 179 | -------------------------------------------------------------------------------- /imgs/ex2_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProteinDesignLab/protein_seq_des/bb1e5a968f84a2db189f6a7ce400b96c5eaff691/imgs/ex2_results.png -------------------------------------------------------------------------------- /imgs/ex3_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProteinDesignLab/protein_seq_des/bb1e5a968f84a2db189f6a7ce400b96c5eaff691/imgs/ex3_results.png -------------------------------------------------------------------------------- /imgs/ex5_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProteinDesignLab/protein_seq_des/bb1e5a968f84a2db189f6a7ce400b96c5eaff691/imgs/ex5_results.png -------------------------------------------------------------------------------- /imgs/ex6_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProteinDesignLab/protein_seq_des/bb1e5a968f84a2db189f6a7ce400b96c5eaff691/imgs/ex6_results.png -------------------------------------------------------------------------------- /imgs/tim.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProteinDesignLab/protein_seq_des/bb1e5a968f84a2db189f6a7ce400b96c5eaff691/imgs/tim.gif -------------------------------------------------------------------------------- /load_and_save_bb_coords.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from tqdm import tqdm 9 | import common.run_manager 10 | import glob 11 | import seq_des.util.canonicalize as canonicalize 12 | import pickle 13 | import seq_des.util.data as datasets 14 | from torch.utils import data 15 | 16 | 17 | import resource 18 | 19 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 20 | resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1])) 21 | 22 | """ script to load PDB coords, canonicalize, save """ 23 | 24 | def main(): 25 | 26 | manager = common.run_manager.RunManager() 27 | 28 | manager.parse_args() 29 | args = manager.args 30 | log = manager.log 31 | 32 | dataset = datasets.PDB_domain_spitter(txt_file=args.txt, pdb_path=args.pdb_dir, num_return=75, bb_only=1) 33 | 34 | dataloader = data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=args.workers) 35 | 36 | num_return = args.num_return 37 | gen = iter(dataloader) 38 | coords_out, data_out, ys, domain_ids, chis_out = [], [], [], [], [] 39 | 40 | cs = args.chunk_size 41 | n = 0 42 | 43 | for it in tqdm(range(len(dataloader)), desc="loading and saving coords"): 44 | 45 | out = gen.next() 46 | if len(out) == 0 or out is None: 47 | print("out is none") 48 | continue 49 | atom_coords, atom_data, res_label, domain_id, chis = out 50 | for i in range(len(atom_coords)): 51 | coords_out.extend(atom_coords[i][0].cpu().data.numpy()) 52 | data_out.extend(atom_data[i][0].cpu().data.numpy()) 53 | ys.extend(res_label[i][0].cpu().data.numpy()) 54 | domain_ids.extend([domain_id[i][0]] * res_label[i][0].cpu().data.numpy().shape[0]) 55 | chis_out.extend(chis[i][0].cpu().data.numpy()) 56 | 57 | assert len(coords_out) == len(ys) 58 | assert len(coords_out) == len(data_out) 59 | assert len(coords_out) == len(domain_ids), (len(coords_out), len(domain_ids)) 60 | assert len(coords_out) == len(chis_out) 61 | 62 | del atom_coords 63 | del atom_data 64 | del res_label 65 | del domain_id 66 | 67 | # intermittent save data 68 | if len(coords_out) > cs or it == len(dataloader) - 1: 69 | # shuffle then save 70 | print(n, len(coords_out)) # -- NOTE keep this 71 | idx = np.arange(min(cs, len(coords_out))) 72 | np.random.shuffle(idx) 73 | print(n, len(idx)) 74 | 75 | c, d, y, di, ch = map(lambda arr: np.array(arr[: len(idx)])[idx], [coords_out, data_out, ys, domain_ids, chis_out]) 76 | 77 | print("saving", args.save_dir + "/" + "data_%0.4d.pt" % (n)) 78 | torch.save((c, d, y, di, ch), args.save_dir + "/" + "data_%0.4d.pt" % (n)) 79 | 80 | print("Current num examples", (n) * cs + len(coords_out)) 81 | 82 | n += 1 83 | coords_out, data_out, ys, domain_ids, chis_out = map(lambda arr: arr[len(idx) :], [coords_out, data_out, ys, domain_ids, chis_out]) 84 | 85 | 86 | if __name__ == "__main__": 87 | main() 88 | -------------------------------------------------------------------------------- /load_and_save_coords.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from tqdm import tqdm 9 | import common.run_manager 10 | import glob 11 | import seq_des.util.canonicalize as canonicalize 12 | import pickle 13 | import seq_des.util.data as datasets 14 | from torch.utils import data 15 | 16 | 17 | import resource 18 | 19 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 20 | resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1])) 21 | 22 | """ script to load PDB coords, canonicalize, save """ 23 | 24 | def main(): 25 | 26 | manager = common.run_manager.RunManager() 27 | 28 | manager.parse_args() 29 | args = manager.args 30 | log = manager.log 31 | 32 | dataset = datasets.PDB_domain_spitter(txt_file=args.txt, pdb_path=args.pdb_dir) 33 | 34 | dataloader = data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=args.workers) 35 | 36 | num_return = args.num_return 37 | gen = iter(dataloader) 38 | coords_out, data_out, ys, domain_ids, chis_out = [], [], [], [], [] 39 | 40 | cs = args.chunk_size 41 | n = 0 42 | 43 | for it in tqdm(range(len(dataloader)), desc="loading and saving coords"): 44 | 45 | out = gen.next() 46 | if len(out) == 0 or out is None: 47 | print("out is none") 48 | continue 49 | atom_coords, atom_data, res_label, domain_id, chis = out 50 | for i in range(len(atom_coords)): 51 | coords_out.extend(atom_coords[i][0].cpu().data.numpy()) 52 | data_out.extend(atom_data[i][0].cpu().data.numpy()) 53 | ys.extend(res_label[i][0].cpu().data.numpy()) 54 | domain_ids.extend([domain_id[i][0]] * res_label[i][0].cpu().data.numpy().shape[0]) 55 | chis_out.extend(chis[i][0].cpu().data.numpy()) 56 | 57 | assert len(coords_out) == len(ys) 58 | assert len(coords_out) == len(data_out) 59 | assert len(coords_out) == len(domain_ids), (len(coords_out), len(domain_ids)) 60 | assert len(coords_out) == len(chis_out) 61 | 62 | del atom_coords 63 | del atom_data 64 | del res_label 65 | del domain_id 66 | 67 | # intermittent save data 68 | if len(coords_out) > cs or it == len(dataloader) - 1: 69 | # shuffle then save 70 | print(n, len(coords_out)) # -- NOTE keep this 71 | idx = np.arange(min(cs, len(coords_out))) 72 | np.random.shuffle(idx) 73 | print(n, len(idx)) 74 | 75 | c, d, y, di, ch = map(lambda arr: np.array(arr[: len(idx)])[idx], [coords_out, data_out, ys, domain_ids, chis_out]) 76 | 77 | print("saving", args.save_dir + "/" + "data_%0.4d.pt" % (n)) 78 | torch.save((c, d, y, di, ch), args.save_dir + "/" + "data_%0.4d.pt" % (n)) 79 | 80 | print("Current num examples", (n) * cs + len(coords_out)) 81 | 82 | n += 1 83 | coords_out, data_out, ys, domain_ids, chis_out = map(lambda arr: arr[len(idx) :], [coords_out, data_out, ys, domain_ids, chis_out]) 84 | 85 | 86 | if __name__ == "__main__": 87 | main() 88 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | biopython 2 | matplotlib 3 | tqdm 4 | tensorboardX 5 | scipy 6 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | 6 | from seq_des import * 7 | import seq_des.sampler as sampler 8 | import seq_des.models as models 9 | 10 | import common.run_manager 11 | import common.atoms 12 | 13 | import sys 14 | import pickle 15 | import glob 16 | 17 | from pyrosetta.rosetta.protocols.simple_filters import BuriedUnsatHbondFilterCreator, PackStatFilterCreator 18 | from pyrosetta.rosetta.protocols.denovo_design.filters import ExposedHydrophobicsFilterCreator 19 | 20 | from tqdm import tqdm 21 | 22 | __author__ = 'Namrata Anand-Achim' 23 | 24 | 25 | def log_metrics(run="sampler", args=None, log=None, iteration=0, design_sampler=None, prefix=""): 26 | # tensorboard logging 27 | 28 | # log structure / sequence metrics 29 | log.log_scalar("{run}/{prefix}rosetta_energy".format(run=run, prefix=prefix), design_sampler.rosetta_energy) 30 | log.log_scalar("{run}/{prefix}seq_overlap".format(run=run, prefix=prefix), design_sampler.seq_overlap) 31 | log.log_scalar("{run}/{prefix}anneal_start_temp".format(run=run, prefix=prefix), design_sampler.anneal_start_temp) 32 | log.log_scalar("{run}/{prefix}anneal_final_temp".format(run=run, prefix=prefix), design_sampler.anneal_final_temp) 33 | log.log_scalar("{run}/{prefix}log_p".format(run=run, prefix=prefix), design_sampler.log_p_mean.item()) 34 | log.log_scalar("{run}/{prefix}chi_error".format(run=run, prefix=prefix), design_sampler.chi_error) 35 | log.log_scalar("{run}/{prefix}chi_rmsd".format(run=run, prefix=prefix), design_sampler.chi_rmsd) 36 | 37 | # log rosetta score terms 38 | for s in design_sampler.score_terms: 39 | log.log_scalar("{run}/z_{prefix}{s}".format(run=run, prefix=prefix, s=s), float(design_sampler.curr_score_terms[s].mean())) 40 | 41 | # log rosetta agnostic terms 42 | for n, s in design_sampler.filter_scores: 43 | log.log_scalar("{run}/y_{prefix}{n}".format(run=run, prefix=prefix, n=n), s) 44 | 45 | 46 | 47 | def load_model(model, use_cuda=True, nic=len(common.atoms.atoms)): 48 | classifier = models.seqPred(nic=nic) 49 | if use_cuda: 50 | classifier.cuda() 51 | if use_cuda: 52 | state = torch.load(model) 53 | else: 54 | state = torch.load(model, map_location="cpu") 55 | for k in state.keys(): 56 | if "module" in k: 57 | print("MODULE") 58 | classifier = nn.DataParallel(classifier) 59 | break 60 | if use_cuda: 61 | classifier.load_state_dict(torch.load(model)) 62 | else: 63 | classifier.load_state_dict(torch.load(model, map_location="cpu")) 64 | return classifier 65 | 66 | 67 | def load_models(model_list, use_cuda=True, nic=len(common.atoms.atoms)): 68 | classifiers = [] 69 | for model in model_list: 70 | classifier = load_model(model, use_cuda=use_cuda, nic=nic) 71 | classifiers.append(classifier) 72 | return classifiers 73 | 74 | 75 | def main(): 76 | 77 | manager = common.run_manager.RunManager() 78 | 79 | manager.parse_args() 80 | args = manager.args 81 | log = manager.log 82 | 83 | use_cuda = torch.cuda.is_available() 84 | 85 | # download pdb if not there already 86 | if not os.path.isfile(args.pdb): 87 | print("Downloading pdb to current directory...") 88 | os.system("wget -O {} https://files.rcsb.org/download/{}.pdb".format(args.pdb, args.pdb[:-4].upper())) 89 | 90 | assert os.path.isfile(args.pdb), "pdb not found" 91 | 92 | # load models 93 | if args.init_model != "": 94 | init_classifier = load_model(args.init_model, use_cuda=use_cuda, nic=len(common.atoms.atoms)) 95 | init_classifier.eval() 96 | init_classifiers = [init_classifier] 97 | else: 98 | assert not (args.ala and args.val), "must specify either poly-alanine or poly-valine" 99 | if args.randomize: 100 | if args.ala: 101 | init_scheme = "poly-alanine" 102 | elif args.val: 103 | init_scheme = "poly-valine" 104 | else: 105 | init_scheme = "random" 106 | else: init_scheme = 'using starting structure' 107 | print("No baseline model specified, initialization will be %s" % init_scheme) 108 | init_classifiers = None 109 | 110 | classifiers = load_models(args.model_list, use_cuda=use_cuda, nic=len(common.atoms.atoms) + 1 + 21) 111 | for classifier in classifiers: 112 | classifier.eval() 113 | 114 | # set up design_sampler 115 | design_sampler = sampler.Sampler(args, classifiers, init_classifiers, log=log, use_cuda=use_cuda) 116 | 117 | # initialize sampler 118 | design_sampler.init() 119 | 120 | # log metrics for gt seq/structure 121 | log_metrics(args=args, log=log, iteration=0, design_sampler=design_sampler, prefix="gt_") 122 | best_rosetta_energy = np.inf 123 | best_energy = np.inf 124 | 125 | # initialize design_sampler sequence with baseline model prediction or random/poly-alanine/poly-valine initial sequence, save initial model 126 | design_sampler.init_seq() 127 | design_sampler.pose.dump_pdb(log.log_path + "/" + args.file_dir + "/" + "curr_0.pdb") 128 | 129 | # save trajectories for logmeans and rosettas 130 | logmeans = np.zeros(int(args.n_iters)) 131 | rosettas = np.zeros(int(args.n_iters)) 132 | 133 | # run design 134 | with torch.no_grad(): 135 | for i in tqdm(range(1, int(args.n_iters)), desc='running design'): 136 | 137 | # step 138 | design_sampler.step() 139 | 140 | # logging 141 | log_metrics(args=args, log=log, iteration=i, design_sampler=design_sampler) 142 | 143 | # save log_p_means and rosettas 144 | logmeans[i] = design_sampler.log_p_mean 145 | rosettas[i] = design_sampler.rosetta_energy 146 | 147 | if design_sampler.log_p_mean < best_energy: 148 | design_sampler.pose.dump_pdb(log.log_path + "/" + args.file_dir + "/" + "curr_best_log_p_%s.pdb" % log.ts) 149 | best_energy = design_sampler.log_p_mean 150 | 151 | if design_sampler.rosetta_energy < best_rosetta_energy: 152 | design_sampler.pose.dump_pdb(log.log_path + "/" + args.file_dir + "/" + "curr_best_rosetta_energy_%s.pdb" % log.ts) 153 | best_rosetta_energy = design_sampler.rosetta_energy 154 | 155 | # save intermediate models -- comment out if desired 156 | if (i==1) or (i % args.save_rate == 0) or (i == args.n_iters - 1): 157 | design_sampler.pose.dump_pdb(log.log_path + "/" + args.file_dir + "/" + "curr_%s_%s.pdb" % (i, log.ts)) 158 | 159 | log.advance_iteration() 160 | 161 | # save final model 162 | design_sampler.pose.dump_pdb(log.log_path + "/" + args.file_dir + "/" + "curr_final.pdb") 163 | 164 | np.savetxt('{}/{}/logmeans.txt'.format(log.log_path, args.file_dir),logmeans, delimiter=',') 165 | np.savetxt('{}/{}/rosetta_energy.txt'.format(log.log_path, args.file_dir),rosettas, delimiter=',') 166 | 167 | if __name__ == "__main__": 168 | main() 169 | -------------------------------------------------------------------------------- /seq_des/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProteinDesignLab/protein_seq_des/bb1e5a968f84a2db189f6a7ce400b96c5eaff691/seq_des/__init__.py -------------------------------------------------------------------------------- /seq_des/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import seq_des.util.data as data 4 | import common.atoms 5 | 6 | 7 | def init_ortho_weights(self): 8 | for module in self.modules(): 9 | if isinstance(module, nn.Conv2d): 10 | torch.nn.init.orthogonal_(module.weight) 11 | elif isinstance(module, nn.ConvTranspose2d): 12 | torch.nn.init.orthogonal_(module.weight) 13 | 14 | 15 | class seqPred(nn.Module): 16 | def __init__(self, nic, nf=64, momentum=0.01): 17 | super(seqPred, self).__init__() 18 | self.nic = nic 19 | self.model = nn.Sequential( 20 | # 20 -- 10 21 | nn.Conv3d(nic, nf, 4, 2, 1, bias=False), 22 | nn.BatchNorm3d(nf, momentum=momentum), 23 | nn.LeakyReLU(0.2, inplace=True), 24 | nn.Dropout(0.1), 25 | nn.Conv3d(nf, nf, 3, 1, 1, bias=False), 26 | nn.BatchNorm3d(nf, momentum=momentum), 27 | nn.LeakyReLU(0.2, inplace=True), 28 | nn.Dropout(0.1), 29 | nn.Conv3d(nf, nf, 3, 1, 1, bias=False), 30 | nn.BatchNorm3d(nf, momentum=momentum), 31 | nn.LeakyReLU(0.2, inplace=True), 32 | nn.Dropout(0.1), 33 | # 10 -- 5 34 | nn.Conv3d(nf, nf * 2, 4, 2, 1, bias=False), 35 | nn.BatchNorm3d(nf * 2, momentum=momentum), 36 | nn.LeakyReLU(0.2, inplace=True), 37 | nn.Dropout(0.1), 38 | nn.Conv3d(nf * 2, nf * 2, 3, 1, 1, bias=False), 39 | nn.BatchNorm3d(nf * 2, momentum=momentum), 40 | nn.LeakyReLU(0.2, inplace=True), 41 | nn.Dropout(0.1), 42 | nn.Conv3d(nf * 2, nf * 2, 3, 1, 1, bias=False), 43 | nn.BatchNorm3d(nf * 2, momentum=momentum), 44 | nn.LeakyReLU(0.2, inplace=True), 45 | nn.Dropout(0.1), 46 | # 5 -- 1 47 | nn.Conv3d(nf * 2, nf * 4, 5, 1, 0, bias=False), 48 | nn.BatchNorm3d(nf * 4, momentum=momentum), 49 | nn.LeakyReLU(0.2, inplace=True), 50 | nn.Dropout(0.1), 51 | nn.Conv3d(nf * 4, nf * 4, 3, 1, 1, bias=False), 52 | nn.BatchNorm3d(nf * 4, momentum=momentum), 53 | nn.LeakyReLU(0.2, inplace=True), 54 | nn.Dropout(0.1), 55 | nn.Conv3d(nf * 4, nf * 4, 3, 1, 1, bias=False), 56 | nn.BatchNorm3d(nf * 4, momentum=momentum), 57 | nn.LeakyReLU(0.2, inplace=True), 58 | nn.Dropout(0.1), 59 | ) 60 | 61 | # res pred 62 | self.out = nn.Sequential( 63 | nn.Conv1d(nf * 4, nf * 4, 3, 1, 1, bias=False), 64 | nn.BatchNorm1d(nf * 4, momentum=momentum), 65 | nn.LeakyReLU(0.2, inplace=True), 66 | nn.Dropout(0.1), 67 | nn.Conv1d(nf * 4, nf * 4, 3, 1, 1, bias=False), 68 | nn.BatchNorm1d(nf * 4, momentum=momentum), 69 | nn.LeakyReLU(0.2, inplace=True), 70 | nn.Dropout(0.1), 71 | nn.Conv1d(nf * 4, len(common.atoms.label_res_dict.keys()), 3, 1, 1, bias=False), 72 | ) 73 | 74 | # chi feat vec -- condition on residue and env feature vector 75 | self.chi_feat = nn.Sequential( 76 | nn.Conv1d(nf * 4 + 20, nf * 4, 3, 1, 1, bias=False), 77 | nn.BatchNorm1d(nf * 4, momentum=momentum), 78 | nn.LeakyReLU(0.2, inplace=True), 79 | nn.Dropout(0.1), 80 | nn.Conv1d(nf * 4, nf * 4, 3, 1, 1, bias=False), 81 | nn.BatchNorm1d(nf * 4, momentum=momentum), 82 | nn.LeakyReLU(0.2, inplace=True), 83 | nn.Dropout(0.1), 84 | ) 85 | 86 | # chi 1 pred -- condition on chi feat vec 87 | self.chi_1_out = nn.Sequential( 88 | nn.Conv1d(nf * 4, nf * 4, 3, 1, 1, bias=False), 89 | nn.BatchNorm1d(nf * 4, momentum=momentum), 90 | nn.LeakyReLU(0.2, inplace=True), 91 | nn.Dropout(0.1), 92 | nn.Conv1d(nf * 4, nf * 4, 3, 1, 1, bias=False), 93 | nn.BatchNorm1d(nf * 4, momentum=momentum), 94 | nn.LeakyReLU(0.2, inplace=True), 95 | nn.Dropout(0.1), 96 | nn.Conv1d(nf * 4, (len(data.CHI_BINS) - 1), 3, 1, 1, bias=False), 97 | ) 98 | 99 | # chi 2 pred -- condition on chi 1 and chi feat vec 100 | self.chi_2_out = nn.Sequential( 101 | nn.Conv1d(nf * 4 + 1 * (len(data.CHI_BINS) - 1), nf * 4, 3, 1, 1, bias=False), 102 | nn.BatchNorm1d(nf * 4, momentum=momentum), 103 | nn.LeakyReLU(0.2, inplace=True), 104 | nn.Dropout(0.1), 105 | nn.Conv1d(nf * 4, nf * 4, 3, 1, 1, bias=False), 106 | nn.BatchNorm1d(nf * 4, momentum=momentum), 107 | nn.LeakyReLU(0.2, inplace=True), 108 | nn.Dropout(0.1), 109 | nn.Conv1d(nf * 4, (len(data.CHI_BINS) - 1), 3, 1, 1, bias=False), 110 | ) 111 | 112 | # chi 3 pred -- condition on chi 1, chi 2, and chi feat vec 113 | self.chi_3_out = nn.Sequential( 114 | nn.Conv1d(nf * 4 + 2 * (len(data.CHI_BINS) - 1), nf * 4, 3, 1, 1, bias=False), 115 | nn.BatchNorm1d(nf * 4, momentum=momentum), 116 | nn.LeakyReLU(0.2, inplace=True), 117 | nn.Dropout(0.1), 118 | nn.Conv1d(nf * 4, nf * 4, 3, 1, 1, bias=False), 119 | nn.BatchNorm1d(nf * 4, momentum=momentum), 120 | nn.LeakyReLU(0.2, inplace=True), 121 | nn.Dropout(0.1), 122 | nn.Conv1d(nf * 4, (len(data.CHI_BINS) - 1), 3, 1, 1, bias=False), 123 | ) 124 | 125 | # chi 4 pred -- condition on chi 1, chi 2, chi 3, and chi feat vec 126 | self.chi_4_out = nn.Sequential( 127 | nn.Conv1d(nf * 4 + 3 * (len(data.CHI_BINS) - 1), nf * 4, 3, 1, 1, bias=False), 128 | nn.BatchNorm1d(nf * 4, momentum=momentum), 129 | nn.LeakyReLU(0.2, inplace=True), 130 | nn.Dropout(0.1), 131 | nn.Conv1d(nf * 4, nf * 4, 3, 1, 1, bias=False), 132 | nn.BatchNorm1d(nf * 4, momentum=momentum), 133 | nn.LeakyReLU(0.2, inplace=True), 134 | nn.Dropout(0.1), 135 | nn.Conv1d(nf * 4, (len(data.CHI_BINS) - 1), 3, 1, 1, bias=False), 136 | ) 137 | 138 | def res_pred(self, input): 139 | bs = input.size()[0] 140 | feat = self.model(input).view(bs, -1, 1) 141 | res_pred = self.out(feat).view(bs, -1) 142 | return res_pred, feat 143 | 144 | def get_chi_init_feat(self, feat, res_onehot): 145 | chi_init = torch.cat([feat, res_onehot[..., None]], 1) 146 | chi_feat = self.chi_feat(chi_init) 147 | return chi_feat 148 | 149 | def get_chi_1(self, chi_feat): 150 | chi_1_pred = self.chi_1_out(chi_feat).view(chi_feat.size()[0], -1) 151 | return chi_1_pred 152 | 153 | def get_chi_2(self, chi_feat, chi_1_onehot): 154 | chi_2_pred = self.chi_2_out(torch.cat([chi_feat, chi_1_onehot[..., None]], 1)).view(chi_feat.size()[0], -1) 155 | return chi_2_pred 156 | 157 | def get_chi_3(self, chi_feat, chi_1_onehot, chi_2_onehot): 158 | chi_3_pred = self.chi_3_out(torch.cat([chi_feat, chi_1_onehot[..., None], chi_2_onehot[..., None]], 1)).view(chi_feat.size()[0], -1) 159 | return chi_3_pred 160 | 161 | def get_chi_4(self, chi_feat, chi_1_onehot, chi_2_onehot, chi_3_onehot): 162 | chi_4_pred = self.chi_4_out(torch.cat([chi_feat, chi_1_onehot[..., None], chi_2_onehot[..., None], chi_3_onehot[..., None]], 1)).view( 163 | chi_feat.size()[0], -1 164 | ) 165 | return chi_4_pred 166 | 167 | def get_feat(self, input, res_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot): 168 | bs = input.size()[0] 169 | feat = self.model(input).view(bs, -1, 1) 170 | res_pred = self.out(feat).view(bs, -1) 171 | 172 | # condition on res type and env feat 173 | chi_init = torch.cat([feat, res_onehot[..., None]], 1) 174 | chi_feat = self.chi_feat(chi_init) 175 | 176 | # condition on true residue type and previous ground-truth rotamer angles 177 | chi_1_pred = self.chi_1_out(chi_feat).view(bs, -1) 178 | chi_2_pred = self.chi_2_out(torch.cat([chi_feat, chi_1_onehot[..., None]], 1)).view(bs, -1) 179 | chi_3_pred = self.chi_3_out(torch.cat([chi_feat, chi_1_onehot[..., None], chi_2_onehot[..., None]], 1)).view(bs, -1) 180 | chi_4_pred = self.chi_4_out(torch.cat([chi_feat, chi_1_onehot[..., None], chi_2_onehot[..., None], chi_3_onehot[..., None]], 1)).view(bs, -1) 181 | return feat, res_pred, chi_1_pred, chi_2_pred, chi_3_pred, chi_4_pred 182 | 183 | 184 | def forward(self, input, res_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot): 185 | bs = input.size()[0] 186 | feat = self.model(input).view(bs, -1, 1) 187 | res_pred = self.out(feat).view(bs, -1) 188 | 189 | # condition on res type and env feat 190 | chi_init = torch.cat([feat, res_onehot[..., None]], 1) 191 | chi_feat = self.chi_feat(chi_init) 192 | 193 | # condition on true residue type and previous ground-truth rotamer angles 194 | chi_1_pred = self.chi_1_out(chi_feat).view(bs, -1) 195 | chi_2_pred = self.chi_2_out(torch.cat([chi_feat, chi_1_onehot[..., None]], 1)).view(bs, -1) 196 | chi_3_pred = self.chi_3_out(torch.cat([chi_feat, chi_1_onehot[..., None], chi_2_onehot[..., None]], 1)).view(bs, -1) 197 | chi_4_pred = self.chi_4_out(torch.cat([chi_feat, chi_1_onehot[..., None], chi_2_onehot[..., None], chi_3_onehot[..., None]], 1)).view(bs, -1) 198 | 199 | return res_pred, chi_1_pred, chi_2_pred, chi_3_pred, chi_4_pred 200 | -------------------------------------------------------------------------------- /seq_des/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import math 4 | import sys 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.distributions.categorical import Categorical 9 | 10 | import seq_des.util.pyrosetta_util as putil 11 | import seq_des.util.sampler_util as sampler_util 12 | import seq_des.util.canonicalize as canonicalize 13 | import seq_des.util.data as data 14 | import seq_des.util.resfile_util as resfile_util 15 | import common.atoms 16 | 17 | from pyrosetta.rosetta.protocols.simple_filters import ( 18 | BuriedUnsatHbondFilterCreator, 19 | PackStatFilterCreator, 20 | ) 21 | from pyrosetta.rosetta.protocols.denovo_design.filters import ExposedHydrophobicsFilterCreator 22 | from pyrosetta.rosetta.core.scoring import automorphic_rmsd 23 | 24 | 25 | class Sampler(object): 26 | def __init__(self, args, models, init_models=None, log=None, use_cuda=True): 27 | super(Sampler, self).__init__() 28 | self.models = models 29 | for model in self.models: 30 | model.eval() 31 | 32 | if init_models is not None: 33 | self.init_models = init_models 34 | for init_model in self.init_models: 35 | init_model.eval() 36 | else: 37 | self.init_models = None 38 | self.no_init_model = args.no_init_model 39 | 40 | self.pdb = args.pdb 41 | self.log = log 42 | self.use_cuda = use_cuda 43 | 44 | self.threshold = args.threshold 45 | self.pack_radius = args.pack_radius 46 | self.iteration = 0 47 | self.randomize = args.randomize 48 | self.rotamer_repack = args.repack_only 49 | self.use_rosetta_packer = args.use_rosetta_packer 50 | self.no_cys = args.no_cys 51 | self.no_met = args.no_met 52 | self.symmetry = args.symmetry 53 | self.k = args.k 54 | self.restrict_gly = args.restrict_gly 55 | self.ala = args.ala 56 | self.val = args.val 57 | assert not (self.ala and self.val), "only ala or val settings can be on for a given run" 58 | self.chi_mask = None 59 | 60 | self.anneal = args.anneal 61 | self.anneal_start_temp = args.anneal_start_temp 62 | self.anneal_final_temp = args.anneal_final_temp 63 | self.step_rate = args.step_rate 64 | self.accept_prob = 1 65 | 66 | # load fixed idx if applicable 67 | if args.fixed_idx != "": 68 | # assert not self.symmetry, 'fixed idx not supported in symmetry mode' 69 | self.fixed_idx = sampler_util.get_idx(args.fixed_idx) 70 | else: 71 | self.fixed_idx = [] 72 | 73 | # resfile restrictions handling (see util/resfile_util.py) 74 | self.resfile = args.resfile 75 | if self.resfile: 76 | # get resfile NATRO (used to skip designing/packing at all) 77 | self.fixed_idx = resfile_util.get_natro(self.resfile) 78 | # get resfile commands (used to restrict amino acid probability distribution) 79 | self.resfile = resfile_util.read_resfile(self.resfile) 80 | # get initial resfile sequence (used to initialize the sequence) 81 | self.init_seq_resfile = self.resfile[2] 82 | 83 | # the initial sequence must be randomized (avoid running the baseline model) 84 | if self.init_seq_resfile: 85 | self.randomize = 0 86 | 87 | # load var idx if applicable 88 | if args.var_idx != "": 89 | # assert not self.symmetry, 'var idx not supported in symmetry mode' 90 | self.var_idx = sampler_util.get_idx(args.var_idx) 91 | else: 92 | self.var_idx = [] 93 | 94 | assert not ((len(self.fixed_idx) > 0) and (len(self.var_idx) > 0)), "cannot specify both fixed and variable indices" 95 | 96 | if self.rotamer_repack: 97 | assert self.init_models is not None, "baseline model must be used to initialize rotamer repacking" 98 | 99 | if self.no_init_model: 100 | assert not self.rotamer_repack, "baseline model must be used for initializing rotamer repacking" 101 | 102 | if self.symmetry: 103 | assert len(self.fixed_idx) == 0, "specifying fixed idx not supported in symmetry model" 104 | assert len(self.var_idx) == 0, "specifying var idx not supported in symmetry model" 105 | 106 | def init(self): 107 | """ initialize sampler 108 | - initialize rosetta filters 109 | - score starting (ground-truth) sequence 110 | - set up constraints on glycines 111 | - set up symmetry 112 | - eval metrics on starting (ground-truth) sequence 113 | - get blocks for blocked sampling 114 | """ 115 | 116 | # initialize sampler 117 | self.init_rosetta_filters() 118 | # score starting (ground-truth) pdb, get gt energies 119 | self.gt_pose = putil.get_pose(self.pdb) 120 | self.gt_seq = self.gt_pose.sequence() 121 | ( 122 | _, 123 | self.log_p_per_res, 124 | self.log_p_mean, 125 | self.logits, 126 | self.chi_feat, 127 | self.gt_chi_angles, 128 | self.gt_chi_mask, 129 | self.gt_chi, 130 | ) = sampler_util.get_energy( 131 | self.models, pose=self.gt_pose, return_chi=1, log_path=self.log.log_path, include_rotamer_probs=1, use_cuda=self.use_cuda, 132 | ) 133 | self.chi_error = 0 134 | self.re = putil.score_pose(self.gt_pose) 135 | self.gt_pose.dump_pdb(self.log.log_path + "/" + "gt_" + self.pdb) 136 | self.gt_score_terms = self.gt_pose.energies().residue_total_energies_array() 137 | self.score_terms = list(self.gt_score_terms.dtype.fields) 138 | 139 | # set no gly indices 140 | self.gt_pose.display_secstruct() 141 | ss = self.gt_pose.secstruct() 142 | self.no_gly_idx = [i for i in range(len(ss)) if ss[i] != "L"] 143 | self.n = self.gt_pose.residues.__len__() 144 | 145 | # handle symmetry 146 | if self.symmetry: 147 | if "tim" in self.pdb: 148 | # handle tim case 149 | self.n_k = math.ceil((self.n + 1) / self.k) if (self.n + 1) % 2 == 0 else math.ceil((self.n) / self.k) 150 | else: 151 | self.n_k = self.n // self.k 152 | assert self.n % self.k == 0, 'length of protein must be divisible by k for k-fold symm design' 153 | idx = [[i + j * (self.n_k) for j in range(self.k) if i + j * (self.n_k) < self.n] for i in range(self.n_k)] 154 | self.symmetry_idx = {} 155 | for idx_set in idx: 156 | for i in idx_set: 157 | self.symmetry_idx[i] = idx_set 158 | 159 | # updated fixed/var idx to reflect symmetry 160 | for i in self.fixed_idx: 161 | assert ( 162 | i in self.symmetry_idx.keys() 163 | ), "fixed idx must only be specified for first symmetric unit in symmetry mode (within first n_k residues)" 164 | for i in self.var_idx: 165 | assert ( 166 | i in self.symmetry_idx.keys() 167 | ), "var idx must only be specified for first symmetric unit in symmetry mode (within first n_k residues)" 168 | 169 | # get gt data -- monitor distance to initial sequence 170 | if '/' in self.pdb: 171 | pdb_idx = self.pdb.rfind("/") + 1 172 | pdb_dir = self.pdb[: self.pdb.rfind("/")] 173 | else: 174 | pdb_idx = 0 175 | pdb_dir = './' 176 | (self.gt_atom_coords, self.gt_atom_data, self.gt_residue_bb_index_list, res_data, self.gt_res_label, chis,) = data.get_pdb_data( 177 | self.pdb[pdb_idx : -4], data_dir=pdb_dir, assembly=0, skip_download=1, 178 | ) 179 | self.eval_metrics(self.gt_pose, self.gt_res_label) 180 | 181 | # get conditionally independent blocks via greedy k-colring of backbone 'graph' 182 | self.get_blocks() 183 | 184 | 185 | def init_seq(self): 186 | # initialize starting sequence 187 | 188 | # random/poly-alanine/poly-valine initialize sequence, pack rotamers 189 | self.pose = putil.get_pose(self.pdb) 190 | if self.randomize: 191 | if (not self.no_init_model) and not (self.ala or self.val): 192 | # get features --> BB only 193 | ( 194 | res_label, 195 | self.log_p_per_res_temp, 196 | self.log_p_mean_temp, 197 | self.logits_temp, 198 | self.chi_feat_temp, 199 | self.chi_angles_temp, 200 | self.chi_mask_temp, 201 | ) = sampler_util.get_energy( 202 | self.init_models, self.pose, bb_only=1, log_path=self.log.log_path, include_rotamer_probs=1, use_cuda=self.use_cuda, 203 | ) 204 | 205 | # set sequence 206 | if not self.rotamer_repack: 207 | # sample res from logits 208 | if not self.symmetry: 209 | res, idx, res_label = self.sample(self.logits_temp, np.arange(len(res_label))) 210 | else: 211 | res, idx, res_label = self.sample(self.logits_temp, np.arange(self.n_k)) 212 | # mutate pose residues based on baseline prediction 213 | self.pose = putil.mutate_list(self.pose, idx, res, pack_radius=0, fixed_idx=self.fixed_idx, var_idx=self.var_idx) 214 | else: 215 | res = [i for i in self.gt_seq] 216 | if self.symmetry: 217 | res_label = res_label[: self.n_k] 218 | 219 | # sample and set rotamers 220 | if self.symmetry: 221 | if not self.rotamer_repack: 222 | (self.chi_1, self.chi_2, self.chi_3, self.chi_4, idx, res_idx,) = self.sample_rotamer( 223 | np.arange(self.n_k), [res_label[i] for i in range(0, len(res_label), self.k)], self.chi_feat_temp, bb_only=1, 224 | ) 225 | else: 226 | (self.chi_1, self.chi_2, self.chi_3, self.chi_4, idx, res_idx,) = self.sample_rotamer( 227 | np.arange(self.n_k), res_label, self.chi_feat_temp, bb_only=1, 228 | ) 229 | else: 230 | (self.chi_1, self.chi_2, self.chi_3, self.chi_4, idx, res_idx,) = self.sample_rotamer( 231 | np.arange(len(res_label)), res_label, self.chi_feat_temp, bb_only=1, 232 | ) 233 | res = [common.atoms.label_res_single_dict[k] for k in res_idx] 234 | self.pose = self.set_rotamer(self.pose, res, idx, self.chi_1, self.chi_2, self.chi_3, self.chi_4, fixed_idx=self.fixed_idx, var_idx=self.var_idx) 235 | 236 | # Randomize sequence/rotamers 237 | else: 238 | if not self.rotamer_repack: 239 | random_seq = np.random.choice(20, size=len(self.pose)) 240 | if not self.ala and not self.val and self.symmetry: 241 | # random sequence must be symmetric 242 | random_seq = np.concatenate([random_seq[: self.n_k] for i in range(self.k)]) 243 | random_seq = random_seq[: len(self.pose)] 244 | self.pose, _ = putil.randomize_sequence( 245 | random_seq, 246 | self.pose, 247 | pack_radius=self.pack_radius, 248 | ala=self.ala, 249 | val=self.val, 250 | resfile_init_seq=self.init_seq_resfile, 251 | fixed_idx=self.fixed_idx, 252 | var_idx=self.var_idx, 253 | repack_rotamers=1,) 254 | else: 255 | assert False, "baseline model must be used for initializing rotamer repacking" 256 | 257 | # evaluate energy for starting structure/sequence 258 | (self.res_label, self.log_p_per_res, self.log_p_mean, self.logits, self.chi_feat, self.chi_angles, self.chi_mask,) = sampler_util.get_energy( 259 | self.models, self.pose, log_path=self.log.log_path, include_rotamer_probs=1, use_cuda=self.use_cuda, 260 | ) 261 | if self.rotamer_repack: 262 | assert np.all(self.chi_mask == self.gt_chi_mask), "gt and current pose chi masks should be the same when doing rotamer repacking" 263 | 264 | if self.anneal: 265 | self.pose.dump_pdb(self.log.log_path + "/" + "curr_pose_%s.pdb" % self.log.ts) 266 | 267 | def init_rosetta_filters(self): 268 | # initialize pyrosetta filters 269 | hbond_filter_creator = BuriedUnsatHbondFilterCreator() 270 | hydro_filter_creator = ExposedHydrophobicsFilterCreator() 271 | ps_filter_creator = PackStatFilterCreator() 272 | self.packstat_filter = ps_filter_creator.create_filter() 273 | self.exposed_hydrophobics_filter = hydro_filter_creator.create_filter() 274 | self.sc_buried_unsats_filter = hbond_filter_creator.create_filter() 275 | self.bb_buried_unsats_filter = hbond_filter_creator.create_filter() 276 | self.bb_buried_unsats_filter.set_report_bb_heavy_atom_unsats(True) 277 | self.sc_buried_unsats_filter.set_report_sc_heavy_atom_unsats(True) 278 | self.filters = [ 279 | ("packstat", self.packstat_filter), 280 | ("exposed_hydrophobics", self.exposed_hydrophobics_filter), 281 | ("sc_buried_unsats", self.sc_buried_unsats_filter), 282 | ("bb_buried_unsats", self.bb_buried_unsats_filter), 283 | ] 284 | 285 | def get_blocks(self, single_res=False): 286 | # get node blocks for blocked sampling 287 | D = sampler_util.get_CB_distance(self.gt_atom_coords, self.gt_residue_bb_index_list) 288 | if single_res: # no blocked gibbs -- sampling one res at a time 289 | self.blocks = [[i] for i in np.arange(D.shape[0])] 290 | self.n_blocks = len(self.blocks) 291 | else: 292 | A = sampler_util.get_graph_from_D(D, self.threshold) 293 | # if symmetry holding --> collapse graph st. all neighbors of node i are neighbors of node i+n//4 294 | if self.symmetry: 295 | for i in range(self.n_k): # //self.k): #self.graph.shape[0]): 296 | A[i] = np.sum(np.concatenate([A[i + j * self.n_k][None] for j in range(self.k) if i + j * self.n_k < self.n]), axis=0,) 297 | for i in range(self.n_k): 298 | A[:, i] = np.sum(np.concatenate([A[:, i + j * self.n_k][None] for j in range(self.k) if i + j * self.n_k < self.n]), axis=0,) 299 | A[A > 1] = 1 300 | A = A[: self.n_k, : self.n_k] 301 | 302 | self.graph = {i: np.where(A[i, :] == 1)[0] for i in range(A.shape[0])} 303 | # min k-color of graph by greedy search 304 | nodes = np.arange(A.shape[0]) 305 | np.random.shuffle(nodes) 306 | # eliminate fixed indices from list 307 | if self.symmetry: 308 | nodes = [n for n in range(self.n_k)] 309 | if len(self.fixed_idx) > 0: 310 | nodes = [n for n in nodes if n not in self.fixed_idx] 311 | elif len(self.var_idx) > 0: 312 | nodes = [n for n in nodes if n in self.var_idx] 313 | self.colors = sampler_util.color_nodes(self.graph, nodes) 314 | self.n_blocks = 0 315 | if self.colors: # check if there are any colored notes to get n-blocks (might be empty if running NATRO on all residues in resfile) 316 | self.n_blocks = sorted(list(set(self.colors.values())))[-1] + 1 317 | self.blocks = {} 318 | for k in self.colors.keys(): 319 | if self.colors[k] not in self.blocks.keys(): 320 | self.blocks[self.colors[k]] = [] 321 | self.blocks[self.colors[k]].append(k) 322 | 323 | self.reset_block_rate = self.n_blocks 324 | 325 | def eval_metrics(self, pose, res_label): 326 | self.rosetta_energy = putil.score_pose(pose) 327 | self.curr_score_terms = pose.energies().residue_total_energies_array() 328 | self.seq_overlap = (res_label == self.gt_res_label).sum() 329 | self.filter_scores = [] 330 | for n, filter in self.filters: 331 | self.filter_scores.append((n, filter.score(pose))) 332 | if self.rotamer_repack: 333 | self.chi_rmsd = sum([automorphic_rmsd(self.gt_pose.residue(i + 1), pose.residue(i + 1), True) for i in range(len(pose))]) / len(pose) 334 | else: 335 | self.chi_rmsd = 0 336 | self.seq = pose.sequence() 337 | if self.chi_mask is not None and self.rotamer_repack: 338 | chi_error = self.chi_mask * np.sqrt( 339 | (np.sin(self.chi_angles) - np.sin(self.gt_chi_angles)) ** 2 + (np.cos(self.chi_angles) - np.cos(self.gt_chi_angles)) ** 2 340 | ) 341 | self.chi_error = np.sum(chi_error) / np.sum(self.chi_mask) 342 | else: 343 | self.chi_error = 0 344 | 345 | def enforce_resfile(self, logits, idx): 346 | """ 347 | enforces resfile constraints by setting logits to -np.inf (see PyTorch on Categorical distribution - returns normalized value) 348 | 349 | logits - tensor where the columns are residue ids, rows are amino acid probabilities 350 | idx - residue ids 351 | """ 352 | constraints, header = self.resfile[0], self.resfile[1] 353 | # iterate over all residues and check if they're to be constrained 354 | for i in idx: 355 | if i in constraints.keys(): 356 | # set of amino acids to restrict in the tensor 357 | aa_to_restrict = constraints[i] 358 | for aa in aa_to_restrict: 359 | logits[i, common.atoms.aa_map_inv[aa]] = -99999 360 | elif header: # if not in the constraints, apply header (see util/resfile_util.py) 361 | aa_to_restrict = header["DEFAULT"] 362 | for aa in aa_to_restrict: 363 | logits[i, common.atoms.aa_map_inv[aa]] = -99999 364 | return logits 365 | 366 | def enforce_constraints(self, logits, idx): 367 | if self.resfile: 368 | logits = self.enforce_resfile(logits, idx) 369 | # enforce idx-wise constraints 370 | if self.no_cys: 371 | logits = logits[..., :-1] 372 | no_gly_idx = [i for i in idx if i in self.no_gly_idx] 373 | # note -- definitely other more careful ways to enforce met/gly constraints 374 | for i in idx: 375 | if self.restrict_gly: 376 | if i in self.no_gly_idx: 377 | logits[i, 18] = torch.min(logits[i]) 378 | if self.no_met: 379 | logits[i, 13] = torch.min(logits[i]) 380 | if self.no_cys: 381 | logits[i, 19] = torch.min(logits[i]) 382 | if self.symmetry: 383 | # average logits across all symmetry postions 384 | for i in idx: 385 | logits[i] = torch.cat([logits[j][None] for j in self.symmetry_idx[i] if j < self.n], 0).mean(0) 386 | return logits 387 | 388 | def sample_rotamer(self, idx, res_idx, feat, bb_only=0): 389 | # idx --> (block) residue indices (on chain) 390 | # res_idx --> idx of residue *type* (AA type) 391 | # feat --> initial env features from conv net 392 | assert len(idx) == len(res_idx), (len(idx), len(res_idx)) 393 | if bb_only: 394 | curr_models = self.init_models 395 | else: 396 | curr_models = self.models 397 | 398 | if not self.symmetry: 399 | 400 | # get residue onehot vector 401 | res_idx_long = torch.LongTensor(res_idx) 402 | res_onehot = sampler_util.make_onehot(res_idx_long.size()[0], 20, res_idx_long[:, None], use_cuda=self.use_cuda,) 403 | 404 | # get chi feat 405 | chi_feat = sampler_util.get_chi_init_feat(curr_models, feat[idx], res_onehot) 406 | # predict and sample chi angles 407 | chi_1_pred_out = sampler_util.get_chi_1_logits(curr_models, chi_feat) 408 | chi_1, chi_1_real, chi_1_onehot = sampler_util.sample_chi(chi_1_pred_out, use_cuda=self.use_cuda) 409 | chi_2_pred_out = sampler_util.get_chi_2_logits(curr_models, chi_feat, chi_1_onehot) 410 | chi_2, chi_2_real, chi_2_onehot = sampler_util.sample_chi(chi_2_pred_out, use_cuda=self.use_cuda) 411 | chi_3_pred_out = sampler_util.get_chi_3_logits(curr_models, chi_feat, chi_1_onehot, chi_2_onehot) 412 | chi_3, chi_3_real, chi_3_onehot = sampler_util.sample_chi(chi_3_pred_out, use_cuda=self.use_cuda) 413 | chi_4_pred_out = sampler_util.get_chi_4_logits(curr_models, chi_feat, chi_1_onehot, chi_2_onehot, chi_3_onehot) 414 | chi_4, chi_4_real, chi_4_onehot = sampler_util.sample_chi(chi_4_pred_out, use_cuda=self.use_cuda) 415 | 416 | return chi_1_real, chi_2_real, chi_3_real, chi_4_real, idx, res_idx 417 | 418 | else: 419 | 420 | # symmetric rotamer sampling 421 | 422 | # get symmetry indices 423 | symm_idx = [] 424 | for i in idx: 425 | symm_idx.extend([j for j in self.symmetry_idx[i]]) 426 | 427 | res_idx_symm = [] 428 | for i, idx_i in enumerate(idx): 429 | res_idx_symm.extend([res_idx[i] for j in self.symmetry_idx[idx_i]]) 430 | 431 | # get residue onehot vector 432 | res_idx_long = torch.LongTensor(res_idx_symm) 433 | res_onehot = sampler_util.make_onehot(res_idx_long.size()[0], 20, res_idx_long[:, None], use_cuda=self.use_cuda,) 434 | 435 | symm_idx_ptr = [] 436 | count = 0 437 | for i, idx_i in enumerate(idx): 438 | symm_idx_ptr.append([count + j for j in range(len(self.symmetry_idx[idx_i]))]) 439 | count = count + len(self.symmetry_idx[idx_i]) 440 | 441 | # get chi feature vector 442 | chi_feat = sampler_util.get_chi_init_feat(curr_models, feat[symm_idx], res_onehot) 443 | 444 | # predict and sample chi for each symmetry position 445 | chi_1_pred_out = sampler_util.get_chi_1_logits(curr_models, chi_feat) 446 | chi_1_real, chi_1_onehot = sampler_util.get_symm_chi(chi_1_pred_out, symm_idx_ptr, use_cuda=self.use_cuda) 447 | 448 | chi_2_pred_out = sampler_util.get_chi_2_logits(curr_models, chi_feat, chi_1_onehot) 449 | # set debug=True below to reproduce biorxiv results. Sample uniformly 2x from predicted rotamer bin. Small bug for TIM-barrel symmetry experiments for chi_2. 450 | chi_2_real, chi_2_onehot = sampler_util.get_symm_chi(chi_2_pred_out, symm_idx_ptr, use_cuda=self.use_cuda, debug=True) 451 | 452 | chi_3_pred_out = sampler_util.get_chi_3_logits(curr_models, chi_feat, chi_1_onehot, chi_2_onehot) 453 | chi_3_real, chi_3_onehot = sampler_util.get_symm_chi(chi_3_pred_out, symm_idx_ptr, use_cuda=self.use_cuda) 454 | 455 | chi_4_pred_out = sampler_util.get_chi_4_logits(curr_models, chi_feat, chi_1_onehot, chi_2_onehot, chi_3_onehot) 456 | chi_4_real, chi_4_onehot = sampler_util.get_symm_chi(chi_4_pred_out, symm_idx_ptr, use_cuda=self.use_cuda) 457 | 458 | return ( 459 | chi_1_real, 460 | chi_2_real, 461 | chi_3_real, 462 | chi_4_real, 463 | symm_idx, 464 | res_idx_symm, 465 | ) 466 | 467 | def set_rotamer(self, pose, res, idx, chi_1, chi_2, chi_3, chi_4, fixed_idx=[], var_idx=[]): 468 | # res -- residue type ID 469 | # idx -- residue index on BB (0-indexed) 470 | assert len(res) == len(idx) 471 | assert len(idx) == len(chi_1), (len(idx), len(chi_1)) 472 | for i, r_idx in enumerate(idx): 473 | if len(fixed_idx) > 0 and r_idx in fixed_idx: 474 | continue 475 | elif len(var_idx) > 0 and r_idx not in var_idx: 476 | continue 477 | res_i = res[i] 478 | chi_i = common.atoms.chi_dict[common.atoms.aa_inv[res_i]] 479 | if "chi_1" in chi_i.keys(): 480 | pose.set_chi(1, r_idx + 1, chi_1[i] * (180 / np.pi)) 481 | assert np.abs(pose.chi(1, r_idx + 1) - chi_1[i] * (180 / np.pi)) <= 1e-5, (pose.chi(1, r_idx + 1), chi_1[i] * (180 / np.pi)) 482 | if "chi_2" in chi_i.keys(): 483 | pose.set_chi(2, r_idx + 1, chi_2[i] * (180 / np.pi)) 484 | assert np.abs(pose.chi(2, r_idx + 1) - chi_2[i] * (180 / np.pi)) <= 1e-5, (pose.chi(2, r_idx + 1), chi_2[i] * (180 / np.pi)) 485 | if "chi_3" in chi_i.keys(): 486 | pose.set_chi(3, r_idx + 1, chi_3[i] * (180 / np.pi)) 487 | assert np.abs(pose.chi(3, r_idx + 1) - chi_3[i] * (180 / np.pi)) <= 1e-5, (pose.chi(3, r_idx + 1), chi_3[i] * (180 / np.pi)) 488 | if "chi_4" in chi_i.keys(): 489 | pose.set_chi(4, r_idx + 1, chi_4[i] * (180 / np.pi)) 490 | assert np.abs(pose.chi(4, r_idx + 1) - chi_4[i] * (180 / np.pi)) <= 1e-5, (pose.chi(4, r_idx + 1), chi_4[i] * (180 / np.pi)) 491 | 492 | return pose 493 | 494 | def sample(self, logits, idx): 495 | # sample residue from model conditional prob distribution at idx with current logits 496 | logits = self.enforce_constraints(logits, idx) 497 | dist = Categorical(logits=logits[idx]) 498 | res_idx = dist.sample().cpu().data.numpy() 499 | idx_out = [] 500 | res = [] 501 | assert len(res_idx) == len(idx), (len(idx), len(res_idx)) 502 | 503 | for k in list(res_idx): 504 | res.append(common.atoms.label_res_single_dict[k]) 505 | 506 | if self.symmetry: 507 | idx_out = [] 508 | for i in idx: 509 | idx_out.extend([j for j in self.symmetry_idx[i] if j < self.n]) 510 | res_out = [] 511 | for i, idx_i in enumerate(idx): 512 | res_out.extend([res[i] for j in self.symmetry_idx[idx_i] if j < self.n]) 513 | res_idx_out = [] 514 | for i, idx_i in enumerate(idx): 515 | res_idx_out.extend([res_idx[i] for j in self.symmetry_idx[idx_i] if j < self.n]) 516 | 517 | assert len(idx_out) == len(res_out), (len(idx_out), len(res_out)) 518 | assert len(idx_out) == len(res_idx_out), (len(idx_out), len(res_idx_out)) 519 | 520 | return res_out, idx_out, res_idx_out 521 | 522 | return res, idx, res_idx 523 | 524 | def sim_anneal_step(self, e, e_old): 525 | delta_e = e - e_old 526 | if delta_e < 0: 527 | accept_prob = 1.0 528 | else: 529 | if self.anneal_start_temp == 0: 530 | accept_prob = 0 531 | else: 532 | accept_prob = torch.exp(-(delta_e) / self.anneal_start_temp).item() 533 | return accept_prob 534 | 535 | def step_T(self): 536 | # anneal temperature 537 | self.anneal_start_temp = max(self.anneal_start_temp * self.step_rate, self.anneal_final_temp) 538 | 539 | def step(self): 540 | # no blocks to sample (NATRO for all residues) 541 | if self.n_blocks == 0: 542 | self.step_anneal() 543 | return 544 | 545 | # random idx selection, draw sample 546 | idx = self.blocks[np.random.choice(self.n_blocks)] 547 | 548 | if not self.rotamer_repack: 549 | # sample new residue indices/ residues 550 | res, idx, res_idx = self.sample(self.logits, idx) 551 | else: 552 | # residue idx is fixed (identity fixed) for rotamer repacking 553 | res = [self.gt_seq[i] for i in idx] 554 | res_idx = [common.atoms.aa_map_inv[self.gt_seq[i]] for i in idx] 555 | 556 | # sample rotamer using precomputed chi_feat vector 557 | (self.chi_1, self.chi_2, self.chi_3, self.chi_4, idx, res_idx,) = self.sample_rotamer(idx, res_idx, self.chi_feat) 558 | if self.anneal: 559 | self.pose = putil.get_pose(self.log.log_path + "/" + "curr_pose_%s.pdb" % self.log.ts) 560 | 561 | # mutate residues, set rotamers 562 | res = [common.atoms.label_res_single_dict[k] for k in res_idx] 563 | 564 | if not self.use_rosetta_packer: 565 | # mutate center residue 566 | if not self.rotamer_repack: 567 | self.pose_temp = putil.mutate_list(self.pose, idx, res, pack_radius=0, fixed_idx=self.fixed_idx, var_idx=self.var_idx) 568 | 569 | else: 570 | self.pose_temp = self.pose 571 | 572 | # sample and set center residue rotamer 573 | self.pose_temp = self.set_rotamer(self.pose_temp, res, idx, self.chi_1, self.chi_2, self.chi_3, self.chi_4, fixed_idx=self.fixed_idx, var_idx=self.var_idx) 574 | 575 | else: 576 | # Pyrosetta mutate and rotamer repacking 577 | self.pose_temp = putil.mutate_list( 578 | self.pose, idx, res, pack_radius=self.pack_radius, fixed_idx=self.fixed_idx, var_idx=self.var_idx, repack_rotamers=1 579 | ) 580 | 581 | # get log prob under model 582 | ( 583 | self.res_label_temp, 584 | self.log_p_per_res_temp, 585 | self.log_p_mean_temp, 586 | self.logits_temp, 587 | self.chi_feat_temp, 588 | self.chi_angles_temp, 589 | self.chi_mask_temp, 590 | ) = sampler_util.get_energy(self.models, self.pose_temp, log_path=self.log.log_path, include_rotamer_probs=1, use_cuda=self.use_cuda,) 591 | if self.anneal: 592 | # simulated annealing accept/reject step 593 | self.accept_prob = self.sim_anneal_step(self.log_p_mean_temp, self.log_p_mean) 594 | r = np.random.uniform(0, 1) 595 | else: 596 | # vanilla sampling step 597 | self.accept_prob = 1 598 | r = 0 599 | 600 | if r < self.accept_prob: 601 | if self.anneal: 602 | self.pose_temp.dump_pdb(self.log.log_path + "/" + "curr_pose_%s.pdb" % self.log.ts) 603 | # update pose 604 | self.pose = self.pose_temp 605 | (self.log_p_mean, self.log_p_per_res, self.logits, self.chi_feat, self.res_label,) = ( 606 | self.log_p_mean_temp, 607 | self.log_p_per_res_temp, 608 | self.logits_temp, 609 | self.chi_feat_temp, 610 | self.res_label_temp, 611 | ) 612 | self.chi_angles, self.chi_mask = self.chi_angles_temp, self.chi_mask_temp 613 | 614 | # eval all metrics 615 | self.eval_metrics(self.pose, self.res_label) 616 | 617 | self.step_anneal() 618 | 619 | def step_anneal(self): 620 | # ending for step() 621 | if self.anneal: 622 | self.step_T() 623 | 624 | self.iteration += 1 625 | 626 | # reset blocks 627 | if self.reset_block_rate != 0 and (self.iteration % self.reset_block_rate == 0): 628 | self.get_blocks() 629 | -------------------------------------------------------------------------------- /seq_des/util/README.md: -------------------------------------------------------------------------------- 1 | # Resfile Interface 2 | 3 | Authors: Damir Temir, Christian Choe 4 | 5 | ## Overview 6 | 7 | The resfile interface controls the amino acid distributions produced by the baseline and conditional models. 8 | It can be used to specify particular amino acids in certain residues, 9 | thus guiding the Protein Sequence Design algorithm to produce desired structures. 10 | 11 | Example of a resfile: 12 | 13 | ALLAA # set a default command for all residues not listed below 14 | START 15 | 34 ALLAAwc # allow all amino acids at residue #34 16 | 65 POLAR # allow only polar amino acids at residue #65 17 | 36 - 38 ALLAAxc # allow all amino acids except cysteine at residues #36 to #38 (including) 18 | 34 TPIKAA C # set the initial pose sequence postion at residue #34 to cysteine 19 | 55 - 58 NOTAA EHKNRQDST # disallow the listed amino acids at residues #55 to #58 20 | 20 NATRO # do not design the residue #20 at all 21 | 22 | ## Using resfile 23 | 24 | To use a resfile, create a new `.txt` where you specify all the flags. Then run: 25 | 26 | python3 run.py --pdb pdbs/3mx7_gt.pdb --resfile txt/resfiles/.txt 27 | 28 | ## List of Functions 29 | 30 | ### Body 31 | 32 | This is a **complete list of the commands that can be specified in the body** for particular residue ids: 33 | 34 | | Command | Description | 35 | | ------ | ----- | 36 | |ALLAA|Allows all amino acids| 37 | |ALLAAwc|Allows all amino acids (including cysteine)| 38 | |ALLAAxc|Allows all amino acids (excluding cysteine)| 39 | |POLAR|Allows only polar amino acids (DEHKNQRST)| 40 | |APOLAR|Allows only non-polar amino acids (ACFGILMPVWY)| 41 | |PIKAA|Allows only the specified amino acids| 42 | |NOTAA|Allows only those other than the specified amino acids| 43 | |NATRO|Disallows designing for that residue| 44 | |TPIKAA|Sets the specified amino acid in the initial sequence| 45 | |TNOTAA|Sets the amino acid other than the specified in the initial sequence| 46 | 47 | ### Header 48 | 49 | The header _can take_ these commands to limit **all residues not specified in the body**: 50 | 51 | | Command | Description | 52 | | ------ | ----- | 53 | |ALLAA|Allows all amino acids| 54 | |ALLAAwc|Allows all amino acids (including cysteine)| 55 | |ALLAAxc|Allows all amino acids (excluding cysteine)| 56 | |POLAR|Allows only polar amino acids (DEHKNQRST)| 57 | |APOLAR|Allows only non-polar amino acids (ACFGILMPVWY)| 58 | |PIKAA|Allows only the specified amino acids| 59 | |NOTAA|Allows only those other than the specified amino acids 60 | 61 | **NOTE**: The header command must be followed by the keyword **START** on a new line. 62 | 63 | The header _cannot take_ these commands for the following reasons: 64 | 65 | | Command | Reason | 66 | | ---- | ----- | 67 | |NATRO|Extracting residues for which the algorithm shouldn't design is a separate process. Please specify the range of residues to preserve in the body instead `ex. 1 - 90 NATRO`| 68 | |TPIKAA|Setting the particular residues in the initial sequence is a separate process. Please specify each amino acid for each residue in the body instead `5 TPIKAA C`| 69 | |TNOTAA|For the same reason as above. Please specify all amino acids to avoid in initializing for each residue instead `ex. 5 TNOTAA HKRDESTNQAVLIMFYWPG`| 70 | 71 | ### Ranges 72 | 73 | You can specify the ranges for which the command should apply. For example: 74 | 75 | 1 - 90 NATRO # will preserve all residues from residue #1 to #90 (including #90) 76 | 77 | The ranges can be specified for _all_ body commands, but **cannot be specified in the header section**. 78 | 79 | ### Initial Sequencing 80 | 81 | With the `TPIKAA` and `TNOTAA` commands we can initialize the sequence with particular amino acids. 82 | 83 | 1 TPIKAA C 84 | 2 TPIKAA T 85 | 3 TPIKAA Y 86 | 4 TNOTAA ACFGILMPVWYDEHKNQRS # will set res #4 to T since it's the only one not restricted 87 | ... 88 | 89 | Will result in an initial sequence `CTYT...` 90 | 91 | **NOTE**: you can still specify other commands for those residues that will restrict them in the following designs using the conditional model and not the baseline model. 92 | 93 | ## Results 94 | 95 | An example of a designed all-beta structure using the **backbone [3mx7_gt.pdb](../../pdbs/3mx7_gt.pdb)** with the **[resfile](../../txt/resfiles/resfile_3mx7_gt_ex1.txt)**: 96 | 97 | Before | After | 98 | :------:|:------| 99 | ![Example of a usual result where all Hydrogen Bonding Networks are external to the core](../../imgs/ex3_results.png)|![Example of a hydrogen bonding network in the core](../../imgs/ex2_results.png) 100 | 101 | An example of a designed all-alpha structure using the **backbone [1bkr_gt.pdb](../../pdbs/1bkr_gt.pdb)** with the **[resfile](../../txt/resfiles/resfile_1bkr_gt_ex6.txt)**: 102 | 103 | Before | After | 104 | :------:|:------| 105 | ![Example of a usual result where all Hydrogen Bonding Networks are external to the core](../../imgs/ex5_results.png)|![Example of a hydrogen bonding network in the core](../../imgs/ex6_results.png) 106 | 107 | 108 | -------------------------------------------------------------------------------- /seq_des/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProteinDesignLab/protein_seq_des/bb1e5a968f84a2db189f6a7ce400b96c5eaff691/seq_des/util/__init__.py -------------------------------------------------------------------------------- /seq_des/util/acc_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | """ accuracy eval fns """ 4 | 5 | label_coarse = {0:0,1:0,2:0, 3:1,4:1, 5:2,6:2, 7:3,8:3, 9:4,10:4,11:4,12:4,13:4, 14:5,15:5,16:5, 17:6,18:7, 19:8} 6 | label_res_single_dict_coarse= {0:'(+)', 1:'(-)', 2: 'ST', 3: 'NQ', 4: 'SH', 5:'LH', 6:'P', 7:'G', 8:'C'} 7 | 8 | label_polar = {0:0, 1:0,2:0,3:0,4:0,5:0,6:0, 7:0,8:0, 9:1,10:1,11:1,12:1, 13:2, 14:1, 15:2,16:2, 17:3,18:4, 19:0} 9 | label_res_single_dict_polar={0:'polar', 1: 'nonpolar', 2: 'amphipathic', 3: 'proline', 4:'glycine'} 10 | 11 | def get_acc(logits, label, cm=None, label_dict=None, ignore_idx=None): 12 | 13 | pred = torch.argmax(logits, 1) 14 | 15 | if label_dict is not None: 16 | pred = torch.LongTensor([label_dict[p] for p in pred.cpu().data.numpy()]) 17 | label = torch.LongTensor([label_dict[l] for l in label.cpu().data.numpy()]) 18 | 19 | if ignore_idx is None: 20 | acc = float((pred == label).sum(-1)) / label.size()[0] 21 | else: 22 | if len(label[label != ignore_idx]) == 0: 23 | # case when all data in a batch is to be ignored 24 | acc = 0.0 25 | else: 26 | acc = float((pred[label != ignore_idx ] == label[label != ignore_idx]).sum(-1)) / len(label[label != ignore_idx]) 27 | 28 | if cm is not None: 29 | if ignore_idx is None: 30 | for i in range(pred.size()[0]): 31 | # NOTE -- do not try to un-for loop this... errors 32 | cm[label[i], pred[i]] += 1 33 | 34 | else: 35 | for i in range(pred.size()[0]): 36 | # NOTE -- do not try to un-for loop this... errors 37 | if label[i] != ignore_idx: 38 | cm[label[i], pred[i]] += 1 39 | 40 | return acc, cm 41 | 42 | 43 | def get_chi_acc(logits, label, res_label, cm_dict=None, label_dict=None, ignore_idx=None): 44 | 45 | pred = torch.argmax(logits, 1) 46 | 47 | if label_dict is not None: 48 | pred = torch.LongTensor([label_dict[p] for p in pred.cpu().data.numpy()]) 49 | label = torch.LongTensor([label_dict[l] for l in label.cpu().data.numpy()]) 50 | 51 | if ignore_idx is None: 52 | acc = float((pred == label).sum(-1)) / label.size()[0] 53 | else: 54 | if len(label[label != ignore_idx]) == 0: 55 | # case when all data in a batch is to be ignored 56 | acc = 0.0 57 | else: 58 | acc = float((pred[label != ignore_idx ] == label[label != ignore_idx]).sum(-1)) / len(label[label != ignore_idx]) 59 | 60 | if cm_dict is not None: 61 | if ignore_idx is None: 62 | for i in range(pred.size()[0]): 63 | # NOTE -- do not try to un-for loop this... errors 64 | cm_dict[res_label[i].item()][label[i], pred[i]] += 1 65 | 66 | else: 67 | for i in range(pred.size()[0]): 68 | # NOTE -- do not try to un-for loop this... errors 69 | if label[i] != ignore_idx: 70 | cm_dict[res_label[i].item()][label[i], pred[i]] += 1 71 | 72 | return acc, cm_dict 73 | 74 | 75 | def get_chi_EV(probs, label, res_label, cm_dict=None, label_dict=None, ignore_idx=None): 76 | 77 | 78 | if cm_dict is not None: 79 | if ignore_idx is None: 80 | for i in range(probs.shape[0]): #ize()[0]): 81 | # NOTE -- do not try to un-for loop this... errors 82 | cm_dict[res_label[i].item()]['ev'] += probs[i] 83 | cm_dict[res_label[i].item()]['n']+= 1 84 | 85 | else: 86 | for i in range(probs.shape[0]): #ize()[0]): 87 | # NOTE -- do not try to un-for loop this... errors 88 | if label[i] != ignore_idx: 89 | cm_dict[res_label[i].item()]['ev'] += probs[i] 90 | cm_dict[res_label[i].item()]['n']+= 1 91 | 92 | return cm_dict 93 | 94 | 95 | # from pytorch ... 96 | def get_top_k_acc(output, target, k=3, ignore_idx=None): 97 | """Computes the accuracy over the k top predictions for the specified values of k""" 98 | with torch.no_grad(): 99 | batch_size = target.size(0) 100 | 101 | _, pred = output.topk(k, 1, True, True) 102 | pred = pred.t() 103 | if ignore_idx is not None: 104 | pred = pred[target !=ignore_idx] 105 | target = target[target !=ignore_idx] 106 | 107 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 108 | 109 | res = [] 110 | correct = correct.contiguous() 111 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 112 | #res.append(correct_k.mul_(100.0 / batch_size)) 113 | return correct_k.mul_(1.0 / batch_size).item() 114 | 115 | ### 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /seq_des/util/canonicalize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | import glob 4 | import pickle 5 | 6 | 7 | gly_CB_mu = np.array([-0.5311191 , -0.75842446, 1.2198311 ]) #pickle.load(open("pkl/CB_mu.pkl", "rb")) 8 | 9 | 10 | def get_len(v): 11 | return np.sqrt(np.sum(v ** 2, -1)) 12 | 13 | 14 | def get_unit_normal(ab, bc): 15 | n = np.cross(ab, bc, -1) 16 | length = get_len(n) 17 | if len(n.shape) > 2: 18 | length = length[..., None] 19 | return n / length 20 | 21 | 22 | def get_angle(v1, v2): 23 | # get in plane angle between v1, v2 -- cos^-1(v1.v2 / ||v1|| ||v2||) 24 | return np.arccos(np.sum(v1 * v2, -1) / get_len(v1) * get_len(v2)) 25 | 26 | 27 | def bdot(a, b): 28 | return np.matmul(a, b) 29 | 30 | 31 | def return_align_f(axis, theta): 32 | c_theta = np.cos(theta)[..., None] 33 | s_theta = np.sin(theta)[..., None] 34 | f_rot = lambda v: c_theta * v + s_theta * np.cross(axis, v, axis=-1) + (1 - c_theta) * bdot(axis, v.transpose(0, 2, 1)) * axis 35 | return f_rot 36 | 37 | 38 | def return_batch_align_f(axis, theta, n): 39 | # n is total number of atoms 40 | c_theta = np.cos(theta) 41 | s_theta = np.sin(theta) 42 | axis = np.repeat(axis, n, axis=1)[:, :, None] 43 | c_theta = np.repeat(c_theta, n, axis=1)[:, :, None, None] 44 | s_theta = np.repeat(s_theta, n, axis=1)[:, :, None, None] 45 | 46 | f_rot = lambda v: c_theta * v + s_theta * np.cross(axis, v, axis=-1) + (1 - c_theta) * bdot(axis, v.transpose(0, 1, 3, 2)) * axis 47 | return f_rot 48 | 49 | 50 | def get_batch_N_CA_C_align(normal, r, n): 51 | # get fn to align n to positive z_hat, via rotation about x axis (assume N-CA already along x_hat) 52 | # r is number of residues 53 | z = np.repeat(np.array([[0, 0, 1]]), r, 0)[:, None] 54 | theta = get_angle(normal, z) 55 | axis = get_unit_normal(normal, z) 56 | return return_align_f(axis, theta), return_batch_align_f(axis, theta, n=n) 57 | 58 | 59 | def get_batch_N_CA_align(v, r, n): 60 | # assuming ca is at (0,0,0), return fn to batch align CA--N to positive x axis 61 | # v = n - ca 62 | x = np.repeat(np.array([[1, 0, 0]])[None], r, 0) 63 | axis = get_unit_normal(v, x) 64 | theta = get_angle(v, x) 65 | return return_align_f(axis, theta), return_batch_align_f(axis, theta, n=n) 66 | 67 | 68 | def batch_canonicalize_coords(atom_coords, atom_data, residue_bb_index_list, res_idx=None, num_return=400, bb_only=0): 69 | """Function to get batch canonicalize atoms about all residues in a structure and mask out residue of interest. 70 | 71 | Args: 72 | atom_coords (np.array): num_atoms x 3 coordinates of all retained atoms in structure 73 | atom_data (np.array): num_atoms x 4 data for atoms -- [residue idx, BB ind, atom type, res type] 74 | residue_bb_index_list (np.array): num_res x 4 mapping from residue idx to atom indices for backbone atoms (N, CA, C, CB) used for canonicalization 75 | res_idx (np.array): num_output_res x 1 -- residue indices for subsampling residues ahead of canonicalization 76 | num_return (int): number of atoms to preserve about residue in environment 77 | Returns: 78 | x_coords (np.array): num_output_res x num_return x 1 x 3 -- canonicalized coordinates about each residue with center residue masked 79 | x_data (np.array): num_output_res x num_return x 1 x 4 -- metadata for canonicalized atoms for each environment 80 | """ 81 | 82 | n_atoms = atom_coords.shape[0] 83 | 84 | # subsampling residues to canonicalize 85 | if res_idx is not None: 86 | residue_bb_index_list = residue_bb_index_list[res_idx] 87 | n_res = len(res_idx) 88 | else: 89 | n_res = residue_bb_index_list.shape[0] 90 | 91 | num_return = min(num_return, n_atoms - 15) 92 | 93 | idx_N, idx_CA, idx_C, idx_CB = residue_bb_index_list[:, 0], residue_bb_index_list[:, 1], residue_bb_index_list[:, 2], residue_bb_index_list[:, 3] 94 | x = atom_coords.copy() 95 | 96 | center = x[idx_CA].copy() 97 | x_idxN, x_idxC, x_idxCA, x_idxCB = x[idx_N] - center, x[idx_C] - center, x[idx_CA] - center, x[idx_CB] - center 98 | x_data = atom_data.copy() 99 | 100 | x = np.repeat(x[None], n_res, axis=0) 101 | x_data = np.repeat(x_data[None], n_res, axis=0) 102 | 103 | # center coordinates at CA position 104 | x = x - center[:, None] 105 | 106 | # for each residue, eliminate side chain residue coordinates here -- 107 | bs, _, _, x_dim = x.shape 108 | x_data_dim = x_data.shape[-1] 109 | 110 | if res_idx is None: 111 | res_idx = np.arange(n_res) 112 | 113 | res_idx = np.tile(res_idx[:, None], (1, n_atoms)).reshape(-1) 114 | x = x.reshape(-1, x_dim) 115 | x_data = x_data.reshape(-1, x_data_dim) 116 | # get res_idx, indicator of bb atom 117 | x_res, x_bb, x_res_type = x_data[..., 0], x_data[..., 1], x_data[..., -1] 118 | assert len(x_res) == len(res_idx) 119 | 120 | if not bb_only: 121 | # exclude atoms on residue of interest that are not BB atoms 122 | exclude_idx = np.where((x_res == res_idx) & (x_bb != 1))[0] 123 | else: 124 | # exclude all side-chain atoms (bb only) 125 | exclude_idx = np.where((x_bb != 1))[0] 126 | 127 | # mask res type for all current residue atoms (no cheating!) 128 | res_type_exclude_idx = np.where((x_res == res_idx))[0] 129 | x_res_type[res_type_exclude_idx] = 21 # set to idx higher than highest -- 130 | 131 | # move coordinates for non-include residues well out of frame of reference -- will be omitted in next step or voxelize 132 | x[exclude_idx] = x[exclude_idx] + np.array([-1000.0, -1000.0, -1000.0]) 133 | x = x.reshape(bs, n_atoms, x_dim)[:, :, None] 134 | 135 | x_data = x_data.reshape(bs, n_atoms, x_data_dim)[:, :, None] 136 | 137 | # select num_return nearest atoms to env center 138 | d_x_out = np.sqrt(np.sum(x ** 2, -1)) 139 | idx = np.argpartition(d_x_out, kth=num_return, axis=1) 140 | idx = idx[:, :num_return] 141 | 142 | x = np.take_along_axis(x, idx[..., None], axis=1) 143 | x_data = np.take_along_axis(x_data, idx[..., None], axis=1) 144 | 145 | n = num_return 146 | 147 | # align N-CA along x axis 148 | f_R, f_bR = get_batch_N_CA_align(x_idxN - x_idxCA, r=n_res, n=n) # um_return) 149 | x = f_bR(x) 150 | x_idxN, x_idxC, x_idxCA, x_idxCB = f_R(x_idxN), f_R(x_idxC), f_R(x_idxCA), f_R(x_idxCB) 151 | 152 | # rotate so that normal of N-CA-C plane aligns to positive z_hat 153 | normal = get_unit_normal(x_idxN, x_idxC) 154 | f_R, f_bR = get_batch_N_CA_C_align(normal, r=n_res, n=n) # um_return) 155 | x_idxN, x_idxC, x_idxCA, x_idxCB = f_R(x_idxN), f_R(x_idxC), f_R(x_idxCA), f_R(x_idxCB) 156 | x = f_bR(x) 157 | 158 | # recenter at CB 159 | fixed_CB = np.ones((x_idxCB.shape[0], 1, 3)) * gly_CB_mu 160 | x = x - fixed_CB[:, None] 161 | 162 | return x, x_data 163 | -------------------------------------------------------------------------------- /seq_des/util/data.py: -------------------------------------------------------------------------------- 1 | import Bio.PDB 2 | import Bio.PDB.vectors 3 | 4 | import torch 5 | from torch.utils import data 6 | import torch.nn.functional as F 7 | 8 | import json 9 | import numpy as np 10 | import os 11 | import re 12 | import glob 13 | 14 | import common.atoms 15 | import seq_des.util.canonicalize as canonicalize 16 | import seq_des.util.voxelize as voxelize 17 | 18 | 19 | CHI_BINS = np.linspace(-np.pi, np.pi, num=25) 20 | 21 | def read_domain_ids_per_chain_from_txt(txt_file): 22 | pdbs = [] 23 | ids_chains = {} 24 | with open(txt_file, 'r') as f: 25 | for line in f: 26 | line = line.strip('\n').split() 27 | pdbs.append(line[0][:4]) 28 | ids_chains[line[0][:4]] = [] 29 | with open(txt_file, 'r') as f: 30 | for line in f: 31 | line = line.strip('\n').split() 32 | if len(line) == 6: # no icodes 33 | line.extend([' ', ' ']) 34 | elif len(line) == 7: 35 | line.extend([' ']) 36 | pdb = line[0][:4] 37 | ids_chains[pdb].append(tuple(line)) #line[:4], line[4:])) 38 | return [(k, ids_chains[k]) for k in ids_chains.keys()] 39 | 40 | 41 | def map_to_bins(chi): 42 | # map rotamer angles to discretized bins 43 | binned_pwd = np.digitize(chi, CHI_BINS) 44 | if len(binned_pwd[binned_pwd == 0]) > 0: 45 | binned_pwd[binned_pwd == 0] = 1 # in case chi == -np.pi 46 | return binned_pwd 47 | 48 | 49 | def download_pdb(pdb, data_dir, assembly=1): 50 | 51 | """Function to download pdb -- either biological assembly or if that 52 | is not available/specified -- download default pdb structure 53 | Uses biological assembly as default, otherwise gets default pdb. 54 | 55 | Args: 56 | pdb (str): pdb ID. 57 | data_dir (str): path to pdb directory 58 | 59 | Returns: 60 | f (str): path to downloaded pdb 61 | 62 | """ 63 | 64 | if assembly: 65 | f = data_dir + "/" + pdb + ".pdb1" 66 | if not os.path.isfile(f): 67 | try: 68 | os.system("wget -O {}.gz https://files.rcsb.org/download/{}.pdb1.gz".format(f, pdb.upper())) 69 | os.system("gunzip {}.gz".format(f)) 70 | 71 | except: 72 | f = data_dir + "/" + pdb + ".pdb" 73 | if not os.path.isfile(f): 74 | os.system("wget -O {} https://files.rcsb.org/download/{}.pdb".format(f, pdb.upper())) 75 | else: 76 | f = data_dir + "/" + pdb + ".pdb" 77 | 78 | if not os.path.isfile(f): 79 | os.system("wget -O {} https://files.rcsb.org/download/{}.pdb".format(f, pdb.upper())) 80 | 81 | return f 82 | 83 | 84 | def get_pdb_chains(pdb, data_dir, assembly=1, skip_download=0): 85 | 86 | """Function to load pdb structure via Biopython and extract all chains. 87 | Uses biological assembly as default, otherwise gets default pdb. 88 | 89 | Args: 90 | pdb (str): pdb ID. 91 | data_dir (str): path to pdb directory 92 | 93 | Returns: 94 | chains (list of (chain, chain_id)): all pdb chains 95 | 96 | """ 97 | if not skip_download: 98 | f = download_pdb(pdb, data_dir, assembly=assembly) 99 | 100 | if assembly: 101 | f = data_dir + "/" + pdb + ".pdb1" 102 | if not os.path.isfile(f): 103 | f = data_dir + "/" + pdb + ".pdb" 104 | else: 105 | f = data_dir + "/" + pdb + ".pdb" 106 | 107 | assert os.path.isfile(f) 108 | structure = Bio.PDB.PDBParser(QUIET=True).get_structure(pdb, f) 109 | 110 | assert len(structure) > 0, pdb 111 | 112 | # for assemblies -- sometimes chains are represented as different structures 113 | if len(structure) > 1: 114 | model = structure[0] 115 | count = 0 116 | for i in range(len(structure)): 117 | for c in structure[i].get_chains(): 118 | try: 119 | c.id = common.atoms.rename_chains[count] 120 | except: 121 | continue 122 | count += 1 123 | try: 124 | model.add(c) 125 | except Bio.PDB.PDBExceptions.PDBConstructionException: 126 | continue 127 | else: 128 | model = structure[0] 129 | 130 | # special hard-coded case with very large assembly -- not necessary to train on all 131 | if "2y26" in pdb: 132 | return [(c, c.id) for c in model.get_chains() if c.id in ["B", "A", "E", "C", "D"]] 133 | 134 | return [(c, c.id) for c in model.get_chains()] 135 | 136 | 137 | def get_pdb_data(pdb, data_dir="", assembly=1, skip_download=0): 138 | 139 | """Function to get atom coordinates and atom/residue metadata from pdb structures. 140 | 141 | Args: 142 | pdb (str): pdb ID 143 | data_dir (str): path to pdb directory 144 | assembly (int): 0/1 indicator of whether to use biological assembly or default pdb 145 | skip_download (int): 0/1 indicator of whether to skip attempt to download pdb from remote server 146 | 147 | Returns: 148 | atom_coords (np.array): num_atoms x 3 coordinates of all retained atoms in structure 149 | atom_data (np.array): num_atoms x 4 data for atoms -- [residue idx, BB ind, atom type, res type] 150 | residue_bb_index_list (np.array): num_res x 4 mapping from residue idx to atom indices for backbone atoms (N, CA, C, CB) used for canonicalization 151 | res_data (dict of list of lists): dictionary {chain ID: [ [residue ID, residue icode, residue index, residue type], ...]} 152 | res_label (np.array): num_res x 1 residue type labels (amino acid type) for all residues (to be included in training) 153 | 154 | """ 155 | 156 | # get pdb chain data 157 | pdb_chains = get_pdb_chains(pdb, data_dir, assembly=assembly, skip_download=skip_download) 158 | 159 | res_idx = 0 160 | res_data = {} 161 | atom_coords = [] 162 | atom_data = [] 163 | residue_bb_index = {} 164 | residue_bb_index_list = [] 165 | res_label = [] 166 | chis = [] 167 | # iterate over chains 168 | for pdb_chain, chain_id in pdb_chains: 169 | # iterate over residues 170 | res_data[chain_id] = [] 171 | for res in pdb_chain.get_residues(): 172 | skip_res = False # whether to skip training directly on this residue 173 | 174 | res_name = res.get_resname() 175 | het, res_id, res_icode = res.id 176 | 177 | # skip waters, metal ions, pre-specified ligands, unknown ligands 178 | if res_name in common.atoms.skip_res_list: 179 | continue 180 | 181 | res_atoms = [atom for atom in res.get_atoms()] 182 | 183 | # skip training on residues where all BB atoms are not present -- this will break canonicalization 184 | if res_name in common.atoms.res_label_dict.keys() and len(res_atoms) < 4: 185 | skip_res = True 186 | 187 | # if residue is an amino acid, add to label and save residue ID 188 | if (not skip_res) and (res_name in common.atoms.res_label_dict.keys()): 189 | res_type = common.atoms.res_label_dict[res_name] 190 | res_data[chain_id].append((res_id, res_icode, res_idx, res_type)) 191 | res_label.append(res_type) 192 | residue_bb_index[res_idx] = {} 193 | 194 | # iterate over atoms -- get coordinate data 195 | for atom in res.get_atoms(): 196 | 197 | if atom.element in common.atoms.skip_atoms: 198 | continue 199 | elif atom.element not in common.atoms.atoms: 200 | if res_name == "MSE" and atom.element == "SE": 201 | elem_name = "S" # swap MET for MSE 202 | else: 203 | elem_name = "other" # all other atoms are labeled 'other' 204 | else: 205 | elem_name = atom.element 206 | 207 | # get atomic coordinate 208 | c = np.array(list(atom.get_coord()))[None].astype(np.float32) 209 | 210 | # get atom type index 211 | assert elem_name in common.atoms.atoms 212 | atom_type = common.atoms.atoms.index(elem_name) 213 | 214 | # get whether atom is a BB atom 215 | bb = int(res_name in common.atoms.res_label_dict.keys() and atom.name in ["N", "CA", "C", "O", "OXT"]) 216 | 217 | if res_name in common.atoms.res_label_dict.keys(): 218 | res_type_idx = common.atoms.res_label_dict[res_name] 219 | else: 220 | res_type_idx = 20 # 'other' type (ligand, ion) 221 | 222 | # index -- residue idx, bb?, atom index, residue type (AA) 223 | index = np.array([res_idx, bb, atom_type, res_type_idx]) 224 | atom_coords.append(c) 225 | atom_data.append(index[None]) 226 | # if atom is BB atom, add to residue_bb_index dictionary 227 | if (not skip_res) and ((res_name in common.atoms.res_label_dict.keys())): 228 | # map from residue index to atom coordinate 229 | residue_bb_index[res_idx][atom.name] = len(atom_coords) - 1 230 | 231 | # get rotamer chi angles 232 | if (not skip_res) and (res_name in common.atoms.res_label_dict.keys()): 233 | if res_name == "GLY" or res_name == "ALA": 234 | chi = [0, 0, 0, 0] 235 | mask = [0, 0, 0, 0] 236 | 237 | else: 238 | chi = [] 239 | mask = [] 240 | if "N" in residue_bb_index[res_idx].keys() and "CA" in residue_bb_index[res_idx].keys(): 241 | n = Bio.PDB.vectors.Vector(list(atom_coords[residue_bb_index[res_idx]["N"]][0])) 242 | ca = Bio.PDB.vectors.Vector(list(atom_coords[residue_bb_index[res_idx]["CA"]][0])) 243 | if ( 244 | "chi_1" in common.atoms.chi_dict[common.atoms.label_res_dict[res_type]].keys() 245 | and common.atoms.chi_dict[common.atoms.label_res_dict[res_type]]["chi_1"] in residue_bb_index[res_idx].keys() 246 | and "CB" in residue_bb_index[res_idx].keys() 247 | ): 248 | cb = Bio.PDB.vectors.Vector(list(atom_coords[residue_bb_index[res_idx]["CB"]][0])) 249 | cg = Bio.PDB.vectors.Vector( 250 | atom_coords[residue_bb_index[res_idx][common.atoms.chi_dict[common.atoms.label_res_dict[res_type]]["chi_1"]]][0] 251 | ) 252 | chi_1 = Bio.PDB.vectors.calc_dihedral(n, ca, cb, cg) 253 | chi.append(chi_1) 254 | mask.append(1) 255 | 256 | if ( 257 | "chi_2" in common.atoms.chi_dict[common.atoms.label_res_dict[res_type]].keys() 258 | and common.atoms.chi_dict[common.atoms.label_res_dict[res_type]]["chi_2"] in residue_bb_index[res_idx].keys() 259 | ): 260 | cd = Bio.PDB.vectors.Vector( 261 | atom_coords[residue_bb_index[res_idx][common.atoms.chi_dict[common.atoms.label_res_dict[res_type]]["chi_2"]]][0] 262 | ) 263 | chi_2 = Bio.PDB.vectors.calc_dihedral(ca, cb, cg, cd) 264 | chi.append(chi_2) 265 | mask.append(1) 266 | 267 | if ( 268 | "chi_3" in common.atoms.chi_dict[common.atoms.label_res_dict[res_type]].keys() 269 | and common.atoms.chi_dict[common.atoms.label_res_dict[res_type]]["chi_3"] in residue_bb_index[res_idx].keys() 270 | ): 271 | ce = Bio.PDB.vectors.Vector( 272 | atom_coords[residue_bb_index[res_idx][common.atoms.chi_dict[common.atoms.label_res_dict[res_type]]["chi_3"]]][ 273 | 0 274 | ] 275 | ) 276 | chi_3 = Bio.PDB.vectors.calc_dihedral(cb, cg, cd, ce) 277 | chi.append(chi_3) 278 | mask.append(1) 279 | 280 | if ( 281 | "chi_4" in common.atoms.chi_dict[common.atoms.label_res_dict[res_type]].keys() 282 | and common.atoms.chi_dict[common.atoms.label_res_dict[res_type]]["chi_4"] in residue_bb_index[res_idx].keys() 283 | ): 284 | cz = Bio.PDB.vectors.Vector( 285 | atom_coords[ 286 | residue_bb_index[res_idx][common.atoms.chi_dict[common.atoms.label_res_dict[res_type]]["chi_4"]] 287 | ][0] 288 | ) 289 | chi_4 = Bio.PDB.vectors.calc_dihedral(cg, cd, ce, cz) 290 | chi.append(chi_4) 291 | mask.append(1) 292 | else: 293 | chi.append(0) 294 | mask.append(0) 295 | else: 296 | chi.extend([0, 0]) 297 | mask.extend([0, 0]) 298 | 299 | else: 300 | chi.extend([0, 0, 0]) 301 | mask.extend([0, 0, 0]) 302 | else: 303 | chi = [0, 0, 0, 0] 304 | mask = [0, 0, 0, 0] 305 | else: 306 | chi = [0, 0, 0, 0] 307 | mask = [0, 0, 0, 0] 308 | chi = np.array(chi) 309 | mask = np.array(mask) 310 | chis.append(np.concatenate([chi[None], mask[None]], axis=0)) 311 | 312 | # add bb atom indices in residue_list to residue_bb_index dict 313 | if (not skip_res) and res_name in common.atoms.res_label_dict.keys(): 314 | residue_bb_index[res_idx]["list"] = [] 315 | for atom in ["N", "CA", "C", "CB"]: 316 | if atom in residue_bb_index[res_idx]: 317 | residue_bb_index[res_idx]["list"].append(residue_bb_index[res_idx][atom]) 318 | else: 319 | # GLY handling for CB 320 | residue_bb_index[res_idx]["list"].append(-1) 321 | 322 | residue_bb_index_list.append(residue_bb_index[res_idx]["list"]) 323 | if not skip_res and (res_name in common.atoms.res_label_dict.keys()): 324 | res_idx += 1 325 | 326 | assert len(atom_coords) == len(atom_data) 327 | assert len(residue_bb_index_list) == len(res_label) 328 | assert len(chis) == len(residue_bb_index_list) 329 | 330 | return np.array(atom_coords), np.array(atom_data), np.array(residue_bb_index_list), res_data, np.array(res_label), np.array(chis) 331 | 332 | 333 | 334 | def get_domain_envs(pdb_id, domains_list, pdb_dir="/data/drive2tb/protein/pdb", num_return=400, bb_only=0): 335 | """ Get domain specific residues and local environments by first getting full biological assembly for 336 | pdb of interest -- selecting domain specific residues. 337 | 338 | Args: 339 | pdb_id (str): pdb structure ID 340 | domains_list (list of list of tuples of str): for each domain within pdb of interest -- list of domain start, stop residue IDs and icodes 341 | 342 | Returns: 343 | atom_coords_canonicalized (np.array): n_res x n_atoms x 3 array with canonicalized local 344 | environment atom coordinates 345 | atom_data_canonicalized (np.array): n_res x n_atoms x 4 with metadata for local env atoms 346 | [residue idx, BB ind, atom type, res type] 347 | res_data (dict of list of lists): dictionary with residue metadata -- {chain ID: [ [residue ID, residue icode, residue index, residue type], ...]} 348 | res_label (np.array): num_res x 1 residue type labels (amino acid type) for all residues (to be included in training) 349 | 350 | """ 351 | 352 | atom_coords, atom_data, residue_bb_index_list, res_data, res_label, chis = get_pdb_data(pdb_id, data_dir=pdb_dir) 353 | atom_coords_def = None 354 | 355 | assert len(res_label) > 0 356 | 357 | ind_assembly = [] 358 | res_idx_list_domains = [] 359 | # iterate over domains for PDB of interest 360 | for domain_split in domains_list: 361 | domain_id = domain_split[0] 362 | domain_split = domain_split[-1] 363 | chain_id, domains = get_domain(domain_split) 364 | res_idx_list = [] 365 | # iterate over start/end cutpoints for domain 366 | if chain_id in res_data.keys(): 367 | ind_assembly.append(1) 368 | else: 369 | if atom_coords_def is None: 370 | atom_coords_def, atom_data_def, residue_bb_index_list_def, res_data_def, res_label_def, chis_def = get_pdb_data(pdb_id, data_dir=pdb_dir, assembly=0) 371 | 372 | if chain_id in res_data_def.keys(): 373 | ind_assembly.append(0) 374 | if atom_coords_def is None: 375 | atom_coords_def, atom_data_def, residue_bb_index_list_def, res_data_def, res_label_def, chis_def = get_pdb_data(pdb_id, data_dir=pdb_dir, assembly=0) 376 | else: 377 | print("chain not found", chain_id, res_data.keys(), res_data_def.keys()) 378 | continue 379 | for ds, de in domains: 380 | start = False 381 | end = False 382 | if chain_id not in res_data.keys(): 383 | for res_id, res_icode, res_idx, res_type in res_data_def[chain_id]: 384 | assert res_idx < len(res_label_def) 385 | if (res_id != ds) and not start: 386 | continue 387 | elif res_id == ds: 388 | start = True 389 | if res_id == de: 390 | end = True 391 | if start and not end: 392 | res_idx_list.append(res_idx) 393 | if end: 394 | break 395 | else: 396 | # parse chain_res_data to get res_idx for domain of interest 397 | for res_id, res_icode, res_idx, res_type in res_data[chain_id]: 398 | assert res_idx < len(res_label) 399 | if (res_id != ds) and not start: 400 | continue 401 | elif res_id == ds: 402 | start = True 403 | if res_id == de: 404 | end = True 405 | if start and not end: 406 | res_idx_list.append(res_idx) 407 | if end: 408 | break 409 | res_idx_list_domains.append(res_idx_list) 410 | 411 | assert len(res_idx_list_domains) == len(ind_assembly) 412 | 413 | atom_coords_out = [] 414 | atom_data_out = [] 415 | res_label_out = [] 416 | domain_ids_out = [] 417 | chis_out = [] 418 | 419 | for i in range(len(res_idx_list_domains)): 420 | # canonicalize -- subset of residues 421 | if len(res_idx_list_domains[i]) == 0: 422 | continue 423 | if ind_assembly[i] == 1: 424 | # pull data from biological assembly 425 | atom_coords_canonicalized, atom_data_canonicalized = canonicalize.batch_canonicalize_coords(atom_coords, atom_data, residue_bb_index_list, res_idx=np.array(res_idx_list_domains[i]), num_return=num_return, bb_only=bb_only) 426 | else: 427 | # pull data from default structure 428 | atom_coords_canonicalized, atom_data_canonicalized = canonicalize.batch_canonicalize_coords( 429 | atom_coords_def, atom_data_def, residue_bb_index_list_def, res_idx=np.array(res_idx_list_domains[i]), num_return=num_return, bb_only=bb_only 430 | ) 431 | 432 | atom_coords_out.append(atom_coords_canonicalized) 433 | atom_data_out.append(atom_data_canonicalized) 434 | if ind_assembly[i] == 1: 435 | res_label_out.append(res_label[res_idx_list_domains[i]]) 436 | assert len(atom_coords_canonicalized) == len(res_label[res_idx_list_domains[i]]) 437 | chis_out.append(chis[res_idx_list_domains[i]]) 438 | else: 439 | res_label_out.append(res_label_def[res_idx_list_domains[i]]) 440 | assert len(atom_coords_canonicalized) == len(res_label_def[res_idx_list_domains[i]]) 441 | chis_out.append(chis_def[res_idx_list_domains[i]]) 442 | domain_ids_out.append(domains_list[i][0]) 443 | 444 | return atom_coords_out, atom_data_out, res_label_out, domain_ids_out, chis_out 445 | 446 | 447 | def get_domain(domain_split): 448 | # function to parse CATH domain info from txt -- returns chain and domain residue IDs 449 | chain = domain_split[-1] 450 | 451 | domains = domain_split.split(",") 452 | domains = [d[: d.rfind(":")] for d in domains] 453 | 454 | domains = [(d[: d.rfind("-")], d[d.rfind("-") + 1 :]) for d in domains] 455 | domains = [(int(re.findall("\D*\d+", ds)[0]), int(re.findall("\D*\d+", de)[0])) for ds, de in domains] 456 | 457 | return chain, np.array(domains) 458 | 459 | 460 | 461 | class PDB_domain_spitter(data.Dataset): 462 | def __init__(self, txt_file="data/052320_cath-b-newest-all.txt", pdb_path="/data/drive2tb/protein/pdb", num_return=400, bb_only=0): 463 | self.domains = read_domain_ids_per_chain_from_txt(txt_file) 464 | self.pdb_path = pdb_path 465 | self.num_return = num_return 466 | self.bb_only = bb_only 467 | 468 | def __len__(self): 469 | return len(self.domains) 470 | 471 | def __getitem__(self, index): 472 | pdb_id, domain_list = self.domains[index] 473 | return self.get_data(pdb_id, domain_list) 474 | 475 | def get_and_download_pdb(self, index): 476 | pdb_id, domain_list = self.domains[index] 477 | f = download_pdb(pdb_id, data_dir=self.pdb_path) 478 | return f 479 | 480 | def get_data(self, pdb, domain_list): 481 | try: 482 | atom_coords, atom_data, res_label, domain_id, chis = get_domain_envs(pdb, domain_list, pdb_dir=self.pdb_path, num_return=self.num_return, bb_only=self.bb_only) 483 | return atom_coords, atom_data, res_label, domain_id, chis 484 | except: 485 | return [] 486 | 487 | 488 | class PDB_data_spitter(data.Dataset): 489 | def __init__(self, data_dir="/data/simdev_2tb/protein/sequence_design/data/coords/test_s95_chi/", n=20, dist=10, datalen=1000): 490 | self.files = glob.glob("%s/data*pt" % (data_dir)) 491 | self.cached_pt = -1 492 | self.chunk_size = 10000 # args.chunk_size #50000i #NOTE -- CAUTION 493 | self.datalen = datalen 494 | self.data_dir = data_dir 495 | self.n = n 496 | self.dist = dist 497 | self.c = len(common.atoms.atoms) 498 | self.len = 0 499 | 500 | def __len__(self): 501 | if self.len == 0: 502 | return len(self.files) * self.chunk_size 503 | else: 504 | return self.len 505 | 506 | def get_data(self, index): 507 | if self.cached_pt != index // self.chunk_size: 508 | self.cached_pt = int(index // self.chunk_size) 509 | self.xs, self.x_data, self.ys, self.domain_ids, self.chis = torch.load("%s/data_%0.4d.pt" % (self.data_dir, self.cached_pt)) 510 | 511 | index = index % self.chunk_size 512 | x, x_data, y, domain_id, chis = self.xs[index], self.x_data[index], self.ys[index], self.domain_ids[index], self.chis[index] 513 | return x, x_data, y, domain_id, chis 514 | 515 | def __getitem__(self, index): # index): 516 | x, x_data, y, domain_id, chis = self.get_data(index) 517 | ## voxelize coordinates and atom metadata 518 | bs_idx, x_atom, x_bb, x_b, y_b, z_b, x_res_type = voxelize.get_voxel_idx(x[None], x_data[None], n=self.n, c=self.c, dist=self.dist) 519 | # map chi angles to bins 520 | chi_angles = chis[0] 521 | chi_mask = chis[1] 522 | chi_angles_binned = map_to_bins(chi_angles) 523 | chi_angles_binned[chi_mask == 0] = 0 # ignore index 524 | 525 | # return domain_id, x, x_data, y, chi_angles, chi_angles_binned 526 | return bs_idx, x_atom, x_bb, x_b, y_b, z_b, x_res_type, y, chi_angles, chi_angles_binned 527 | 528 | 529 | def collate_wrapper(data, crop=True): 530 | max_n = 0 531 | for i in range(len(data)): 532 | bs_idx, x_atom, x_bb, x_b, y_b, z_b, x_res_type, y, chi_angles, chi_angles_binned = data[i][0], data[i][1], data[i][2], data[i][3], data[i][4], data[i][5], data[i][6], data[i][7], data[i][8], data[i][9] 533 | # print(bs_idx.shape, x_atom.shape, x_bb.shape, x_b.shape, y_b.shape, z_b.shape, x_res_type.shape)# if pwd is greater than CROP_SIZE -- random crop 534 | n_i = x_atom.shape[-1] 535 | # print(n_i, min_n) 536 | if n_i > max_n: 537 | max_n = n_i 538 | 539 | # pad pwd data, coords 540 | out_bs_idx = [] 541 | out_y = [] 542 | out_x_atom = [] 543 | out_x_bb = [] 544 | out_x_b = [] 545 | out_y_b = [] 546 | out_z_b = [] 547 | out_x_res_type = [] 548 | out_chi_angles = [] 549 | out_chi_angles_binned = [] 550 | padding = False 551 | for i in range(len(data)): 552 | bs_idx, x_atom, x_bb, x_b, y_b, z_b, x_res_type, y, chi_angles, chi_angles_binned = data[i][0], data[i][1], data[i][2], data[i][3], data[i][4], data[i][5], data[i][6], data[i][7], data[i][8], data[i][9] 553 | n_i = x_atom.shape[-1] 554 | 555 | if n_i < max_n: 556 | padding = True 557 | # zero pad all --> x, y, z indexing will be omitted 558 | x_atom = np.pad(x_atom, ((0, max_n - n_i)), mode='constant') 559 | x_b = np.pad(x_b, ((0, max_n - n_i)), mode='constant') 560 | y_b = np.pad(y_b, ((0, max_n - n_i)), mode='constant') 561 | z_b = np.pad(z_b, ((0, max_n - n_i)), mode='constant') 562 | x_bb = np.pad(x_bb, ((0, max_n - n_i)), mode='constant') 563 | x_res_type = np.pad(x_res_type, ((0, max_n - n_i)), mode='constant') 564 | 565 | # handle batch indexing correctly 566 | out_bs_idx.append(torch.Tensor([i for j in range(len(x_b))])[None]) 567 | out_y.append(torch.Tensor([y])) # [None]) 568 | out_x_atom.append(torch.Tensor(x_atom)[None]) 569 | out_x_bb.append(torch.Tensor(x_bb)[None]) 570 | out_x_b.append(torch.Tensor(x_b)[None]) 571 | out_y_b.append(torch.Tensor(y_b)[None]) 572 | out_z_b.append(torch.Tensor(z_b)[None]) 573 | out_x_res_type.append(torch.Tensor(x_res_type)[None]) 574 | out_chi_angles.append(torch.Tensor(chi_angles)[None]) 575 | out_chi_angles_binned.append(torch.Tensor(chi_angles_binned)[None]) 576 | 577 | out_bs_idx = torch.cat(out_bs_idx, 0) 578 | out_y = torch.cat(out_y, 0) 579 | out_x_atom = torch.cat(out_x_atom, 0) 580 | out_x_bb = torch.cat(out_x_bb, 0) 581 | out_x_b = torch.cat(out_x_b, 0) 582 | out_y_b = torch.cat(out_y_b, 0) 583 | out_z_b = torch.cat(out_z_b, 0) 584 | out_x_res_type = torch.cat(out_x_res_type, 0) 585 | out_chi_angles = torch.cat(out_chi_angles, 0) 586 | out_chi_angles_binned = torch.cat(out_chi_angles_binned, 0) 587 | return out_bs_idx.long(), out_x_atom.long(), out_x_bb.long(), out_x_b.long(), out_y_b.long(), out_z_b.long(), out_x_res_type.long(), out_y.long(), out_chi_angles, out_chi_angles_binned.long() 588 | 589 | 590 | -------------------------------------------------------------------------------- /seq_des/util/pyrosetta_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import common.atoms 3 | 4 | from rosetta import * 5 | from pyrosetta import * 6 | init("-mute basic -mute core -mute protocols -ex1 -ex2 -constant_seed") 7 | 8 | #from pyrosetta.toolbox import pose_from_rcsb, cleanATOM # , mutate_residue 9 | from pyrosetta.rosetta.protocols.simple_moves import MutateResidue 10 | 11 | from pyrosetta.rosetta.core import conformation 12 | from pyrosetta.rosetta.core import chemical 13 | from pyrosetta.rosetta.protocols.minimization_packing import PackRotamersMover 14 | 15 | score_manager = pyrosetta.rosetta.core.scoring.ScoreTypeManager() 16 | scorefxn = get_fa_scorefxn() 17 | from pyrosetta.rosetta.core.chemical import aa_from_oneletter_code 18 | 19 | 20 | def get_seq_delta(s1, s2): 21 | count = 0 22 | for i in range(len(s1)): 23 | if s1[i] != s2[i]: 24 | count += 1 25 | return count 26 | 27 | def score_pose(pose): 28 | return scorefxn(pose) 29 | 30 | def randomize_sequence(new_seq, pose, pack_radius=5.0, fixed_idx=[], var_idx=[], ala=False, val=False, resfile_init_seq=False, enforce=False, repack_rotamers=0): 31 | for idx in range(pose.residues.__len__()): 32 | # do not mutate fixed indices / only mutate var indices 33 | if idx in fixed_idx: 34 | continue 35 | elif len(var_idx) > 0 and idx not in var_idx: 36 | continue 37 | 38 | res = pose.residue(idx + 1) 39 | ref_res_name = res.name() 40 | 41 | if ":" in ref_res_name: 42 | ref_res_name = ref_res_name[: ref_res_name.find(":")] 43 | if "_" in ref_res_name: 44 | ref_res_name = ref_res_name[: ref_res_name.find("_")] 45 | 46 | if ref_res_name not in common.atoms.res_label_dict.keys(): 47 | continue 48 | 49 | if ala: 50 | r = common.atoms.res_label_dict["ALA"] 51 | elif val: 52 | r = common.atoms.res_label_dict["VAL"] 53 | else: 54 | r = new_seq[idx] 55 | 56 | res_aa = common.atoms.aa_map[r] 57 | 58 | # resfile hangling: ex. 5 TPIKAA C means set the initial sequence at residue 5 to 'C' 59 | if idx in resfile_init_seq.keys(): 60 | res_aa = resfile_init_seq[idx] 61 | 62 | pose = handle_disulfide(pose, idx) 63 | mutate_residue(pose, idx + 1, res_aa, pack_radius=pack_radius, repack_rotamers=repack_rotamers) 64 | 65 | return pose, pose.residues.__len__() 66 | 67 | 68 | # from https://github.com/barricklab/mutant-protein-stability/blob/master/PyRosetta_TACC_MPI.py 69 | def handle_disulfide(pose, idx): 70 | res = pose.residue(idx + 1) 71 | if (res.name() == "CYS:disulfide") or (res.name() == "CYD"): 72 | disulfide_partner = None 73 | try: 74 | disulfide_partner = res.residue_connection_partner(res.n_residue_connections()) 75 | except AttributeError: 76 | disulfide_partner = res.residue_connection_partner(res.n_current_residue_connections()) 77 | temp_pose = pyrosetta.Pose() 78 | temp_pose.assign(pose) 79 | # (Packing causes seg fault if current CYS residue is not 80 | # also converted before mutating.) 81 | conformation.change_cys_state(idx + 1, "CYS", temp_pose.conformation()) 82 | conformation.change_cys_state(disulfide_partner, "CYS", temp_pose.conformation()) 83 | pose = temp_pose 84 | return pose 85 | 86 | 87 | def mutate(pose, idx, res, pack_radius=5.0, fixed_idx=[], var_idx=[], repack_rotamers=0): 88 | if idx in fixed_idx: 89 | return pose 90 | elif len(var_idx) > 0 and idx not in var_idx: 91 | return pose 92 | pose = handle_disulfide(pose, idx) 93 | pose = mutate_residue(pose, idx + 1, res, pack_radius=pack_radius, repack_rotamers=repack_rotamers) 94 | return pose 95 | 96 | 97 | def mutate_list(pose, idx_list, res_list, pack_radius=5.0, fixed_idx=[], var_idx=[], repack_rotamers=0): 98 | assert len(idx_list) == len(res_list), (len(idx_list), len(res_list)) 99 | for i in range(len(idx_list)): 100 | idx, res = idx_list[i], res_list[i] 101 | if len(fixed_idx) > 0 and idx in fixed_idx: 102 | continue 103 | if len(var_idx) > 0 and idx not in var_idx: 104 | continue 105 | sequence = pose.sequence() 106 | pose = mutate(pose, idx, res, pack_radius=pack_radius, fixed_idx=fixed_idx, var_idx=var_idx, repack_rotamers=repack_rotamers) 107 | new_sequence = pose.sequence() 108 | assert get_seq_delta(sequence, new_sequence) <= 1, get_seq_delta(sequence, new_sequence) 109 | assert res == pose.sequence()[idx], (res, pose.sequence()[idx]) 110 | return pose 111 | 112 | 113 | def get_pose(pdb): 114 | return pose_from_pdb(pdb) 115 | 116 | 117 | # from PyRosetta toolbox 118 | def restrict_non_nbrs_from_repacking(pose, res, task, pack_radius, repack_rotamers=0): 119 | """Configure a `PackerTask` to only repack neighboring residues and 120 | return the task. 121 | 122 | Args: 123 | pose (pyrosetta.Pose): The `Pose` to opertate on. 124 | res (int): Pose-numbered residue position to exclude. 125 | task (pyrosetta.rosetta.core.pack.task.PackerTask): `PackerTask` to modify. 126 | pack_radius (float): Radius used to define neighboring residues. 127 | 128 | Returns: 129 | pyrosetta.rosetta.core.pack.task.PackerTask: Configured `PackerTask`. 130 | """ 131 | 132 | if not repack_rotamers: 133 | assert pack_radius == 0, "pack radius must be 0 if you don't want to repack rotamers" 134 | 135 | def representative_coordinate(resNo): 136 | return pose.residue(resNo).xyz(pose.residue(resNo).nbr_atom()) 137 | 138 | center = representative_coordinate(res) 139 | for i in range(1, len(pose.residues) + 1): 140 | # only pack the mutating residue and any within the pack_radius 141 | if i == res: 142 | # comment out this block to reproduce biorxiv results 143 | #if not repack_rotamers: 144 | # task.nonconst_residue_task(i).prevent_repacking() 145 | continue 146 | if center.distance(representative_coordinate(i)) > pack_radius: 147 | task.nonconst_residue_task(i).prevent_repacking() 148 | else: 149 | if repack_rotamers: 150 | task.nonconst_residue_task(i).restrict_to_repacking() 151 | else: 152 | task.nonconst_residue_task(i).prevent_repacking() 153 | 154 | return task 155 | 156 | 157 | # modified from PyRosetta toolbox 158 | def mutate_residue(pose, mutant_position, mutant_aa, pack_radius=0.0, pack_scorefxn=None, repack_rotamers=0): 159 | """Replace the residue at a single position in a Pose with a new amino acid 160 | and repack any residues within user-defined radius of selected residue's 161 | center using. 162 | 163 | Args: 164 | pose (pyrosetta.rosetta.core.pose.Pose): 165 | mutant_position (int): Pose-numbered position of the residue to mutate. 166 | mutant_aa (str): The single letter name for the desired amino acid. 167 | pack_radius (float): Radius used to define neighboring residues. 168 | pack_scorefxn (pyrosetta.ScoreFunction): `ScoreFunction` to use when repacking the `Pose`. 169 | Defaults to the standard `ScoreFunction`. 170 | """ 171 | 172 | wpose = pose 173 | 174 | if not wpose.is_fullatom(): 175 | raise IOError("mutate_residue only works with fullatom poses") 176 | 177 | # create a standard scorefxn by default 178 | if not pack_scorefxn: 179 | pack_scorefxn = pyrosetta.get_score_function() 180 | 181 | # forces mutation 182 | mut = MutateResidue(mutant_position, common.atoms.aa_inv[mutant_aa]) 183 | mut.apply(wpose) 184 | 185 | # the numbers 1-20 correspond individually to the 20 proteogenic amino acids 186 | mutant_aa = int(aa_from_oneletter_code(mutant_aa)) 187 | aa_bool = pyrosetta.Vector1([aa == mutant_aa for aa in range(1, 21)]) 188 | # mutation is performed by using a PackerTask with only the mutant 189 | # amino acid available during design 190 | 191 | task = pyrosetta.standard_packer_task(wpose) 192 | task.nonconst_residue_task(mutant_position).restrict_absent_canonical_aas(aa_bool) 193 | 194 | # prevent residues from packing by setting the per-residue "options" of the PackerTask 195 | task = restrict_non_nbrs_from_repacking(wpose, mutant_position, task, pack_radius, repack_rotamers=repack_rotamers) 196 | 197 | # apply the mutation and pack nearby residues 198 | 199 | packer = PackRotamersMover(pack_scorefxn, task) 200 | packer.apply(wpose) 201 | # return pack_or_pose 202 | return wpose 203 | -------------------------------------------------------------------------------- /seq_des/util/resfile_util.py: -------------------------------------------------------------------------------- 1 | # developed by Damir Temir | github.com/dtemir | as a part of the RosettaCommons Summer Internship 2 | 3 | import common 4 | import re 5 | 6 | def read_resfile(filename): 7 | """ 8 | read a resfile and return a dictionary of constraints for each residue id 9 | 10 | the constraints is a dictionary where the keys are residue ids and values are the amino acids to restrict 11 | (passed residue ids in the resfile are subtracted 1 because the count in PDBs starts from 1, 12 | while in the logits the count is from 0) 13 | 14 | example: 15 | 65 ALLAA # allow all amino acids at residue id 65 (64 in the tensor) 16 | 54 ALLAAxc # allow all amino acids except cysteine at residue id 54 (53 in the tensor) 17 | 30 POLAR # allow only polar amino acids at residue id 30 (29 in the tensor) 18 | 31 - 33 NOTAA CFYG # disallow the specified amino acids at residue ids 31 to 33 (30 to 32 in the tensor) 19 | 43 TPIKAA C # allow only cysteine when initializing the sequence (same logic for TNOTAA) 20 | 21 | results into a dictionary: 22 | {64: {}, 53: {'C'}, 29: {'T', 'R', 'K', 'Q', 'D', 'E', 'S', 'N', 'H'}, 23 | 30: {'C', 'F', 'Y', 'G'}, 31: {'C', 'F', 'Y', 'G'}, 32: {'C', 'F', 'Y', 'G'}} 24 | 25 | plus it returns a header from check_for_header(): 26 | {"DEFAULT": {}} 27 | 28 | plus it returns a dictionary with the amino acids for initial sequence (NOTE: amino acids listed will NOT be used to initialize the sequence) 29 | {42: 'C'} 30 | """ 31 | def place_constraints(constraint, init_seq): 32 | """ 33 | places the constraints in the appropriate dicts 34 | -initial_seq for building the initial sequence with TPIKAA and TNOTAA 35 | -constraints for restricting the conditional model with PIKAA, NOTAA, ALLAA, POLAR, etc. 36 | """ 37 | if not init_seq: 38 | constraints[res_id] = constraint 39 | else: 40 | initial_seq[res_id] = constraint 41 | 42 | constraints = dict() # amino acids to restrict in the design 43 | header, start_id = check_for_header(filename) # amino acids to use as default for those not specified in constraints 44 | initial_seq = dict() # amino acids to use when initializing the sequence 45 | 46 | with open(filename, "r") as f: 47 | # iterate over the lines and extract arguments (residue id, command) 48 | lines = f.readlines() 49 | for line in lines[start_id + 1:]: 50 | args = [arg.strip() for arg in line.split(" ")] 51 | is_integer(args[0]) # the res id needs to be an integer 52 | assert isinstance(args[1], str), "the resfile command needs to be a string" 53 | 54 | res_id = int(args[0]) - 1 55 | if args[1] == "-": # if given a range of residue ids (ex. 31 - 33 NOTAA) 56 | is_integer(args[2]) # the res id needs to be an integer 57 | for res_id in range(res_id, int(args[2])): 58 | constraint, init_seq = check_for_commands(args, 3, 4) 59 | place_constraints(constraint, init_seq) 60 | else: # if not given a range (ex. 31 NOTAA CA) 61 | constraint, init_seq = check_for_commands(args, 1, 2) 62 | place_constraints(constraint, init_seq) 63 | 64 | # update the initial seq dictionary to only have one element per residue id (at random) 65 | initial_seq = {res_id : (common.atoms.resfile_commands["ALLAAwc"] - restricted_aa).pop() for res_id, restricted_aa in initial_seq.items()} 66 | 67 | return constraints, header, initial_seq 68 | 69 | def check_for_header(filename): 70 | """ 71 | read a resfile and return the header if present 72 | 73 | the header is commands that should be applied by default 74 | to all residues that are not specified after the 'start' keyword 75 | 76 | example of a header: 77 | ALLLA # allows all amino acids for residues that are not specified in the body 78 | START # divides the body and header 79 | # ... the body starts here, see read_resfile() 80 | """ 81 | header = {} 82 | start_id = -1 83 | with open(filename, "r") as f: 84 | start = re.compile(r"\bSTART|start\b") 85 | # if the file has the keyword start, extract header 86 | if bool(start.search(f.read())): 87 | f.seek(0) # set the cursor back to the beginning 88 | lines = f.readlines() 89 | for i, line in enumerate(lines): 90 | if start.match(line): 91 | start_id = i # the line number where start is used (divides header and body) 92 | break 93 | args = line.split() 94 | args.insert(0, "") # check_for_commands only handles the second argument (first is usually res_id) 95 | header['DEFAULT'] = check_for_commands(args, 1, 2) 96 | 97 | return header, start_id 98 | 99 | 100 | def check_for_commands(args, command_id, list_id): 101 | """ 102 | converts given commands into sets of amino acids to restrict in the logits 103 | 104 | so far, it handles these commands: ALLAA, ALLAAxc, POLAR, APOLAR, NOTAA, PIKAA 105 | 106 | command_id - the index where the command is within the args 107 | list_id - the index where the possible list of AA is within the args (only for NOTAA and PIKAA) 108 | """ 109 | constraint = set() 110 | command = args[command_id].upper() 111 | init_seq = False # reflect if it's TPIKAA or TNOTAA 112 | if command in common.atoms.resfile_commands.keys(): 113 | constraint = common.atoms.resfile_commands["ALLAAwc"] - common.atoms.resfile_commands[command] 114 | elif "PIKAA" in command: # allow only the specified amino acids 115 | constraint = common.atoms.resfile_commands["ALLAAwc"] - set(args[list_id].strip()) 116 | elif "NOTAA" in command: # disallow only the specified amino acids 117 | constraint = set(args[list_id].strip()) 118 | 119 | if command == "TPIKAA" or command == "TNOTAA": 120 | init_seq = True 121 | 122 | return constraint, init_seq 123 | 124 | def get_natro(filename): 125 | """ 126 | provides a list of indecies whose input rotamers and identities need to be presevered (Native Rotamer - NATRO) 127 | 128 | overrides the sampler.py's self.fixed_idx attribute with a list of the NATRO residues to be skipped in the 129 | self.get_blocks() function that picks sampling blocks 130 | 131 | if ALL residues in the resfile are NATRO, the sampler.py's self.step() skips running the neural network for 132 | amino acid prediction AND rotamer prediction 133 | """ 134 | fixed_idx = set() 135 | with open(filename, "r") as f: 136 | lines = f.readlines() 137 | for line in lines: 138 | args = [arg.strip().upper() for arg in line.split(" ")] 139 | if "NATRO" in args: 140 | is_integer(args[0]) 141 | if args[1] == "-": # provided a range of NATRO residues 142 | is_integer(args[2]) 143 | fixed_idx.update(range(int(args[0]) - 1, int(args[2]))) 144 | else: # provided a single NATRO residue 145 | fixed_idx.add(int(args[0]) - 1) 146 | 147 | return list(fixed_idx) 148 | 149 | def is_integer(n): 150 | try: 151 | int(n) 152 | except ValueError: 153 | raise ValueError("Incorrect residue index in the resfile ", n) 154 | -------------------------------------------------------------------------------- /seq_des/util/sampler_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import seq_des.util.data as data 5 | import seq_des.util.canonicalize as canonicalize 6 | import seq_des.util.voxelize as voxelize 7 | import common.atoms 8 | 9 | import torch.nn.functional as F 10 | from torch.distributions.categorical import Categorical 11 | 12 | 13 | def get_idx(filename): 14 | # get variable or fixed indices from list 15 | with open(filename, "r") as f: 16 | lines = list(f) 17 | idx = [int(line.strip("\n").split()[0]) for line in lines] 18 | return idx 19 | 20 | 21 | def get_CB_distance(x, x_data): 22 | # get CB-CB pairwise distances 23 | A = [] 24 | for k in range(x_data.shape[0]): 25 | idx_CA, idx_CB = x_data[k, 1], x_data[k, -1] 26 | if idx_CB >= 0: 27 | A.append(x[idx_CB]) 28 | else: 29 | A.append(x[idx_CA]) 30 | A = np.array(A)[:, 0, :3] 31 | D = np.sqrt(np.sum((A[:, None].repeat(len(A), axis=1) - A[None].repeat(len(A), axis=0)) ** 2, -1)) 32 | return D 33 | 34 | 35 | def get_graph_from_D(D, threshold): 36 | A = np.zeros_like(D) 37 | A[D < threshold] = 1 38 | return A 39 | 40 | 41 | def make_onehot(bs, dim, scatter_tensor, use_cuda=1): 42 | onehot = torch.FloatTensor(bs, dim) 43 | onehot.zero_() 44 | onehot.scatter_(1, scatter_tensor, 1) 45 | if use_cuda: 46 | return onehot.cuda() 47 | else: 48 | return onehot 49 | 50 | 51 | def get_energy_from_logits(logits, res_idx, mask=None, baseline=0): 52 | # get negative log prob from logits 53 | log_p = -F.log_softmax(logits, -1).gather(1, res_idx[:, None]) 54 | if mask is not None: 55 | log_p[mask == 1] = baseline 56 | log_p_mean = log_p.mean() 57 | return log_p, log_p_mean 58 | 59 | 60 | def get_conv_feat( 61 | curr_models, atom_coords, atom_data, residue_bb_index_list, res_data, res_label, chis, bb_only=0, return_chi=0, use_cuda=1 62 | ): 63 | atom_coords_canonicalized, atom_data_canonicalized = canonicalize.batch_canonicalize_coords( 64 | atom_coords, atom_data, residue_bb_index_list, bb_only=bb_only 65 | ) 66 | 67 | x = atom_coords_canonicalized 68 | y = res_label 69 | x_data = atom_data_canonicalized 70 | 71 | voxels = voxelize.voxelize(x, x_data, n=20, c=len(common.atoms.atoms), dist=10, bb_only=bb_only) 72 | voxels = torch.FloatTensor(voxels) 73 | bs_i = voxels.size()[0] 74 | if use_cuda: 75 | voxels = voxels.cuda() 76 | 77 | # map chi angles to bins 78 | chi_angles = chis[:, 0] 79 | chi_mask = chis[:, 1] 80 | chi_angles_binned = data.map_to_bins(chi_angles) 81 | chi_angles_binned[chi_mask == 0] = 0 82 | chi_angles_binned = torch.LongTensor(chi_angles_binned) 83 | 84 | chi_1 = chi_angles_binned[..., 0] 85 | chi_2 = chi_angles_binned[..., 1] 86 | chi_3 = chi_angles_binned[..., 2] 87 | chi_4 = chi_angles_binned[..., 3] 88 | 89 | # get chi onehot vectors -- NOTE can make this faster by precomputing, saving zero tensors 90 | chi_1_onehot = make_onehot(bs_i, len(data.CHI_BINS), chi_1[:, None], use_cuda=use_cuda) 91 | chi_2_onehot = make_onehot(bs_i, len(data.CHI_BINS), chi_2[:, None], use_cuda=use_cuda) 92 | chi_3_onehot = make_onehot(bs_i, len(data.CHI_BINS), chi_3[:, None], use_cuda=use_cuda) 93 | chi_4_onehot = make_onehot(bs_i, len(data.CHI_BINS), chi_4[:, None], use_cuda=use_cuda) 94 | 95 | y = torch.LongTensor(y) 96 | y_onehot = make_onehot(bs_i, 20, y[:, None], use_cuda=use_cuda) 97 | if use_cuda: 98 | y = y.cuda() 99 | 100 | # ensemble prediction over all models -- average logits 101 | logits_out = [] 102 | chi_feat_out = [] 103 | chi_1_out = [] 104 | chi_2_out = [] 105 | chi_3_out = [] 106 | chi_4_out = [] 107 | 108 | with torch.no_grad(): 109 | for model in curr_models: 110 | feat, res_pred_logits, chi_1_pred, chi_2_pred, chi_3_pred, chi_4_pred = model.get_feat( 111 | voxels, y_onehot, chi_1_onehot[:, 1:], chi_2_onehot[:, 1:], chi_3_onehot[:, 1:] 112 | ) 113 | logits_out.append(res_pred_logits[None]) 114 | chi_feat_out.append(feat[None]) 115 | chi_1_out.append(chi_1_pred[None]) 116 | chi_2_out.append(chi_2_pred[None]) 117 | chi_3_out.append(chi_3_pred[None]) 118 | chi_4_out.append(chi_4_pred[None]) 119 | 120 | logits_out = torch.cat(logits_out, 0).mean(0) 121 | chi_feat_out = torch.cat(chi_feat_out, 0).mean(0) 122 | chi_1_logits = torch.cat(chi_1_out, 0).mean(0) 123 | chi_2_logits = torch.cat(chi_2_out, 0).mean(0) 124 | chi_3_logits = torch.cat(chi_3_out, 0).mean(0) 125 | chi_4_logits = torch.cat(chi_4_out, 0).mean(0) 126 | 127 | chi_1 = chi_1 - 1 128 | chi_2 = chi_2 - 1 129 | chi_3 = chi_3 - 1 130 | chi_4 = chi_4 - 1 131 | 132 | if use_cuda: 133 | chi_1 = (chi_1).cuda() 134 | chi_2 = (chi_2).cuda() 135 | chi_3 = (chi_3).cuda() 136 | chi_4 = (chi_4).cuda() 137 | 138 | return ( 139 | logits_out, 140 | chi_feat_out, 141 | y, 142 | chi_1_logits, 143 | chi_2_logits, 144 | chi_3_logits, 145 | chi_4_logits, 146 | chi_1, 147 | chi_2, 148 | chi_3, 149 | chi_4, 150 | chi_angles, 151 | chi_mask, 152 | ) 153 | 154 | 155 | def get_energy_from_feat( 156 | models, 157 | logits, 158 | chi_feat, 159 | y, 160 | chi_1_logits, 161 | chi_2_logits, 162 | chi_3_logits, 163 | chi_4_logits, 164 | chi_1, 165 | chi_2, 166 | chi_3, 167 | chi_4, 168 | chi_angles, 169 | chi_mask, 170 | include_rotamer_probs=0, 171 | return_log_ps=0, 172 | use_cuda=True, 173 | ): 174 | # get residue log probs 175 | # energy, energy_per_res, 176 | log_p_per_res, log_p_mean = get_energy_from_logits(logits, y) 177 | 178 | # get rotamer log_probs 179 | chi_1_mask = torch.zeros_like(chi_1) 180 | chi_2_mask = torch.zeros_like(chi_2) 181 | chi_3_mask = torch.zeros_like(chi_3) 182 | chi_4_mask = torch.zeros_like(chi_4) 183 | 184 | if use_cuda: 185 | chi_1_mask = chi_1_mask.cuda() 186 | chi_2_mask = chi_2_mask.cuda() 187 | chi_3_mask = chi_3_mask.cuda() 188 | chi_4_mask = chi_4_mask.cuda() 189 | 190 | chi_1_mask[chi_1 < 0] = 1 191 | chi_2_mask[chi_2 < 0] = 1 192 | chi_3_mask[chi_3 < 0] = 1 193 | chi_4_mask[chi_4 < 0] = 1 194 | 195 | chi_1[chi_1 < 0] = 0 196 | chi_2[chi_2 < 0] = 0 197 | chi_3[chi_3 < 0] = 0 198 | chi_4[chi_4 < 0] = 0 199 | 200 | log_p_per_res_chi_1, log_p_per_res_chi_1_mean = get_energy_from_logits(chi_1_logits, chi_1, mask=chi_1_mask, baseline=1.3183412514892) 201 | log_p_per_res_chi_2, log_p_per_res_chi_2_mean = get_energy_from_logits(chi_2_logits, chi_2, mask=chi_2_mask, baseline=1.5970909799808386) 202 | log_p_per_res_chi_3, log_p_per_res_chi_3_mean = get_energy_from_logits(chi_3_logits, chi_3, mask=chi_3_mask, baseline=2.231545756901711) 203 | log_p_per_res_chi_4, log_p_per_res_chi_4_mean = get_energy_from_logits(chi_4_logits, chi_4, mask=chi_4_mask, baseline=2.084356748355477) 204 | 205 | if return_log_ps: 206 | return log_p_mean, log_p_per_res_chi_1_mean, log_p_per_res_chi_2_mean, log_p_per_res_chi_3_mean, log_p_per_res_chi_4_mean 207 | 208 | if include_rotamer_probs: 209 | # get per residue log probs (autoregressive) 210 | log_p_per_res = log_p_per_res + log_p_per_res_chi_1 + log_p_per_res_chi_2 + log_p_per_res_chi_3 + log_p_per_res_chi_4 211 | # optimize mean log prob across residues 212 | log_p_mean = log_p_per_res.mean() 213 | 214 | return log_p_per_res, log_p_mean 215 | 216 | 217 | def get_energy(models, pose=None, pdb=None, chain="A", bb_only=0, return_chi=0, use_cuda=1, log_path="./", include_rotamer_probs=0): 218 | if pdb is not None: 219 | atom_coords, atom_data, residue_bb_index_list, res_data, res_label, chis = data.get_pdb_data( 220 | pdb[pdb.rfind("/") + 1 : -4], data_dir=pdb[: pdb.rfind("/")], skip_download=1, assembly=0 221 | ) 222 | else: 223 | assert pose is not None, "need to specify pose to calc energy" 224 | pose.dump_pdb(log_path + "/" + "curr.pdb") 225 | atom_coords, atom_data, residue_bb_index_list, res_data, res_label, chis = data.get_pdb_data( 226 | "curr", data_dir=log_path, skip_download=1, assembly=0 227 | ) 228 | 229 | # get residue and rotamer logits 230 | logits, chi_feat, y, chi_1_logits, chi_2_logits, chi_3_logits, chi_4_logits, chi_1, chi_2, chi_3, chi_4, chi_angles, chi_mask = get_conv_feat( 231 | models, atom_coords, atom_data, residue_bb_index_list, res_data, res_label, chis, bb_only=bb_only, return_chi=return_chi, use_cuda=use_cuda 232 | ) 233 | 234 | # get model negative log probs (model energy) 235 | log_p_per_res, log_p_mean = get_energy_from_feat( 236 | models, 237 | logits, 238 | chi_feat, 239 | y, 240 | chi_1_logits, 241 | chi_2_logits, 242 | chi_3_logits, 243 | chi_4_logits, 244 | chi_1, 245 | chi_2, 246 | chi_3, 247 | chi_4, 248 | chi_angles, 249 | chi_mask, 250 | include_rotamer_probs=include_rotamer_probs, 251 | use_cuda=use_cuda, 252 | ) 253 | 254 | if return_chi: 255 | return res_label, log_p_per_res, log_p_mean, logits, chi_feat, chi_angles, chi_mask, [chi_1, chi_2, chi_3, chi_4] 256 | return res_label, log_p_per_res, log_p_mean, logits, chi_feat, chi_angles, chi_mask 257 | 258 | 259 | def get_chi_init_feat(curr_models, feat, res_onehot): 260 | chi_feat_out = [] 261 | with torch.no_grad(): 262 | for model in curr_models: 263 | chi_feat = model.get_chi_init_feat(feat, res_onehot) 264 | chi_feat_out.append(chi_feat[None]) 265 | chi_feat = torch.cat(chi_feat_out, 0).mean(0) 266 | return chi_feat 267 | 268 | 269 | def get_chi_1_logits(curr_models, chi_feat): 270 | chi_1_pred_out = [] 271 | with torch.no_grad(): 272 | for model in curr_models: 273 | chi_1_pred = model.get_chi_1(chi_feat) 274 | chi_1_pred_out.append(chi_1_pred[None]) 275 | chi_1_pred_out = torch.cat(chi_1_pred_out, 0).mean(0) 276 | return chi_1_pred_out 277 | 278 | 279 | def get_chi_2_logits(curr_models, chi_feat, chi_1_onehot): 280 | chi_2_pred_out = [] 281 | with torch.no_grad(): 282 | for model in curr_models: 283 | chi_2_pred = model.get_chi_2(chi_feat, chi_1_onehot) 284 | chi_2_pred_out.append(chi_2_pred[None]) 285 | chi_2_pred_out = torch.cat(chi_2_pred_out, 0).mean(0) 286 | return chi_2_pred_out 287 | 288 | 289 | def get_chi_3_logits(curr_models, chi_feat, chi_1_onehot, chi_2_onehot): 290 | chi_3_pred_out = [] 291 | with torch.no_grad(): 292 | for model in curr_models: 293 | chi_3_pred = model.get_chi_3(chi_feat, chi_1_onehot, chi_2_onehot) 294 | chi_3_pred_out.append(chi_3_pred[None]) 295 | chi_3_pred_out = torch.cat(chi_3_pred_out, 0).mean(0) 296 | return chi_3_pred_out 297 | 298 | 299 | def get_chi_4_logits(curr_models, chi_feat, chi_1_onehot, chi_2_onehot, chi_3_onehot): 300 | chi_4_pred_out = [] 301 | with torch.no_grad(): 302 | for model in curr_models: 303 | chi_4_pred = model.get_chi_4(chi_feat, chi_1_onehot, chi_2_onehot, chi_3_onehot) 304 | chi_4_pred_out.append(chi_4_pred[None]) 305 | chi_4_pred_out = torch.cat(chi_4_pred_out, 0).mean(0) 306 | return chi_4_pred_out 307 | 308 | 309 | def sample_chi(chi_logits, use_cuda=True): 310 | # sample chi bin from predicted distribution 311 | chi_dist = Categorical(logits=chi_logits) 312 | chi_idx = chi_dist.sample().cpu().data.numpy() 313 | chi = torch.LongTensor(chi_idx) 314 | # get one-hot encoding of sampled bin for autoregressive unroll 315 | chi_onehot = make_onehot(chi_logits.size()[0], len(data.CHI_BINS) - 1, chi[:, None], use_cuda=use_cuda) 316 | # sample chi angle (real) uniformly within bin 317 | chi_real = np.random.uniform(low=data.CHI_BINS[chi_idx], high=data.CHI_BINS[chi_idx + 1]) 318 | return chi, chi_real, chi_onehot 319 | 320 | 321 | def get_symm_chi(chi_pred_out, symm_idx_ptr, use_cuda=True, debug=False): 322 | chi_pred_out_symm = [] 323 | for i, ptr in enumerate(symm_idx_ptr): 324 | chi_pred_out_symm.append(chi_pred_out[ptr].mean(0)[None]) 325 | chi_pred_out = torch.cat(chi_pred_out_symm, 0) 326 | chi, chi_real, chi_onehot = sample_chi(chi_pred_out, use_cuda=use_cuda) 327 | if debug: 328 | # sample uniformly again from predicted bin. small bug for TIM-barrel symmetry experiments. ¯\_(ツ)_/¯ 329 | chi, chi_real, chi_onehot = sample_chi(chi_pred_out, use_cuda=use_cuda) 330 | 331 | chi_real_out = [] 332 | for i, ptr in enumerate(symm_idx_ptr): 333 | chi_real_out.append([chi_real[i][None] for j in range(len(ptr))]) # , 0)) 334 | chi_real = np.concatenate(chi_real_out, axis=0) 335 | 336 | chi_onehot_out = [] 337 | for i, ptr in enumerate(symm_idx_ptr): 338 | chi_onehot_out.append(torch.cat([chi_onehot[i][None] for j in range(len(ptr))], 0)) 339 | chi_onehot = torch.cat(chi_onehot_out, 0) 340 | return chi_real, chi_onehot 341 | 342 | 343 | # from https://codereview.stackexchange.com/questions/203319/greedy-graph-coloring-in-python 344 | def color_nodes(graph, nodes): 345 | color_map = {} 346 | # Consider nodes in descending degree 347 | for node in nodes: # sorted(graph, key=lambda x: len(graph[x]), reverse=True): 348 | neighbor_colors = set(color_map.get(neigh) for neigh in graph[node]) 349 | color_map[node] = next(color for color in range(len(graph)) if color not in neighbor_colors) 350 | return color_map 351 | 352 | 353 | ################################ 354 | -------------------------------------------------------------------------------- /seq_des/util/voxelize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def voxelize(x, x_data, n=20, c=13, dist=10, plot=False, bb_only=0): 5 | """Function to voxelize atom coordinate data ahead of training. Could be sped up on GPU 6 | 7 | Args: 8 | x_coords (np.array): num_res x num_return x 1 x 3 -- canonicalized coordinates about each residue with center residue masked 9 | x_data (np.array): num_res x num_return x 1 x 4 -- metadata for canonicalized atoms for each environment 10 | Returns: 11 | output (np.array): num_res x c x n x n x n -- 3D environments centered at each residue with atom type in channel dimensions 12 | """ 13 | 14 | bins = np.linspace(-dist, dist, n + 1) 15 | bs, nres, _, x_dim = x.shape 16 | x_data_dim = x_data.shape[-1] 17 | x = x.reshape(bs * nres, -1, x_dim) 18 | x_data = x_data.reshape(bs * nres, -1, x_data_dim) 19 | x_atom = x_data[..., 2].astype(np.int64) 20 | x_res_type = x_data[..., -1].astype(np.int64) 21 | x_bb = x_data[..., 1].astype(np.int64) 22 | 23 | bs_idx = np.tile(np.arange(bs)[:, None], (1, nres)).reshape(-1) 24 | # coordinates to voxels 25 | x_b = np.digitize(x[..., 0], bins) # [:, 0] 26 | y_b = np.digitize(x[..., 1], bins) # [:, 0] 27 | z_b = np.digitize(x[..., 2], bins) # [:, 0] 28 | 29 | # eliminate 'other' atoms 30 | x_atom[x_atom > c - 1] = c # force any non-listed atoms into 'other' category 31 | 32 | # this step can possibly be moved to GPU 33 | output_atom = np.zeros((bs, c + 1, n + 2, n + 2, n + 2)) 34 | output_atom[bs_idx, x_atom[:, 0], x_b[:, 0], y_b[:, 0], z_b[:, 0]] = 1 # atom type 35 | if not bb_only: 36 | output_bb = np.zeros((bs, 2, n + 2, n + 2, n + 2)) 37 | output_bb[bs_idx, x_bb[:, 0], x_b[:, 0], y_b[:, 0], z_b[:, 0]] = 1 # BB indicator 38 | output_res = np.zeros((bs, 22, n + 2, n + 2, n + 2)) 39 | output_res[bs_idx, x_res_type[:, 0], x_b[:, 0], y_b[:, 0], z_b[:, 0]] = 1 # res type for each atom 40 | # eliminate last channel for output_atom ('other' atom type), output_bb, and output_res (res type for current side chain) 41 | output = np.concatenate([output_atom[:, :c], output_bb[:, :1], output_res[:, :21]], 1) 42 | else: 43 | output = output_atom[:, :c] 44 | 45 | output = output[:, :, 1:-1, 1:-1, 1:-1] 46 | 47 | return output 48 | 49 | 50 | def get_voxel_idx(x, x_data, n=20, c=13, dist=10, plot=False): 51 | """Function to get indices for voxelized atom coordinate data ahead of training. 52 | 53 | Args: 54 | x_coords (np.array): num_res x num_return x 1 x 3 -- canonicalized coordinates about each residue with center residue masked 55 | x_data (np.array): num_res x num_return x 1 x 4 -- metadata for canonicalized atoms for each environment 56 | Returns: 57 | #NOTE -- FIX THIS 58 | output (np.array): num_res x c x n x n x n -- 3D environments centered at each residue with atom type in channel dimensions 59 | """ 60 | 61 | bins = np.linspace(-dist, dist, n + 1) 62 | bs, nres, _, x_dim = x.shape 63 | x_data_dim = x_data.shape[-1] 64 | x = x.reshape(bs * nres, -1, x_dim) 65 | x_data = x_data.reshape(bs * nres, -1, x_data_dim) 66 | x_atom = x_data[..., 2].astype(np.int64) 67 | x_res_type = x_data[..., -1].astype(np.int64) # not used for now 68 | x_bb = x_data[..., 1].astype(np.int64) 69 | 70 | bs_idx = np.tile(np.arange(bs)[:, None], (1, nres)).reshape(-1) 71 | 72 | # coordinates to voxels 73 | x_b = np.digitize(x[..., 0], bins) # [:, 0] 74 | y_b = np.digitize(x[..., 1], bins) # [:, 0] 75 | z_b = np.digitize(x[..., 2], bins) # [:, 0] 76 | 77 | # eliminate 'other' atoms 78 | x_atom[x_atom > c - 1] = c # force any non-listed atoms into 'other' category 79 | # print(x_atom.shape, x_res_type.shape, x_bb.shape) 80 | 81 | return bs_idx, x_atom[..., 0], x_bb[..., 0], x_b[..., 0], y_b[..., 0], z_b[..., 0], x_res_type[..., 0] 82 | 83 | 84 | -------------------------------------------------------------------------------- /seq_des_info.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProteinDesignLab/protein_seq_des/bb1e5a968f84a2db189f6a7ce400b96c5eaff691/seq_des_info.pdf -------------------------------------------------------------------------------- /train_autoreg_chi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from tqdm import tqdm 9 | import common.run_manager 10 | import seq_des.models as models 11 | import seq_des.util.voxelize as voxelize 12 | import glob 13 | import seq_des.util.canonicalize as canonicalize 14 | import pickle 15 | import seq_des.util.data as datasets 16 | from torch.utils import data 17 | import common.atoms 18 | import seq_des.util.acc_util as acc_util 19 | import subprocess as sp 20 | import time 21 | import torch.nn.functional as F 22 | 23 | """ script to train 3D CNN on local residue-centered environments -- with autoregressive rotamer chi angle prediction""" 24 | 25 | dist = 10 26 | n = 20 27 | c = len(common.atoms.atoms) 28 | 29 | 30 | def test(model, gen, dataloader, criterion, chi_1_criterion, chi_2_criterion, chi_3_criterion, chi_4_criterion, max_it=1e6, desc="test", batch_size=64, n_iters=500, k=3, return_cm=False, use_cuda=True): 31 | n_iters = min(max_it, n_iters) 32 | model = model.eval() 33 | gen = iter(dataloader) 34 | losses, avg_acc, avg_top_k_acc, avg_coarse_acc, avg_polar_acc, avg_chi_1_acc, avg_chi_2_acc, avg_chi_3_acc, avg_chi_4_acc, avg_chi_1_loss, avg_chi_2_loss, avg_chi_3_loss, avg_chi_4_loss = ([] for i in range(13)) 35 | 36 | with torch.no_grad(): 37 | 38 | for i in tqdm(range(n_iters), desc=desc): 39 | try: 40 | out = gen.next() 41 | except StopIteration: 42 | gen = iter(dataloader) 43 | out = gen.next() 44 | 45 | out = step(model, out, criterion, chi_1_criterion, chi_2_criterion, chi_3_criterion, chi_4_criterion, use_cuda=use_cuda) 46 | 47 | if out is None: 48 | continue 49 | loss, chi_1_loss, chi_2_loss, chi_3_loss, chi_4_loss, out, y, acc, top_k_acc, coarse_acc, polar_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc = out 50 | 51 | # append losses, accs to lists 52 | for x, y in zip( 53 | [losses, avg_acc, avg_top_k_acc, avg_coarse_acc, avg_polar_acc, avg_chi_1_acc, avg_chi_2_acc, avg_chi_3_acc, avg_chi_4_acc, avg_chi_1_loss, avg_chi_2_loss, avg_chi_3_loss, avg_chi_4_loss], 54 | [loss.item(), acc, top_k_acc, coarse_acc, polar_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc, chi_1_loss.item(), chi_2_loss.item(), chi_3_loss.item(), chi_4_loss.item()], 55 | ): 56 | x.append(y) 57 | 58 | del loss, chi_1_loss, chi_2_loss, chi_3_loss, chi_4_loss, out, y, acc, top_k_acc, coarse_acc, polar_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc 59 | 60 | print("\nloss", np.mean(losses), "acc", np.mean(avg_acc), "top3", np.mean(avg_top_k_acc), "coarse", np.mean(avg_coarse_acc), "polar", np.mean(avg_polar_acc)) 61 | 62 | return ( 63 | gen, 64 | np.mean(losses), 65 | np.mean(avg_chi_1_loss), 66 | np.mean(avg_chi_2_loss), 67 | np.mean(avg_chi_3_loss), 68 | np.mean(avg_chi_4_loss), 69 | np.mean(avg_acc), 70 | np.mean(avg_top_k_acc), 71 | np.mean(avg_coarse_acc), 72 | np.mean(avg_polar_acc), 73 | np.mean(avg_chi_1_acc), 74 | np.mean(avg_chi_2_acc), 75 | np.mean(avg_chi_3_acc), 76 | np.mean(avg_chi_4_acc), 77 | ) 78 | 79 | 80 | def step(model, out, criterion, chi_1_criterion, chi_2_criterion, chi_3_criterion, chi_4_criterion, k=3, use_cuda=True): 81 | 82 | bs_idx, x_atom, x_bb, x_b, y_b, z_b, x_res_type, y, chi_angles_real, chi_angles = out 83 | 84 | bs = len(bs_idx) 85 | output_atom = torch.zeros((bs, c + 1, n + 2, n + 2, n + 2)) 86 | output_bb = torch.zeros((bs, 2, n + 2, n + 2, n + 2)) 87 | output_res = torch.zeros((bs, 22, n + 2, n + 2, n + 2)) 88 | 89 | if use_cuda: 90 | output_atom, output_bb, output_res = map(lambda x: x.cuda(), [output_atom, output_bb, output_res]) 91 | 92 | 93 | output_atom[bs_idx, x_atom, x_b, y_b, z_b] = 1 # atom type 94 | output_bb.zero_() 95 | output_bb[bs_idx, x_bb, x_b, y_b, z_b] = 1 # BB indicator 96 | output_res.zero_() 97 | output_res[bs_idx, x_res_type, x_b, y_b, z_b] = 1 # res type 98 | output = torch.cat([output_atom[:, :c], output_bb[:, :1], output_res[:, :21]], 1) 99 | X = output[:, :, 1:-1, 1:-1, 1:-1] 100 | 101 | X, y = X.float(), y.long() 102 | chi_angles = chi_angles.long() 103 | 104 | chi_1 = chi_angles[:, 0] 105 | chi_2 = chi_angles[:, 1] 106 | chi_3 = chi_angles[:, 2] 107 | chi_4 = chi_angles[:, 3] 108 | 109 | y_onehot = torch.FloatTensor(y.size()[0], 20) 110 | y_onehot.zero_() 111 | y_onehot.scatter_(1, y[:, None], 1) 112 | 113 | chi_1_onehot = torch.FloatTensor(chi_1.size()[0], len(datasets.CHI_BINS)) 114 | chi_1_onehot.zero_() 115 | chi_1_onehot.scatter_(1, chi_1[:, None], 1) 116 | 117 | chi_2_onehot = torch.FloatTensor(chi_2.size()[0], len(datasets.CHI_BINS)) 118 | chi_2_onehot.zero_() 119 | chi_2_onehot.scatter_(1, chi_2[:, None], 1) 120 | 121 | chi_3_onehot = torch.FloatTensor(chi_3.size()[0], len(datasets.CHI_BINS)) 122 | chi_3_onehot.zero_() 123 | chi_3_onehot.scatter_(1, chi_3[:, None], 1) 124 | 125 | if use_cuda: 126 | X, y, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot, chi_1, chi_2, chi_3, chi_4 = map(lambda x: x.cuda(), [X, y, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot, chi_1, chi_2, chi_3, chi_4]) 127 | 128 | out, chi_1_pred, chi_2_pred, chi_3_pred, chi_4_pred = model(X, y_onehot, chi_1_onehot[:, 1:], chi_2_onehot[:, 1:], chi_3_onehot[:, 1:]) 129 | # loss 130 | loss = criterion(out, y) 131 | chi_1_loss = chi_1_criterion(chi_1_pred, chi_1 - 1) # [:, 1:]) 132 | chi_2_loss = chi_2_criterion(chi_2_pred, chi_2 - 1) # [:, 1:]) 133 | chi_3_loss = chi_3_criterion(chi_3_pred, chi_3 - 1) # [:, 1:]) 134 | chi_4_loss = chi_4_criterion(chi_4_pred, chi_4 - 1) # [:, 1:]) 135 | 136 | # acc 137 | acc, _ = acc_util.get_acc(out, y) 138 | top_k_acc = acc_util.get_top_k_acc(out, y, k=k) 139 | coarse_acc, _ = acc_util.get_acc(out, y, label_dict=acc_util.label_coarse) 140 | polar_acc, _ = acc_util.get_acc(out, y, label_dict=acc_util.label_polar) 141 | chi_1_acc, _ = acc_util.get_acc(chi_1_pred, chi_1 - 1, ignore_idx=-1) 142 | chi_2_acc, _ = acc_util.get_acc(chi_2_pred, chi_2 - 1, ignore_idx=-1) 143 | chi_3_acc, _ = acc_util.get_acc(chi_3_pred, chi_3 - 1, ignore_idx=-1) 144 | chi_4_acc, _ = acc_util.get_acc(chi_4_pred, chi_4 - 1, ignore_idx=-1) 145 | 146 | return loss, chi_1_loss, chi_2_loss, chi_3_loss, chi_4_loss, out, y, acc, top_k_acc, coarse_acc, polar_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc 147 | 148 | 149 | def step_iter(gen, dataloader): 150 | try: 151 | out = gen.next() 152 | except StopIteration: 153 | gen = iter(dataloader) 154 | out = gen.next() 155 | return gen, out 156 | 157 | 158 | def main(): 159 | 160 | manager = common.run_manager.RunManager() 161 | 162 | manager.parse_args() 163 | args = manager.args 164 | log = manager.log 165 | 166 | use_cuda = torch.cuda.is_available() and args.cuda 167 | 168 | # set up model 169 | model = models.seqPred(nic=len(common.atoms.atoms) + 1 + 21, nf=args.nf, momentum=0.01) 170 | model.apply(models.init_ortho_weights) 171 | 172 | if use_cuda: 173 | model.cuda() 174 | else: 175 | print("Training model on CPU") 176 | 177 | if args.model != "": 178 | # load pretrained model 179 | model.load_state_dict(torch.load(args.model)) 180 | print("loaded pretrained model") 181 | 182 | # parallelize over available GPUs 183 | if torch.cuda.device_count() > 1 and args.cuda: 184 | print("using", torch.cuda.device_count(), "GPUs") 185 | model = nn.DataParallel(model) 186 | 187 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta1, 0.999), weight_decay=args.reg) 188 | 189 | if args.optimizer != "": 190 | # load pretrained optimizer 191 | optimizer.load_state_dict(torch.load(args.optimizer)) 192 | print("loaded pretrained optimizer") 193 | 194 | # load pretrained model weights / optimizer state 195 | 196 | chi_1_criterion = nn.CrossEntropyLoss(ignore_index=-1) 197 | chi_2_criterion = nn.CrossEntropyLoss(ignore_index=-1) 198 | chi_3_criterion = nn.CrossEntropyLoss(ignore_index=-1) 199 | chi_4_criterion = nn.CrossEntropyLoss(ignore_index=-1) 200 | criterion = nn.CrossEntropyLoss() 201 | if use_cuda: 202 | criterion.cuda() 203 | chi_1_criterion.cuda() 204 | chi_2_criterion.cuda() 205 | chi_3_criterion.cuda() 206 | chi_4_criterion.cuda() 207 | 208 | train_dataset = datasets.PDB_data_spitter(data_dir=args.data_dir + "/train_s95_chi") 209 | train_dataset.len = 8145448 # NOTE -- need to update this if underlying data changes 210 | 211 | test_dataset = datasets.PDB_data_spitter(data_dir=args.data_dir + "/test_s95_chi") 212 | test_dataset.len = 574267 # NOTE -- need to update this if underlying data changes 213 | 214 | train_dataloader = data.DataLoader(train_dataset, batch_size=args.batchSize, shuffle=False, num_workers=args.workers, pin_memory=True, collate_fn=datasets.collate_wrapper) 215 | test_dataloader = data.DataLoader(test_dataset, batch_size=args.batchSize, shuffle=False, num_workers=args.workers, pin_memory=True, collate_fn=datasets.collate_wrapper) 216 | 217 | # training params 218 | validation_frequency = args.validation_frequency 219 | save_frequency = args.save_frequency 220 | 221 | """ TRAIN """ 222 | 223 | model.train() 224 | gen = iter(train_dataloader) 225 | test_gen = iter(test_dataloader) 226 | bs = args.batchSize 227 | output_atom = torch.zeros((bs, c + 1, n + 2, n + 2, n + 2)) 228 | output_bb = torch.zeros((bs, 2, n + 2, n + 2, n + 2)) 229 | output_res = torch.zeros((bs, 22, n + 2, n + 2, n + 2)) 230 | y_onehot = torch.FloatTensor(bs, 20) 231 | chi_1_onehot = torch.FloatTensor(bs, len(datasets.CHI_BINS)) 232 | chi_2_onehot = torch.FloatTensor(bs, len(datasets.CHI_BINS)) 233 | chi_3_onehot = torch.FloatTensor(bs, len(datasets.CHI_BINS)) 234 | 235 | if use_cuda: 236 | output_atom, output_bb, output_res, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot = map(lambda x: x.cuda(), [output_atom, output_bb, output_res, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot]) 237 | for epoch in range(args.epochs): 238 | for it in tqdm(range(len(train_dataloader)), desc="training epoch %0.2d" % epoch): 239 | 240 | gen, out = step_iter(gen, train_dataloader) 241 | bs_idx, x_atom, x_bb, x_b, y_b, z_b, x_res_type, y, chi_angles_real, chi_angles = out 242 | bs_i = len(bs_idx) 243 | output_atom.zero_() 244 | output_atom[bs_idx, x_atom, x_b, y_b, z_b] = 1 # atom type 245 | output_bb.zero_() 246 | output_bb[bs_idx, x_bb, x_b, y_b, z_b] = 1 # BB indicator 247 | output_res.zero_() 248 | output_res[bs_idx, x_res_type, x_b, y_b, z_b] = 1 # res type 249 | output = torch.cat([output_atom[:, :c], output_bb[:, :1], output_res[:, :21]], 1) 250 | 251 | X = output[:, :, 1:-1, 1:-1, 1:-1] 252 | 253 | X, y = X.float(), y.long() 254 | chi_angles = chi_angles.long() 255 | 256 | chi_1 = chi_angles[:, 0] 257 | chi_2 = chi_angles[:, 1] 258 | chi_3 = chi_angles[:, 2] 259 | chi_4 = chi_angles[:, 3] 260 | 261 | if use_cuda: 262 | y, y_onehot, chi_1, chi_2, chi_3, chi_4 = map(lambda x: x.cuda(), [y, y_onehot, chi_1, chi_2, chi_3, chi_4]) 263 | 264 | if bs_i < bs: 265 | y = F.pad(y, (0, bs - bs_i)) 266 | chi_1 = F.pad(chi_1, (0, bs - bs_i)) 267 | chi_2 = F.pad(chi_2, (0, bs - bs_i)) 268 | chi_3 = F.pad(chi_3, (0, bs - bs_i)) 269 | 270 | y_onehot.zero_() 271 | y_onehot.scatter_(1, y[:, None], 1) 272 | 273 | chi_1_onehot.zero_() 274 | chi_1_onehot.scatter_(1, chi_1[:, None], 1) 275 | 276 | chi_2_onehot.zero_() 277 | chi_2_onehot.scatter_(1, chi_2[:, None], 1) 278 | 279 | chi_3_onehot.zero_() 280 | chi_3_onehot.scatter_(1, chi_3[:, None], 1) 281 | 282 | # 0 index for chi indicates that chi is masked 283 | out, chi_1_pred, chi_2_pred, chi_3_pred, chi_4_pred = model(X[:bs_i], y_onehot[:bs_i], chi_1_onehot[:bs_i, 1:], chi_2_onehot[:bs_i, 1:], chi_3_onehot[:bs_i, 1:]) 284 | res_loss = criterion(out, y[:bs_i]) 285 | chi_1_loss = chi_1_criterion(chi_1_pred, chi_1[:bs_i] - 1) # , 1:]) 286 | chi_2_loss = chi_2_criterion(chi_2_pred, chi_2[:bs_i] - 1) # , 1:]) 287 | chi_3_loss = chi_3_criterion(chi_3_pred, chi_3[:bs_i] - 1) # , 1:]) 288 | chi_4_loss = chi_4_criterion(chi_4_pred, chi_4[:bs_i] - 1) # , 1:]) 289 | 290 | train_loss = res_loss + chi_1_loss + chi_2_loss + chi_3_loss + chi_4_loss 291 | train_loss.backward() 292 | optimizer.step() 293 | 294 | # acc 295 | train_acc, _ = acc_util.get_acc(out, y[:bs_i], cm=None) 296 | train_top_k_acc = acc_util.get_top_k_acc(out, y[:bs_i], k=3) 297 | train_coarse_acc, _ = acc_util.get_acc(out, y[:bs_i], label_dict=acc_util.label_coarse) 298 | train_polar_acc, _ = acc_util.get_acc(out, y[:bs_i], label_dict=acc_util.label_polar) 299 | 300 | chi_1_acc, _ = acc_util.get_acc(chi_1_pred, chi_1[:bs_i] - 1, ignore_idx=-1) 301 | chi_2_acc, _ = acc_util.get_acc(chi_2_pred, chi_2[:bs_i] - 1, ignore_idx=-1) 302 | chi_3_acc, _ = acc_util.get_acc(chi_3_pred, chi_3[:bs_i] - 1, ignore_idx=-1) 303 | chi_4_acc, _ = acc_util.get_acc(chi_4_pred, chi_4[:bs_i] - 1, ignore_idx=-1) 304 | 305 | # tensorboard logging 306 | map( 307 | lambda x: log.log_scalar("seq_chi_pred/%s" % x[0], x[1]), 308 | zip( 309 | ["res_loss", "chi_1_loss", "chi_2_loss", "chi_3_loss", "chi_4_loss", "train_acc", "chi_1_acc", "chi_2_acc", "chi_3_acc", "chi_4_acc", "train_top3_acc", "train_coarse_acc", "train_polar_acc"], 310 | [res_loss.item(), chi_1_loss.item(), chi_2_loss.item(), chi_3_loss.item(), chi_4_loss.item(), train_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc, train_top_k_acc, train_coarse_acc, train_polar_acc], 311 | ), 312 | ) 313 | 314 | if it % validation_frequency == 0 or it == len(train_dataloader) - 1: 315 | 316 | if it > 0: 317 | if torch.cuda.device_count() > 1 and args.cuda: 318 | torch.save(model.module.state_dict(), log.log_path + "/seq_chi_pred_curr_weights.pt") 319 | else: 320 | torch.save(model.state_dict(), log.log_path + "/seq_chi_pred_curr_weights.pt") 321 | torch.save(optimizer.state_dict(), log.log_path + "/seq_chi_pred_curr_optimizer.pt") 322 | 323 | # NOTE -- saving models for each validation step 324 | if it > 0 and (it % save_frequency == 0 or it == len(train_dataloader) - 1): 325 | if torch.cuda.device_count() > 1 and args.cuda: 326 | torch.save(model.module.state_dict(), log.log_path + "/seq_chi_pred_epoch_%0.3d_%s_weights.pt" % (epoch, it)) 327 | else: 328 | torch.save(model.state_dict(), log.log_path + "/seq_chi_pred_epoch_%0.3d_%s_weights.pt" % (epoch, it)) 329 | 330 | torch.save(optimizer.state_dict(), log.log_path + "/seq_chi_pred_epoch_%0.3d_%s_optimizer.pt" % (epoch, it)) 331 | 332 | ##NOTE -- turning back on model.eval() 333 | model.eval() 334 | # eval on the test set 335 | test_gen, curr_test_loss, test_chi_1_loss, test_chi_2_loss, test_chi_3_loss, test_chi_4_loss, curr_test_acc, curr_test_top_k_acc, coarse_acc, polar_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc = test( 336 | model, 337 | test_gen, 338 | test_dataloader, 339 | criterion, 340 | chi_1_criterion, 341 | chi_2_criterion, 342 | chi_3_criterion, 343 | chi_4_criterion, 344 | max_it=len(test_dataloader), 345 | n_iters=min(10, len(test_dataloader)), 346 | desc="test", 347 | batch_size=args.batchSize, 348 | use_cuda=use_cuda, 349 | ) 350 | 351 | map( 352 | lambda x: log.log_scalar("seq_chi_pred/%s" % x[0], x[1]), 353 | zip( 354 | [ 355 | "test_loss", 356 | "test_chi_1_loss", 357 | "test_chi_2_loss", 358 | "test_chi_3_loss", 359 | "test_chi_4_loss", 360 | "test_acc", 361 | "test_chi_1_acc", 362 | "test_chi_2_acc", 363 | "test_chi_3_acc", 364 | "test_chi_4_acc", 365 | "test_acc_top3", 366 | "test_coarse_acc", 367 | "test_polar_acc", 368 | ], 369 | [ 370 | curr_test_loss.item(), 371 | chi_1_loss.item(), 372 | chi_2_loss.item(), 373 | chi_3_loss.item(), 374 | chi_4_loss.item(), 375 | curr_test_acc.item(), 376 | chi_1_acc.item(), 377 | chi_2_acc.item(), 378 | chi_3_acc.item(), 379 | chi_4_acc.item(), 380 | curr_test_top_k_acc.item(), 381 | coarse_acc.item(), 382 | polar_acc.item(), 383 | ], 384 | ), 385 | ) 386 | 387 | model.train() 388 | 389 | log.advance_iteration() 390 | 391 | 392 | if __name__ == "__main__": 393 | main() 394 | -------------------------------------------------------------------------------- /train_autoreg_chi_baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from tqdm import tqdm 9 | import common.run_manager 10 | import seq_des.models as models 11 | import seq_des.util.voxelize as voxelize 12 | import glob 13 | import seq_des.util.canonicalize as canonicalize 14 | import pickle 15 | import seq_des.util.data as datasets 16 | from torch.utils import data 17 | import common.atoms 18 | import seq_des.util.acc_util as acc_util 19 | import subprocess as sp 20 | import time 21 | import torch.nn.functional as F 22 | 23 | """ script to train 3D CNN on local residue-centered environments -- BB only -- with autoregressive rotamer chi angle prediction""" 24 | 25 | dist = 10 26 | n = 20 27 | c = len(common.atoms.atoms) 28 | 29 | 30 | def test( 31 | model, gen, dataloader, criterion, chi_1_criterion, chi_2_criterion, chi_3_criterion, chi_4_criterion, max_it=1e6, desc="test", batch_size=64, n_iters=500, k=3, use_cuda=True, 32 | ): 33 | n_iters = min(max_it, n_iters) 34 | model = model.eval() 35 | gen = iter(dataloader) 36 | (losses, avg_acc, avg_top_k_acc, avg_coarse_acc, avg_polar_acc, avg_chi_1_acc, avg_chi_2_acc, avg_chi_3_acc, avg_chi_4_acc, avg_chi_1_loss, avg_chi_2_loss, avg_chi_3_loss, avg_chi_4_loss,) = ([] for i in range(13)) 37 | with torch.no_grad(): 38 | 39 | for i in tqdm(range(n_iters), desc=desc): 40 | try: 41 | out = gen.next() 42 | except StopIteration: 43 | gen = iter(dataloader) 44 | out = gen.next() 45 | 46 | out = step(model, out, criterion, chi_1_criterion, chi_2_criterion, chi_3_criterion, chi_4_criterion, use_cuda=use_cuda) 47 | 48 | if out is None: 49 | continue 50 | (loss, chi_1_loss, chi_2_loss, chi_3_loss, chi_4_loss, out, y, acc, top_k_acc, coarse_acc, polar_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc,) = out 51 | 52 | # append losses, accs to lists 53 | for x, y in zip( 54 | [losses, avg_acc, avg_top_k_acc, avg_coarse_acc, avg_polar_acc, avg_chi_1_acc, avg_chi_2_acc, avg_chi_3_acc, avg_chi_4_acc, avg_chi_1_loss, avg_chi_2_loss, avg_chi_3_loss, avg_chi_4_loss,], 55 | [loss.item(), acc, top_k_acc, coarse_acc, polar_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc, chi_1_loss.item(), chi_2_loss.item(), chi_3_loss.item(), chi_4_loss.item(),], 56 | ): 57 | x.append(y) 58 | 59 | del ( 60 | loss, 61 | chi_1_loss, 62 | chi_2_loss, 63 | chi_3_loss, 64 | chi_4_loss, 65 | out, 66 | y, 67 | acc, 68 | top_k_acc, 69 | coarse_acc, 70 | polar_acc, 71 | chi_1_acc, 72 | chi_2_acc, 73 | chi_3_acc, 74 | chi_4_acc, 75 | ) 76 | 77 | print( 78 | "\nloss", np.mean(losses), "acc", np.mean(avg_acc), "top3", np.mean(avg_top_k_acc), "coarse", np.mean(avg_coarse_acc), "polar", np.mean(avg_polar_acc), 79 | ) 80 | 81 | return ( 82 | gen, 83 | np.mean(losses), 84 | np.mean(avg_chi_1_loss), 85 | np.mean(avg_chi_2_loss), 86 | np.mean(avg_chi_3_loss), 87 | np.mean(avg_chi_4_loss), 88 | np.mean(avg_acc), 89 | np.mean(avg_top_k_acc), 90 | np.mean(avg_coarse_acc), 91 | np.mean(avg_polar_acc), 92 | np.mean(avg_chi_1_acc), 93 | np.mean(avg_chi_2_acc), 94 | np.mean(avg_chi_3_acc), 95 | np.mean(avg_chi_4_acc), 96 | ) 97 | 98 | 99 | def step(model, out, criterion, chi_1_criterion, chi_2_criterion, chi_3_criterion, chi_4_criterion, k=3, use_cuda=True): 100 | 101 | (bs_idx, x_atom, x_bb, x_b, y_b, z_b, x_res_type, y, chi_angles_real, chi_angles,) = out 102 | 103 | bs = len(bs_idx) 104 | output_atom = torch.zeros((bs, c + 1, n + 2, n + 2, n + 2)) 105 | output_atom[bs_idx, x_atom, x_b, y_b, z_b] = 1 106 | 107 | if use_cuda: 108 | output_atom = output_atom.cuda() 109 | 110 | X = output_atom[:, :c, 1:-1, 1:-1, 1:-1] 111 | 112 | if X is None: 113 | return None 114 | 115 | X, y = X.float(), y.long() 116 | chi_angles = chi_angles.long() 117 | 118 | chi_1 = chi_angles[:, 0] 119 | chi_2 = chi_angles[:, 1] 120 | chi_3 = chi_angles[:, 2] 121 | chi_4 = chi_angles[:, 3] 122 | 123 | y_onehot = torch.FloatTensor(y.size()[0], 20) 124 | y_onehot.zero_() 125 | y_onehot.scatter_(1, y[:, None], 1) 126 | 127 | chi_1_onehot = torch.FloatTensor(chi_1.size()[0], len(datasets.CHI_BINS)) 128 | chi_1_onehot.zero_() 129 | chi_1_onehot.scatter_(1, chi_1[:, None], 1) 130 | 131 | chi_2_onehot = torch.FloatTensor(chi_2.size()[0], len(datasets.CHI_BINS)) 132 | chi_2_onehot.zero_() 133 | chi_2_onehot.scatter_(1, chi_2[:, None], 1) 134 | 135 | chi_3_onehot = torch.FloatTensor(chi_3.size()[0], len(datasets.CHI_BINS)) 136 | chi_3_onehot.zero_() 137 | chi_3_onehot.scatter_(1, chi_3[:, None], 1) 138 | 139 | if use_cuda: 140 | (X, y, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot, chi_1, chi_2, chi_3, chi_4,) = map(lambda x: x.cuda(), [X, y, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot, chi_1, chi_2, chi_3, chi_4,],) 141 | 142 | out, chi_1_pred, chi_2_pred, chi_3_pred, chi_4_pred = model(X, y_onehot, chi_1_onehot[:, 1:], chi_2_onehot[:, 1:], chi_3_onehot[:, 1:]) 143 | # loss 144 | loss = criterion(out, y) 145 | chi_1_loss = chi_1_criterion(chi_1_pred, chi_1 - 1) 146 | chi_2_loss = chi_2_criterion(chi_2_pred, chi_2 - 1) 147 | chi_3_loss = chi_3_criterion(chi_3_pred, chi_3 - 1) 148 | chi_4_loss = chi_4_criterion(chi_4_pred, chi_4 - 1) 149 | 150 | # acc 151 | acc, _ = acc_util.get_acc(out, y) 152 | top_k_acc = acc_util.get_top_k_acc(out, y, k=k) 153 | coarse_acc, _ = acc_util.get_acc(out, y, label_dict=acc_util.label_coarse) 154 | polar_acc, _ = acc_util.get_acc(out, y, label_dict=acc_util.label_polar) 155 | chi_1_acc, _ = acc_util.get_acc(chi_1_pred, chi_1 - 1, ignore_idx=-1) 156 | chi_2_acc, _ = acc_util.get_acc(chi_2_pred, chi_2 - 1, ignore_idx=-1) 157 | chi_3_acc, _ = acc_util.get_acc(chi_3_pred, chi_3 - 1, ignore_idx=-1) 158 | chi_4_acc, _ = acc_util.get_acc(chi_4_pred, chi_4 - 1, ignore_idx=-1) 159 | 160 | return ( 161 | loss, 162 | chi_1_loss, 163 | chi_2_loss, 164 | chi_3_loss, 165 | chi_4_loss, 166 | out, 167 | y, 168 | acc, 169 | top_k_acc, 170 | coarse_acc, 171 | polar_acc, 172 | chi_1_acc, 173 | chi_2_acc, 174 | chi_3_acc, 175 | chi_4_acc, 176 | ) 177 | 178 | 179 | def step_iter(gen, dataloader): 180 | try: 181 | out = gen.next() 182 | except StopIteration: 183 | gen = iter(dataloader) 184 | out = gen.next() 185 | return gen, out 186 | 187 | 188 | def main(): 189 | 190 | manager = common.run_manager.RunManager() 191 | 192 | manager.parse_args() 193 | args = manager.args 194 | log = manager.log 195 | 196 | use_cuda = torch.cuda.is_available() and args.cuda 197 | 198 | # set up model 199 | model = models.seqPred(nic=len(common.atoms.atoms), nf=args.nf, momentum=args.momentum) 200 | model.apply(models.init_ortho_weights) 201 | if use_cuda: 202 | model.cuda() 203 | else: 204 | print("Training model on CPU") 205 | 206 | # parallelize over available GPUs 207 | if torch.cuda.device_count() > 1 and args.cuda: 208 | print("using", torch.cuda.device_count(), "GPUs") 209 | model = nn.DataParallel(model) 210 | 211 | if args.model != "": 212 | # load pretrained model 213 | model.load_state_dict(torch.load(args.model)) 214 | print("loaded pretrained model") 215 | 216 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta1, 0.999), weight_decay=args.reg) 217 | 218 | if args.optimizer != "": 219 | # load pretrained optimizer 220 | optimizer.load_state_dict(torch.load(args.optimizer)) 221 | print("loaded pretrained optimizer") 222 | 223 | # load pretrained model weights / optimizer state 224 | 225 | chi_1_criterion = nn.CrossEntropyLoss(ignore_index=-1) 226 | chi_2_criterion = nn.CrossEntropyLoss(ignore_index=-1) 227 | chi_3_criterion = nn.CrossEntropyLoss(ignore_index=-1) 228 | chi_4_criterion = nn.CrossEntropyLoss(ignore_index=-1) 229 | criterion = nn.CrossEntropyLoss() 230 | if use_cuda: 231 | criterion.cuda() 232 | chi_1_criterion.cuda() 233 | chi_2_criterion.cuda() 234 | chi_3_criterion.cuda() 235 | chi_4_criterion.cuda() 236 | 237 | train_dataset = datasets.PDB_data_spitter(data_dir=args.data_dir + "/train_s95_chi_bb") 238 | train_dataset.len = 8145448 # NOTE -- need to update this if underlying data changes 239 | 240 | test_dataset = datasets.PDB_data_spitter(data_dir=args.data_dir + "/test_s95_chi_bb") 241 | test_dataset.len = 574267 # NOTE -- need to update this if underlying data changes 242 | 243 | train_dataloader = data.DataLoader(train_dataset, batch_size=args.batchSize, shuffle=False, num_workers=args.workers, pin_memory=True, collate_fn=datasets.collate_wrapper,) 244 | test_dataloader = data.DataLoader(test_dataset, batch_size=args.batchSize, shuffle=False, num_workers=args.workers, pin_memory=True, collate_fn=datasets.collate_wrapper,) 245 | 246 | # training params 247 | validation_frequency = args.validation_frequency 248 | save_frequency = args.save_frequency 249 | 250 | """ TRAIN """ 251 | 252 | model.train() 253 | gen = iter(train_dataloader) 254 | test_gen = iter(test_dataloader) 255 | bs = args.batchSize 256 | output_atom = torch.zeros((bs, c + 1, n + 2, n + 2, n + 2)) 257 | y_onehot = torch.FloatTensor(bs, 20) 258 | chi_1_onehot = torch.FloatTensor(bs, len(datasets.CHI_BINS)) 259 | chi_2_onehot = torch.FloatTensor(bs, len(datasets.CHI_BINS)) 260 | chi_3_onehot = torch.FloatTensor(bs, len(datasets.CHI_BINS)) 261 | 262 | if use_cuda: 263 | output_atom, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot = map(lambda x: x.cuda(), [output_atom, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot]) 264 | 265 | for epoch in range(args.epochs): 266 | for it in tqdm(range(len(train_dataloader)), desc="training epoch %0.2d" % epoch): 267 | 268 | gen, out = step_iter(gen, train_dataloader) 269 | (bs_idx, x_atom, x_bb, x_b, y_b, z_b, x_res_type, y, chi_angles_real, chi_angles,) = out 270 | bs_i = len(bs_idx) 271 | output_atom.zero_() 272 | output_atom[bs_idx, x_atom, x_b, y_b, z_b] = 1 # atom type 273 | X = output_atom[:, :c, 1:-1, 1:-1, 1:-1] 274 | 275 | X, y = X.float(), y.long() 276 | chi_angles = chi_angles.long() 277 | 278 | chi_1 = chi_angles[:, 0] 279 | chi_2 = chi_angles[:, 1] 280 | chi_3 = chi_angles[:, 2] 281 | chi_4 = chi_angles[:, 3] 282 | 283 | if use_cuda: 284 | y, y_onehot, chi_1, chi_2, chi_3, chi_4 = map(lambda x: x.cuda(), [y, y_onehot, chi_1, chi_2, chi_3, chi_4]) 285 | 286 | if bs_i < bs: 287 | y = F.pad(y, (0, bs - bs_i)) 288 | chi_1 = F.pad(chi_1, (0, bs - bs_i)) 289 | chi_2 = F.pad(chi_2, (0, bs - bs_i)) 290 | chi_3 = F.pad(chi_3, (0, bs - bs_i)) 291 | 292 | y_onehot.zero_() 293 | y_onehot.scatter_(1, y[:, None], 1) 294 | 295 | chi_1_onehot.zero_() 296 | chi_1_onehot.scatter_(1, chi_1[:, None], 1) 297 | 298 | chi_2_onehot.zero_() 299 | chi_2_onehot.scatter_(1, chi_2[:, None], 1) 300 | 301 | chi_3_onehot.zero_() 302 | chi_3_onehot.scatter_(1, chi_3[:, None], 1) 303 | 304 | out, chi_1_pred, chi_2_pred, chi_3_pred, chi_4_pred = model(X[:bs_i], y_onehot[:bs_i], chi_1_onehot[:bs_i, 1:], chi_2_onehot[:bs_i, 1:], chi_3_onehot[:bs_i, 1:]) 305 | res_loss = criterion(out, y[:bs_i]) 306 | chi_1_loss = chi_1_criterion(chi_1_pred, chi_1[:bs_i] - 1) 307 | chi_2_loss = chi_2_criterion(chi_2_pred, chi_2[:bs_i] - 1) 308 | chi_3_loss = chi_3_criterion(chi_3_pred, chi_3[:bs_i] - 1) 309 | chi_4_loss = chi_4_criterion(chi_4_pred, chi_4[:bs_i] - 1) 310 | 311 | train_loss = res_loss + chi_1_loss + chi_2_loss + chi_3_loss + chi_4_loss 312 | train_loss.backward() 313 | optimizer.step() 314 | 315 | # acc 316 | train_acc, _ = acc_util.get_acc(out, y[:bs_i], cm=None) 317 | train_top_k_acc = acc_util.get_top_k_acc(out, y[:bs_i], k=3) 318 | train_coarse_acc, _ = acc_util.get_acc(out, y[:bs_i], label_dict=acc_util.label_coarse) 319 | train_polar_acc, _ = acc_util.get_acc(out, y[:bs_i], label_dict=acc_util.label_polar) 320 | 321 | chi_1_acc, _ = acc_util.get_acc(chi_1_pred, chi_1[:bs_i] - 1, ignore_idx=-1) 322 | chi_2_acc, _ = acc_util.get_acc(chi_2_pred, chi_2[:bs_i] - 1, ignore_idx=-1) 323 | chi_3_acc, _ = acc_util.get_acc(chi_3_pred, chi_3[:bs_i] - 1, ignore_idx=-1) 324 | chi_4_acc, _ = acc_util.get_acc(chi_4_pred, chi_4[:bs_i] - 1, ignore_idx=-1) 325 | 326 | # tensorboard logging 327 | map( 328 | lambda x: log.log_scalar("seq_chi_pred/%s" % x[0], x[1]), 329 | zip( 330 | ["res_loss", "chi_1_loss", "chi_2_loss", "chi_3_loss", "chi_4_loss", "train_acc", "chi_1_acc", "chi_2_acc", "chi_3_acc", "chi_4_acc", "train_top3_acc", "train_coarse_acc", "train_polar_acc",], 331 | [res_loss.item(), chi_1_loss.item(), chi_2_loss.item(), chi_3_loss.item(), chi_4_loss.item(), train_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc, train_top_k_acc, train_coarse_acc, train_polar_acc,], 332 | ), 333 | ) 334 | 335 | if it % validation_frequency == 0 or it == len(train_dataloader) - 1: 336 | 337 | if it > 0: 338 | if torch.cuda.device_count() > 1 and args.cuda: 339 | torch.save( 340 | model.module.state_dict(), log.log_path + "/seq_chi_pred_baseline_curr_weights.pt", 341 | ) 342 | else: 343 | torch.save( 344 | model.state_dict(), log.log_path + "/seq_chi_pred_baseline_curr_weights.pt", 345 | ) 346 | torch.save( 347 | optimizer.state_dict(), log.log_path + "/seq_chi_pred_baseline_curr_optimizer.pt", 348 | ) 349 | 350 | # NOTE -- saving models for each validation step 351 | if it > 0 and (it % save_frequency == 0 or it == len(train_dataloader) - 1): 352 | if torch.cuda.device_count() > 1 and args.cuda: 353 | torch.save( 354 | model.module.state_dict(), log.log_path + "/seq_chi_pred_baseline_epoch_%0.3d_%s_weights.pt" % (epoch, it), 355 | ) 356 | else: 357 | torch.save( 358 | model.state_dict(), log.log_path + "/seq_chi_pred_baseline_epoch_%0.3d_%s_weights.pt" % (epoch, it), 359 | ) 360 | 361 | torch.save( 362 | optimizer.state_dict(), log.log_path + "/seq_chi_pred_baseline_epoch_%0.3d_%s_optimizer.pt" % (epoch, it), 363 | ) 364 | 365 | ##NOTE -- turning back on model.eval() 366 | model.eval() 367 | # eval on the test set 368 | (test_gen, curr_test_loss, test_chi_1_loss, test_chi_2_loss, test_chi_3_loss, test_chi_4_loss, curr_test_acc, curr_test_top_k_acc, coarse_acc, polar_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc,) = test( 369 | model, 370 | test_gen, 371 | test_dataloader, 372 | criterion, 373 | chi_1_criterion, 374 | chi_2_criterion, 375 | chi_3_criterion, 376 | chi_4_criterion, 377 | max_it=len(test_dataloader), 378 | n_iters=min(10, len(test_dataloader)), 379 | desc="test", 380 | batch_size=args.batchSize, 381 | use_cuda=use_cuda, 382 | ) 383 | 384 | map( 385 | lambda x: log.log_scalar("seq_chi_pred/%s" % x[0], x[1]), 386 | zip( 387 | [ 388 | "test_loss", 389 | "test_chi_1_loss", 390 | "test_chi_2_loss", 391 | "test_chi_3_loss", 392 | "test_chi_4_loss", 393 | "test_acc", 394 | "test_chi_1_acc", 395 | "test_chi_2_acc", 396 | "test_chi_3_acc", 397 | "test_chi_4_acc", 398 | "test_acc_top3", 399 | "test_coarse_acc", 400 | "test_polar_acc", 401 | ], 402 | [ 403 | curr_test_loss.item(), 404 | chi_1_loss.item(), 405 | chi_2_loss.item(), 406 | chi_3_loss.item(), 407 | chi_4_loss.item(), 408 | curr_test_acc.item(), 409 | chi_1_acc.item(), 410 | chi_2_acc.item(), 411 | chi_3_acc.item(), 412 | chi_4_acc.item(), 413 | curr_test_top_k_acc.item(), 414 | coarse_acc.item(), 415 | polar_acc.item(), 416 | ], 417 | ), 418 | ) 419 | 420 | model.train() 421 | 422 | log.advance_iteration() 423 | 424 | 425 | if __name__ == "__main__": 426 | main() 427 | -------------------------------------------------------------------------------- /txt/resfiles/NATRO_all.txt: -------------------------------------------------------------------------------- 1 | 1 - 90 NATRO # set all residues in 3mx7_gt.pdb to NATIVE ROTAMERS (skips designing at all) 2 | -------------------------------------------------------------------------------- /txt/resfiles/PIKAA_all_one_AA.txt: -------------------------------------------------------------------------------- 1 | 1 - 90 PIKAA C 2 | -------------------------------------------------------------------------------- /txt/resfiles/full_example.txt: -------------------------------------------------------------------------------- 1 | APOLAR 2 | start 3 | 65 POLAR 4 | 34 PIKAA EHKNRQDST 5 | 35 - 40 PIKAA EHK 6 | 41 - 45 POLAR 7 | 80 - 85 NATRO 8 | 90 NATRO 9 | -------------------------------------------------------------------------------- /txt/resfiles/generate_resfile.py: -------------------------------------------------------------------------------- 1 | # import sys 2 | 3 | # if len(sys.argv) == 0: 4 | # print("Please provide a path to a .fasta file with a sequence") 5 | # sys.exit() 6 | 7 | # def get_sequence(path): 8 | # """ 9 | # Get a sequence from a FASTA file with an initial sequence for the Protein Sequence Design Algorithm 10 | # """ 11 | # sequence = "" 12 | # with open(path, "r") as f: 13 | # lines = f.readlines() 14 | # print(lines) 15 | # sequence = lines[1] + lines[0] 16 | 17 | # print(sequence) 18 | 19 | sequence = "TMPSTYAFKLPIQTETGVARVRSVIKKVSLTLSAYQVDYLLNTATVTSPVAWADMVDGVQAAGVEIQYGQFF" 20 | sequence = list(sequence) 21 | 22 | with open("init_seq_1cc8_gt.txt", "w") as file1: 23 | for i in range(1, len(sequence) + 1): 24 | # command = " ".join([str(i), "TPIKAA", sequence[i], "\n"] 25 | file1.write("{} TPIKAA {} \n".format(str(i), sequence[i-1])) 26 | 27 | # get_sequence("../../../sequenced_results/1bkr_gt_sequenced/init_seq.fasta") 28 | -------------------------------------------------------------------------------- /txt/resfiles/init_seq_1acf_gt.txt: -------------------------------------------------------------------------------- 1 | 1 TPIKAA A 2 | 2 TPIKAA R 3 | 3 TPIKAA E 4 | 4 TPIKAA T 5 | 5 TPIKAA W 6 | 6 TPIKAA V 7 | 7 TPIKAA D 8 | 8 TPIKAA D 9 | 9 TPIKAA L 10 | 10 TPIKAA M 11 | 11 TPIKAA C 12 | 12 TPIKAA S 13 | 13 TPIKAA T 14 | 14 TPIKAA G 15 | 15 TPIKAA A 16 | 16 TPIKAA V 17 | 17 TPIKAA R 18 | 18 TPIKAA K 19 | 19 TPIKAA C 20 | 20 TPIKAA A 21 | 21 TPIKAA L 22 | 22 TPIKAA V 23 | 23 TPIKAA G 24 | 24 TPIKAA P 25 | 25 TPIKAA A 26 | 26 TPIKAA G 27 | 27 TPIKAA N 28 | 28 TPIKAA V 29 | 29 TPIKAA Y 30 | 30 TPIKAA A 31 | 31 TPIKAA Q 32 | 32 TPIKAA A 33 | 33 TPIKAA P 34 | 34 TPIKAA G 35 | 35 TPIKAA Y 36 | 36 TPIKAA E 37 | 37 TPIKAA V 38 | 38 TPIKAA S 39 | 39 TPIKAA D 40 | 40 TPIKAA R 41 | 41 TPIKAA Q 42 | 42 TPIKAA G 43 | 43 TPIKAA E 44 | 44 TPIKAA L 45 | 45 TPIKAA V 46 | 46 TPIKAA A 47 | 47 TPIKAA D 48 | 48 TPIKAA G 49 | 49 TPIKAA L 50 | 50 TPIKAA K 51 | 51 TPIKAA K 52 | 52 TPIKAA P 53 | 53 TPIKAA R 54 | 54 TPIKAA G 55 | 55 TPIKAA V 56 | 56 TPIKAA S 57 | 57 TPIKAA S 58 | 58 TPIKAA S 59 | 59 TPIKAA T 60 | 60 TPIKAA F 61 | 61 TPIKAA G 62 | 62 TPIKAA L 63 | 63 TPIKAA D 64 | 64 TPIKAA G 65 | 65 TPIKAA M 66 | 66 TPIKAA R 67 | 67 TPIKAA F 68 | 68 TPIKAA D 69 | 69 TPIKAA V 70 | 70 TPIKAA L 71 | 71 TPIKAA D 72 | 72 TPIKAA T 73 | 73 TPIKAA S 74 | 74 TPIKAA D 75 | 75 TPIKAA R 76 | 76 TPIKAA S 77 | 77 TPIKAA L 78 | 78 TPIKAA F 79 | 79 TPIKAA A 80 | 80 TPIKAA N 81 | 81 TPIKAA L 82 | 82 TPIKAA D 83 | 83 TPIKAA L 84 | 84 TPIKAA H 85 | 85 TPIKAA G 86 | 86 TPIKAA V 87 | 87 TPIKAA L 88 | 88 TPIKAA C 89 | 89 TPIKAA V 90 | 90 TPIKAA F 91 | 91 TPIKAA T 92 | 92 TPIKAA L 93 | 93 TPIKAA K 94 | 94 TPIKAA S 95 | 95 TPIKAA I 96 | 96 TPIKAA I 97 | 97 TPIKAA V 98 | 98 TPIKAA G 99 | 99 TPIKAA S 100 | 100 TPIKAA L 101 | 101 TPIKAA S 102 | 102 TPIKAA G 103 | 103 TPIKAA D 104 | 104 TPIKAA M 105 | 105 TPIKAA A 106 | 106 TPIKAA A 107 | 107 TPIKAA A 108 | 108 TPIKAA M 109 | 109 TPIKAA A 110 | 110 TPIKAA A 111 | 111 TPIKAA Q 112 | 112 TPIKAA L 113 | 113 TPIKAA V 114 | 114 TPIKAA E 115 | 115 TPIKAA G 116 | 116 TPIKAA L 117 | 117 TPIKAA A 118 | 118 TPIKAA E 119 | 119 TPIKAA A 120 | 120 TPIKAA L 121 | 121 TPIKAA M 122 | 122 TPIKAA V 123 | 123 TPIKAA Y 124 | 124 TPIKAA G 125 | 125 TPIKAA E 126 | -------------------------------------------------------------------------------- /txt/resfiles/init_seq_1bkr_gt.txt: -------------------------------------------------------------------------------- 1 | 1 TPIKAA E 2 | 2 TPIKAA I 3 | 3 TPIKAA R 4 | 4 TPIKAA K 5 | 5 TPIKAA Q 6 | 6 TPIKAA R 7 | 7 TPIKAA F 8 | 8 TPIKAA F 9 | 9 TPIKAA D 10 | 10 TPIKAA F 11 | 11 TPIKAA C 12 | 12 TPIKAA R 13 | 13 TPIKAA K 14 | 14 TPIKAA V 15 | 15 TPIKAA T 16 | 16 TPIKAA A 17 | 17 TPIKAA G 18 | 18 TPIKAA W 19 | 19 TPIKAA Q 20 | 20 TPIKAA N 21 | 21 TPIKAA V 22 | 22 TPIKAA N 23 | 23 TPIKAA L 24 | 24 TPIKAA T 25 | 25 TPIKAA D 26 | 26 TPIKAA F 27 | 27 TPIKAA A 28 | 28 TPIKAA S 29 | 29 TPIKAA N 30 | 30 TPIKAA F 31 | 31 TPIKAA R 32 | 32 TPIKAA H 33 | 33 TPIKAA G 34 | 34 TPIKAA F 35 | 35 TPIKAA C 36 | 36 TPIKAA F 37 | 37 TPIKAA Q 38 | 38 TPIKAA A 39 | 39 TPIKAA L 40 | 40 TPIKAA I 41 | 41 TPIKAA Q 42 | 42 TPIKAA K 43 | 43 TPIKAA V 44 | 44 TPIKAA V 45 | 45 TPIKAA P 46 | 46 TPIKAA E 47 | 47 TPIKAA L 48 | 48 TPIKAA F 49 | 49 TPIKAA N 50 | 50 TPIKAA F 51 | 51 TPIKAA S 52 | 52 TPIKAA D 53 | 53 TPIKAA M 54 | 54 TPIKAA K 55 | 55 TPIKAA K 56 | 56 TPIKAA E 57 | 57 TPIKAA E 58 | 58 TPIKAA P 59 | 59 TPIKAA K 60 | 60 TPIKAA T 61 | 61 TPIKAA N 62 | 62 TPIKAA L 63 | 63 TPIKAA E 64 | 64 TPIKAA N 65 | 65 TPIKAA A 66 | 66 TPIKAA F 67 | 67 TPIKAA K 68 | 68 TPIKAA Y 69 | 69 TPIKAA A 70 | 70 TPIKAA Q 71 | 71 TPIKAA R 72 | 72 TPIKAA K 73 | 73 TPIKAA L 74 | 74 TPIKAA G 75 | 75 TPIKAA I 76 | 76 TPIKAA P 77 | 77 TPIKAA E 78 | 78 TPIKAA I 79 | 79 TPIKAA I 80 | 80 TPIKAA K 81 | 81 TPIKAA P 82 | 82 TPIKAA A 83 | 83 TPIKAA E 84 | 84 TPIKAA V 85 | 85 TPIKAA A 86 | 86 TPIKAA Q 87 | 87 TPIKAA E 88 | 88 TPIKAA G 89 | 89 TPIKAA P 90 | 90 TPIKAA S 91 | 91 TPIKAA E 92 | 92 TPIKAA A 93 | 93 TPIKAA D 94 | 94 TPIKAA V 95 | 95 TPIKAA L 96 | 96 TPIKAA Q 97 | 97 TPIKAA W 98 | 98 TPIKAA V 99 | 99 TPIKAA M 100 | 100 TPIKAA T 101 | 101 TPIKAA F 102 | 102 TPIKAA L 103 | 103 TPIKAA Q 104 | 104 TPIKAA Y 105 | 105 TPIKAA L 106 | 106 TPIKAA A 107 | 107 TPIKAA S 108 | 108 TPIKAA M 109 | -------------------------------------------------------------------------------- /txt/resfiles/init_seq_1cc8_gt.txt: -------------------------------------------------------------------------------- 1 | 1 TPIKAA T 2 | 2 TPIKAA M 3 | 3 TPIKAA P 4 | 4 TPIKAA S 5 | 5 TPIKAA T 6 | 6 TPIKAA Y 7 | 7 TPIKAA A 8 | 8 TPIKAA F 9 | 9 TPIKAA K 10 | 10 TPIKAA L 11 | 11 TPIKAA P 12 | 12 TPIKAA I 13 | 13 TPIKAA Q 14 | 14 TPIKAA T 15 | 15 TPIKAA E 16 | 16 TPIKAA T 17 | 17 TPIKAA G 18 | 18 TPIKAA V 19 | 19 TPIKAA A 20 | 20 TPIKAA R 21 | 21 TPIKAA V 22 | 22 TPIKAA R 23 | 23 TPIKAA S 24 | 24 TPIKAA V 25 | 25 TPIKAA I 26 | 26 TPIKAA K 27 | 27 TPIKAA K 28 | 28 TPIKAA V 29 | 29 TPIKAA S 30 | 30 TPIKAA L 31 | 31 TPIKAA T 32 | 32 TPIKAA L 33 | 33 TPIKAA S 34 | 34 TPIKAA A 35 | 35 TPIKAA Y 36 | 36 TPIKAA Q 37 | 37 TPIKAA V 38 | 38 TPIKAA D 39 | 39 TPIKAA Y 40 | 40 TPIKAA L 41 | 41 TPIKAA L 42 | 42 TPIKAA N 43 | 43 TPIKAA T 44 | 44 TPIKAA A 45 | 45 TPIKAA T 46 | 46 TPIKAA V 47 | 47 TPIKAA T 48 | 48 TPIKAA S 49 | 49 TPIKAA P 50 | 50 TPIKAA V 51 | 51 TPIKAA A 52 | 52 TPIKAA W 53 | 53 TPIKAA A 54 | 54 TPIKAA D 55 | 55 TPIKAA M 56 | 56 TPIKAA V 57 | 57 TPIKAA D 58 | 58 TPIKAA G 59 | 59 TPIKAA V 60 | 60 TPIKAA Q 61 | 61 TPIKAA A 62 | 62 TPIKAA A 63 | 63 TPIKAA G 64 | 64 TPIKAA V 65 | 65 TPIKAA E 66 | 66 TPIKAA I 67 | 67 TPIKAA Q 68 | 68 TPIKAA Y 69 | 69 TPIKAA G 70 | 70 TPIKAA Q 71 | 71 TPIKAA F 72 | 72 TPIKAA F 73 | -------------------------------------------------------------------------------- /txt/resfiles/init_seq_3mx7_gt.txt: -------------------------------------------------------------------------------- 1 | 1 TPIKAA F 2 | 2 TPIKAA F 3 | 3 TPIKAA N 4 | 4 TPIKAA L 5 | 5 TPIKAA V 6 | 6 TPIKAA G 7 | 7 TPIKAA V 8 | 8 TPIKAA W 9 | 9 TPIKAA E 10 | 10 TPIKAA V 11 | 11 TPIKAA D 12 | 12 TPIKAA L 13 | 13 TPIKAA S 14 | 14 TPIKAA D 15 | 15 TPIKAA G 16 | 16 TPIKAA S 17 | 17 TPIKAA H 18 | 18 TPIKAA R 19 | 19 TPIKAA I 20 | 20 TPIKAA V 21 | 21 TPIKAA F 22 | 22 TPIKAA Q 23 | 23 TPIKAA E 24 | 24 TPIKAA E 25 | 25 TPIKAA E 26 | 26 TPIKAA A 27 | 27 TPIKAA A 28 | 28 TPIKAA G 29 | 29 TPIKAA R 30 | 30 TPIKAA R 31 | 31 TPIKAA S 32 | 32 TPIKAA I 33 | 33 TPIKAA Y 34 | 34 TPIKAA C 35 | 35 TPIKAA D 36 | 36 TPIKAA D 37 | 37 TPIKAA H 38 | 38 TPIKAA E 39 | 39 TPIKAA I 40 | 40 TPIKAA Y 41 | 41 TPIKAA R 42 | 42 TPIKAA Q 43 | 43 TPIKAA D 44 | 44 TPIKAA N 45 | 45 TPIKAA V 46 | 46 TPIKAA P 47 | 47 TPIKAA L 48 | 48 TPIKAA L 49 | 49 TPIKAA R 50 | 50 TPIKAA S 51 | 51 TPIKAA Y 52 | 52 TPIKAA Q 53 | 53 TPIKAA V 54 | 54 TPIKAA L 55 | 55 TPIKAA P 56 | 56 TPIKAA L 57 | 57 TPIKAA S 58 | 58 TPIKAA K 59 | 59 TPIKAA G 60 | 60 TPIKAA R 61 | 61 TPIKAA V 62 | 62 TPIKAA S 63 | 63 TPIKAA G 64 | 64 TPIKAA F 65 | 65 TPIKAA M 66 | 66 TPIKAA E 67 | 67 TPIKAA I 68 | 68 TPIKAA T 69 | 69 TPIKAA P 70 | 70 TPIKAA Q 71 | 71 TPIKAA K 72 | 72 TPIKAA A 73 | 73 TPIKAA G 74 | 74 TPIKAA D 75 | 75 TPIKAA Y 76 | 76 TPIKAA R 77 | 77 TPIKAA Y 78 | 78 TPIKAA S 79 | 79 TPIKAA F 80 | 80 TPIKAA C 81 | 81 TPIKAA I 82 | 82 TPIKAA N 83 | 83 TPIKAA G 84 | 84 TPIKAA Q 85 | 85 TPIKAA Q 86 | 86 TPIKAA R 87 | 87 TPIKAA I 88 | 88 TPIKAA I 89 | 89 TPIKAA G 90 | 90 TPIKAA K 91 | -------------------------------------------------------------------------------- /txt/resfiles/resfile_1acf_gt_ex8.txt: -------------------------------------------------------------------------------- 1 | 97 POLAR 2 | 89 POLAR 3 | 22 POLAR 4 | 79 POLAR 5 | -------------------------------------------------------------------------------- /txt/resfiles/resfile_1bkr_gt_ex6.txt: -------------------------------------------------------------------------------- 1 | 37 POLAR 2 | 98 POLAR 3 | 31 POLAR 4 | 40 POLAR 5 | -------------------------------------------------------------------------------- /txt/resfiles/resfile_3mx7_gt_ex1.txt: -------------------------------------------------------------------------------- 1 | 65 POLAR 2 | 21 POLAR 3 | 32 POLAR 4 | -------------------------------------------------------------------------------- /txt/resfiles/resfile_3mx7_gt_ex2.txt: -------------------------------------------------------------------------------- 1 | 65 POLAR 2 | 21 POLAR 3 | 32 POLAR 4 | 52 POLAR # previously APOLAR (V) 5 | 79 POLAR # previously APOLAR (Y) 6 | -------------------------------------------------------------------------------- /txt/resfiles/some_PIKAA_one.txt: -------------------------------------------------------------------------------- 1 | 34 - 36 PIKAA C 2 | 30 - 33 POLAR 3 | -------------------------------------------------------------------------------- /txt/resfiles/testing_TPIKAA_TNOTAA.txt: -------------------------------------------------------------------------------- 1 | 30 - 40 PIKAA CDAK 2 | 41 PIKAA A 3 | 41 TPIKAA K 4 | 45 TPIKAA D 5 | 46 TNOTAA HKRDESTNQAVLIMFYWPG # has to be C 6 | -------------------------------------------------------------------------------- /txt/test_idx.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | 2 4 | 3 5 | --------------------------------------------------------------------------------