├── .DS_Store
├── LICENSE
├── README.md
├── SETUP.md
├── common
├── atoms.py
├── logger.py
└── run_manager.py
├── imgs
├── ex2_results.png
├── ex3_results.png
├── ex5_results.png
├── ex6_results.png
└── tim.gif
├── load_and_save_bb_coords.py
├── load_and_save_coords.py
├── pdbs
├── 1acf_gt.pdb
├── 1acf_gt_crelax.pdb
├── 1bkr_gt.pdb
├── 1bkr_gt_crelax.pdb
├── 1cc8_gt.pdb
├── 1cc8_gt_crelax.pdb
├── 3mx7_gt.pdb
├── 3mx7_gt_crelax.pdb
└── tim10.pdb
├── requirements.txt
├── run.py
├── seq_des
├── __init__.py
├── models.py
├── sampler.py
└── util
│ ├── README.md
│ ├── __init__.py
│ ├── acc_util.py
│ ├── canonicalize.py
│ ├── data.py
│ ├── pyrosetta_util.py
│ ├── resfile_util.py
│ ├── sampler_util.py
│ └── voxelize.py
├── seq_des_info.pdf
├── train_autoreg_chi.py
├── train_autoreg_chi_baseline.py
└── txt
├── resfiles
├── NATRO_all.txt
├── PIKAA_all_one_AA.txt
├── full_example.txt
├── generate_resfile.py
├── init_seq_1acf_gt.txt
├── init_seq_1bkr_gt.txt
├── init_seq_1cc8_gt.txt
├── init_seq_3mx7_gt.txt
├── resfile_1acf_gt_ex8.txt
├── resfile_1bkr_gt_ex6.txt
├── resfile_3mx7_gt_ex1.txt
├── resfile_3mx7_gt_ex2.txt
├── some_PIKAA_one.txt
└── testing_TPIKAA_TNOTAA.txt
├── test_domains_s95.txt
├── test_idx.txt
└── train_domains_s95.txt
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProteinDesignLab/protein_seq_des/bb1e5a968f84a2db189f6a7ce400b96c5eaff691/.DS_Store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2020 The Leland Stanford Junior University (Stanford University), Namrata Anand-Achim.
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Protein sequence design with a learned potential
2 |
3 | Code for the algorithm in our paper
4 |
5 | > Namrata Anand-Achim, Raphael R. Eguchi, Alexander Derry, Russ B. Altman, and Possu Huang. "Protein sequence design with a learned potential." bioRxiv (2020).
6 | > [[biorxiv]](https://www.biorxiv.org/content/10.1101/2020.01.06.895466v1) [[cite]](#citation)
7 |
8 | 
9 |
10 | Entirely AI designed four-fold symmetric TIM-barrel
11 |
12 | ## Requirements
13 |
14 | * Python 3
15 | * [PyTorch](https://pytorch.org)
16 | * [PyRosetta4](http://www.pyrosetta.org/dow)
17 | * Python packages in requirements.txt
18 | * Download pretrained models [here](https://drive.google.com/file/d/1X66RLbaA2-qTlJLlG9TI53cao8gaKnEt/view?usp=sharing)
19 |
20 | See [here](https://github.com/nanand2/protein_seq_des/blob/master/SETUP.md) for set-up instructions on Ubuntu 18.04 with Miniconda, Python 3.7, PyTorch 1.1.0, CUDA 9.0.
21 |
22 |
23 | ## Design
24 |
25 | If you'd like to use the pre-trained models to run design, jump to [[this section]](#running-design)
26 |
27 | ## Generating data
28 | Data is available [here](https://drive.google.com/drive/folders/1MD-tu32SoYtZGag04HwntuxcuOnYPDXs). See the README in the drive for more information about the uploaded files. For the files used to generate the above coordinates, see the .txt files with the domain IDs (see data/train_domains_s95.txt and data/test_domain_s95.txt). These will be the inputs to regenerate the dataset. If you don't have PDB files downloaded, the script will download those and save it to pdb_dir.
29 |
30 | If you'd like to generate the dataset or change the underlying data run the following commands.
31 |
32 | To load and save coordinates for the backbone (BB) only model:
33 | ```
34 | python load_and_save_bb_coords.py --save_dir PATH_TO_SAVE_DATA --pdb_dir PATH_TO_PDB_FILES --log_dir PATH_TO_LOG_DIR --txt PATH_TO_DOMAIN_TXT_FILE
35 | ```
36 |
37 | To load and save coordinates for the main model:
38 | ```
39 | python load_and_save_coords.py --save_dir PATH_TO_SAVE_DATA --pdb_dir PATH_TO_PDB_FILES --log_dir PATH_TO_LOG_DIR --txt PATH_TO_DOMAIN_TXT_FILE
40 | ```
41 |
42 | ## Training the models
43 |
44 | Pretrained models are available [here](https://drive.google.com/file/d/1cHoyeI0H_Jo9bqgFH4z0dfx2s9as9Jp1/view?usp=sharing) but you can also use the available scripts to train from scratch.
45 |
46 | To train the baseline model -- residue and autoregressive rotamer prediction conditioned on backbone (BB) atoms only model (no side-chains):
47 | ```
48 | python train_autoreg_chi_baseline.py --batchSize 4096 --workers 12 --lr 1.5e-4 --validation_frequency 100 --save_frequency 1000 --log_dir PATH_TO_LOG_DIR --data_dir PATH_TO_DATA
49 | ```
50 |
51 | To train the main model -- residue and autoregressive rotamer prediction conditioned on neighboring side-chains:
52 | ```
53 | python train_autoreg_chi.py --batchSize 2048 --workers 12 --lr 7.5e-5 --validation_frequency 200 --save_frequency 2000 --log_dir PATH_TO_LOG_DIR --data_dir PATH_TO_DATA
54 | ```
55 | Note that training was originally done across 8 V100 GPUs with DataParallel mode.
56 |
57 |
58 |
59 | ## Running design
60 |
61 | To run a design trajectory, specify starting backbone with an input PDB.
62 |
63 | ```
64 | python run.py --pdb pdbs/3mx7_gt.pdb
65 | ```
66 |
67 | To run a rotamer repacking trajectory with the model, specify the repack only option
68 | ```
69 | python run.py --pdb pdbs/3mx7_gt.pdb --repack_only 1
70 | ```
71 |
72 | To specify k-fold symmetry in design or packing, specify the symmetry options
73 | ```
74 | python run.py --pdb pdbs/tim10.pdb --symmetry 1 --k 4 [--repack_only 1]
75 | ```
76 |
77 | To constraint a subset of positions to remain fixed, point to a txt file with fixed residue indices, for example
78 | ```
79 | python run.py --pdb pdbs/tim10.pdb --fixed_idx txt/test_idx.txt
80 | ```
81 |
82 | And to constrain a subset of positions to be designed, keeping all others fixed, point to a txt file with variable residue indices, for example
83 | ```
84 | python run.py --pdb pdbs/tim10.pdb --var_idx txt/test_idx.txt
85 | ```
86 |
87 | See [below](#design-parameters) for additional design parameters.
88 |
89 | ## Monitoring metrics
90 | Design metrics can be monitored using Tensorboard
91 |
92 | ```
93 | tensorboard --log_dir='./logs'
94 | ```
95 |
96 | Note that the input PDB sequence and rotamers are considered 'ground-truth' for sequence and rotamer recovery metrics.
97 |
98 |
99 |
100 | ## Design parameters
101 |
102 | * Design inputs
103 | ```
104 | --pdb Path to input PDB
105 | --model_list Paths to conditional models. (Default: ['models/conditional_model_0.pt',
106 | 'models/conditional_model_0.pt', 'models/conditional_model_1.pt',
107 | 'models/conditional_model_2.pt', 'models/conditional_model_3.pt'])
108 | --init_model Path to baseline model for sequence initialization.
109 | (Default: 'models/baseline_model.pt')
110 | ```
111 | * Saving / logging
112 | ```
113 | --log_dir Path to desired output log folder for designed
114 | structures. (Default: ./logs)
115 | --seed Random seed. Design runs are non-deterministic.
116 | (Default: 2)
117 | --save_rate How often to save intermediate designed structures
118 | (Default: 10)
119 |
120 | ```
121 | * Sequence initialization
122 | ```
123 | --randomize {0,1} Randomize starting sequence/rotamers for design.
124 | Toggle to 0 to keep starting sequence and rotamers.
125 | (Default: 1)
126 | --no_init_model {0,1} Do not use baseline model to predict initial sequence/rotamers.
127 | (Default: 0)
128 | --ala {0,1} Initialize sequence with poly-alanine. (Default: 0)
129 | --val {0,1} Initialize sequence with poly-valine. (Default: 0)
130 | ```
131 | * Rotamer repacking parameters
132 | ```
133 | --repack_only {0,1} Only run rotamer repacking. (Default: 0)
134 | --use_rosetta_packer {0,1}
135 | Use the Rosetta packer instead of the model for
136 | rotamer repacking during design. If in symmetry
137 | mode, rotamers are not packed symmetrically. (Default: 0)
138 | --pack_radius Radius in angstroms for Rosetta rotamer packing after
139 | residue mutation. Must set --use_rosetta_packer 1
140 | (Default: 0)
141 | ```
142 | * Design parameters
143 | ```
144 | --symmetry {0,1} Enforce symmetry during design (Default: 0)
145 | --k Enforce k-fold symmetry. Input pose length must be
146 | divisible by k. Requires --symmetry 1 (Default: 4)
147 | --restrict_gly {0,1} Enforce no glycines for non-loop backbone positions
148 | based on DSSP assignment. (Default: 1)
149 | --no_cys {0,1} Enforce no cysteines in design (Default: 0)
150 | --no_met {0,1} Enforce no methionines in design (Default: 0)
151 | --var_idx Path to txt file listing pose indices that should be
152 | designed/packed, all other side-chains will remain
153 | fixed. Cannot be specified if fixed_idx file given.
154 | Not supported with symmetry mode. 0-indexed
155 | --fixed_idx Path to txt file listing pose indices that should NOT
156 | be designed/packed, all other side-chains will be
157 | designed/packed. Cannot be specified if var_idx file given.
158 | Not supported with symmetry mode. 0-indexed
159 | --resfile Enforce resfile on particular residues. 0-indexed
160 | ```
161 |
162 | learn more about [resfile](https://github.com/ProteinDesignLab/protein_seq_des/tree/master/seq_des/util)
163 |
164 | * Sampling / optimization parameters
165 | ```
166 | --anneal {0,1} Option to do simulated annealing of average negative
167 | model pseudo-log-likelihood. Toggle to 0 to do vanilla
168 | blocked sampling (Default: 1)
169 | --step_rate Multiplicative step rate for simulated annealing (Default: 0.995)
170 | --anneal_start_temp Starting temperature for simulated annealing (Default: 1)
171 | --anneal_final_temp Final temperature for simulated annealing (Default: 0)
172 | --n_iters Total number of iterations (Default: 2500)
173 | --threshold Threshold in angstroms for defining conditionally
174 | independent residues for blocked sampling (should be
175 | greater than ~17.3) (Default: 20)
176 | ```
177 |
178 | Additional information
179 | * Code expects single chain PDB input.
180 | * Specifying fixed/variable indices not currently supported in symmetry mode.
181 | * Model rotamer packing in symmetry mode does symmetric rotamer packing, but using the Rosetta packer does not.
182 |
183 | ## Citation
184 | If you find our work relevant to your research, please cite:
185 | ```
186 | @article{anand2020protein,
187 | title={Protein sequence design with a learned potential},
188 | author={Anand, Namrata and Eguchi, Raphael Ryuichi and Derry, Alexander and Altman, Russ B and Huang, Possu},
189 | journal={bioRxiv},
190 | year={2020},
191 | publisher={Cold Spring Harbor Laboratory}
192 | }
193 | ```
194 |
--------------------------------------------------------------------------------
/SETUP.md:
--------------------------------------------------------------------------------
1 |
2 | ## Setup
3 |
4 | Instructions for set up on Ubuntu 18.04 with Miniconda and Python 3.7
5 |
6 | * Install [Miniconda](https://docs.conda.io/en/latest/miniconda.html)
7 | * Create a conda env for the project
8 | ```
9 | conda create -y -n seq_des python=3.7 anaconda
10 | conda activate seq_des
11 | ```
12 | * Install [PyRosetta4](http://www.pyrosetta.org/dow) via conda
13 | * Get a license for PyRosetta
14 | * Add this to ~/.condarc
15 | ```
16 | channels:
17 | - https://USERNAME:PASSWORD@conda.graylab.jhu.edu
18 | - conda-forge
19 | - defaults
20 | ```
21 | * Install PyRosetta
22 | ```
23 | conda install pyrosetta
24 | ```
25 | * Install PyTorch 1.1.0 with CUDA 9.0
26 | ```
27 | conda install -y pytorch=1.1.0 torchvision cudatoolkit=9.0 -c pytorch
28 | ```
29 | * Clone this repo
30 | ```
31 | git clone https://github.com/nanand2/protein_seq_des.git
32 | ```
33 | * Install Python packages
34 | ```
35 | cd protein_seq_des
36 | pip install -r requirements.txt
37 | ```
38 |
39 | * Download [pretrained models](https://drive.google.com/file/d/1X66RLbaA2-qTlJLlG9TI53cao8gaKnEt/view?usp=sharing) to current directory
40 |
41 | ```
42 | unzip models.zip
43 | ```
44 |
45 |
--------------------------------------------------------------------------------
/common/atoms.py:
--------------------------------------------------------------------------------
1 | import string
2 |
3 | letters = string.ascii_uppercase
4 |
5 | rename_chains = {i: letters[i] for i in range(26)} # NOTE -- expect errors if you have more than 26 structures
6 | skip_res_list = [
7 | "HOH",
8 | "GOL",
9 | "EDO",
10 | "SO4",
11 | "EDO",
12 | "NAG",
13 | "PO4",
14 | "ACT",
15 | "PEG",
16 | "MAN",
17 | "BMA",
18 | "DMS",
19 | "MPD",
20 | "MES",
21 | "PG4",
22 | "TRS",
23 | "FMT",
24 | "PGE",
25 | "EPE",
26 | "NO3",
27 | "UNX",
28 | "UNL",
29 | "UNK",
30 | "IPA",
31 | "IMD",
32 | "GLC",
33 | "MLI",
34 | "1PE",
35 | "NO3",
36 | "SCN",
37 | "P6G",
38 | "OXY",
39 | "EOH",
40 | "NH4",
41 | "DTT",
42 | "BEN",
43 | "BCT",
44 | "FUL",
45 | "AZI",
46 | "DOD",
47 | "OH",
48 | "CYN",
49 | "NO",
50 | "NO2",
51 | "SO3",
52 | "H2S",
53 | "MOH",
54 | "URE",
55 | "CO2",
56 | "2NO",
57 | ] # top ions, sugars, small molecules with N/C/O/S/P that appear to be crystal artifacts or common surface bound co-factors -- to ignore
58 | skip_atoms = ["H", "D"]
59 | atoms = ["N", "C", "O", "S", "P", "other"]
60 | aa = [
61 | "ALA",
62 | "ARG",
63 | "ASN",
64 | "ASP",
65 | "CYS",
66 | "GLN",
67 | "GLU",
68 | "GLY",
69 | "HIS",
70 | "ILE",
71 | "LEU",
72 | "LYS",
73 | "MET",
74 | "PHE",
75 | "PRO",
76 | "SER",
77 | "THR",
78 | "TRP",
79 | "TYR",
80 | "VAL",
81 | "MSE",
82 | ]
83 | res_label_dict = {
84 | "HIS": 0,
85 | "LYS": 1,
86 | "ARG": 2,
87 | "ASP": 3,
88 | "GLU": 4,
89 | "SER": 5,
90 | "THR": 6,
91 | "ASN": 7,
92 | "GLN": 8,
93 | "ALA": 9,
94 | "VAL": 10,
95 | "LEU": 11,
96 | "ILE": 12,
97 | "MET": 13,
98 | "PHE": 14,
99 | "TYR": 15,
100 | "TRP": 16,
101 | "PRO": 17,
102 | "GLY": 18,
103 | "CYS": 19,
104 | "MSE": 13,
105 | } # MSE -- same as MET
106 | label_res_dict = {
107 | 0: "HIS",
108 | 1: "LYS",
109 | 2: "ARG",
110 | 3: "ASP",
111 | 4: "GLU",
112 | 5: "SER",
113 | 6: "THR",
114 | 7: "ASN",
115 | 8: "GLN",
116 | 9: "ALA",
117 | 10: "VAL",
118 | 11: "LEU",
119 | 12: "ILE",
120 | 13: "MET",
121 | 14: "PHE",
122 | 15: "TYR",
123 | 16: "TRP",
124 | 17: "PRO",
125 | 18: "GLY",
126 | 19: "CYS",
127 | } # , 20:'MSE'}
128 |
129 | chi_dict = {
130 | "ARG": {"chi_1": "CG", "chi_2": "CD", "chi_3": "NE", "chi_4": "CZ"},
131 | "LYS": {"chi_1": "CG", "chi_2": "CD", "chi_3": "CE", "chi_4": "NZ"},
132 | "GLN": {"chi_1": "CG", "chi_2": "CD", "chi_3": "OE1"},
133 | "GLU": {"chi_1": "CG", "chi_2": "CD", "chi_3": "OE1"},
134 | "MET": {"chi_1": "CG", "chi_2": "SD", "chi_3": "CE"},
135 | "ASP": {"chi_1": "CG", "chi_2": "OD1"},
136 | "ILE": {"chi_1": "CG1", "chi_2": "CD1"},
137 | "HIS": {"chi_1": "CG", "chi_2": "ND1"},
138 | "LEU": {"chi_1": "CG", "chi_2": "CD1"},
139 | "ASN": {"chi_1": "CG", "chi_2": "OD1"},
140 | "PHE": {"chi_1": "CG", "chi_2": "CD1"},
141 | "PRO": {"chi_1": "CG", "chi_2": "CD"},
142 | "TRP": {"chi_1": "CG", "chi_2": "CD1"},
143 | "TYR": {"chi_1": "CG", "chi_2": "CD1"},
144 | "VAL": {"chi_1": "CG1"},
145 | "THR": {"chi_1": "OG1"},
146 | "SER": {"chi_1": "OG"},
147 | "CYS": {"chi_1": "SG"},
148 | "GLY": {},
149 | "ALA": {},
150 | }
151 |
152 |
153 | chi_dict_old = {
154 | "ARG": 4,
155 | "LYS": 4,
156 | "GLN": 3,
157 | "GLU": 3,
158 | "MET": 3,
159 | "ASP": 2,
160 | "ILE": 2,
161 | "HIS": 2,
162 | "LEU": 2,
163 | "ASN": 2,
164 | "PHE": 2,
165 | "PRO": 3,
166 | "TRP": 2,
167 | "TYR": 2,
168 | "VAL": 1,
169 | "THR": 1,
170 | "SER": 1,
171 | "CYS": 1,
172 | "GLY": 0,
173 | "ALA": 0,
174 | }
175 | aa_map = {
176 | 0: "H",
177 | 1: "K",
178 | 2: "R",
179 | 3: "D",
180 | 4: "E",
181 | 5: "S",
182 | 6: "T",
183 | 7: "N",
184 | 8: "Q",
185 | 9: "A",
186 | 10: "V",
187 | 11: "L",
188 | 12: "I",
189 | 13: "M",
190 | 14: "F",
191 | 15: "Y",
192 | 16: "W",
193 | 17: "P",
194 | 18: "G",
195 | 19: "C",
196 | } # , 20: "M"} # caution methionine in place of MSE
197 | aa_inv = {
198 | "H": "HIS",
199 | "K": "LYS",
200 | "R": "ARG",
201 | "D": "ASP",
202 | "E": "GLU",
203 | "S": "SER",
204 | "T": "THR",
205 | "N": "ASN",
206 | "Q": "GLN",
207 | "A": "ALA",
208 | "V": "VAL",
209 | "L": "LEU",
210 | "I": "ILE",
211 | "M": "MET",
212 | "F": "PHE",
213 | "Y": "TYR",
214 | "W": "TRP",
215 | "P": "PRO",
216 | "G": "GLY",
217 | "C": "CYS",
218 | }
219 | aa_map_inv = {
220 | "H": 0,
221 | "K": 1,
222 | "R": 2,
223 | "D": 3,
224 | "E": 4,
225 | "S": 5,
226 | "T": 6,
227 | "N": 7,
228 | "Q": 8,
229 | "A": 9,
230 | "V": 10,
231 | "L": 11,
232 | "I": 12,
233 | "M": 13,
234 | "F": 14,
235 | "Y": 15,
236 | "W": 16,
237 | "P": 17,
238 | "G": 18,
239 | "C": 19,
240 | }
241 | aa_to_letter = {aa_inv[k]: k for k in aa_inv.keys()}
242 | label_res_single_dict = {
243 | 0: "H",
244 | 1: "K",
245 | 2: "R",
246 | 3: "D",
247 | 4: "E",
248 | 5: "S",
249 | 6: "T",
250 | 7: "N",
251 | 8: "Q",
252 | 9: "A",
253 | 10: "V",
254 | 11: "L",
255 | 12: "I",
256 | 13: "M",
257 | 14: "F",
258 | 15: "Y",
259 | 16: "W",
260 | 17: "P",
261 | 18: "G",
262 | 19: "C",
263 | }
264 | # resfile commands where values are amino acids allowed by that command
265 | resfile_commands = {
266 | "ALLAA": {'H', 'K', 'R', 'D', 'E', 'S', 'T', 'N', 'Q', 'A', 'V', 'L', 'I', 'M', 'F', 'Y', 'W', 'P', 'G', 'C'},
267 | "ALLAAwc": {'H', 'K', 'R', 'D', 'E', 'S', 'T', 'N', 'Q', 'A', 'V', 'L', 'I', 'M', 'F', 'Y', 'W', 'P', 'G', 'C'},
268 | "ALLAAxc": {'H', 'K', 'R', 'D', 'E', 'S', 'T', 'N', 'Q', 'A', 'V', 'L', 'I', 'M', 'F', 'Y', 'W', 'P', 'G'},
269 | "POLAR": {'E', 'H', 'K', 'N', 'R', 'Q', 'D', 'S', 'T'},
270 | "APOLAR": {'P', 'M', 'Y', 'V', 'F', 'L', 'I', 'A', 'C', 'W', 'G'},
271 | }
272 |
--------------------------------------------------------------------------------
/common/logger.py:
--------------------------------------------------------------------------------
1 | from tensorboardX import SummaryWriter
2 | import numpy as np
3 | import os
4 | import datetime
5 | import subprocess
6 |
7 |
8 | class Logger(object):
9 | def __init__(self, log_dir="./logs", dummy=False, prefix="", suffix="", full_log_dir=None, rank=0):
10 | self.suffix = suffix
11 | self.prefix = prefix
12 | self.dummy = dummy
13 | if self.dummy:
14 | return
15 |
16 | self.iteration = 1
17 |
18 | if log_dir == "":
19 | log_dir = "./logs"
20 | if full_log_dir is None or log_dir == "":
21 | now = datetime.datetime.now()
22 | self.ts = now.strftime("%Y-%m-%d-%H-%M-%S")
23 | log_path = os.path.join(log_dir, self.prefix + self.ts + self.suffix)
24 | else:
25 | log_path = full_log_dir
26 |
27 | self.log_path = log_path
28 | if not os.path.isdir(log_path):
29 | self.writer = SummaryWriter(log_dir=log_path)
30 | self.kvs = {}
31 |
32 | print(("Logging to", log_path))
33 |
34 | def log_args(self, args):
35 | with open("%s/args.txt" % self.log_path, "w") as f:
36 | for arg in vars(args):
37 | f.write("%s\t%s\n" % (arg, getattr(args, arg)))
38 |
39 | def advance_iteration(self):
40 | self.iteration += 1
41 |
42 | def reset_iteration(self):
43 | self.iteration = 0
44 |
45 | def log_scalar(self, name, value):
46 | if self.dummy:
47 | return
48 |
49 | if isinstance(value, list):
50 | assert len(value) == 1, (name, len(value), value)
51 | return self.log_scalar(name, value[0])
52 | try:
53 | self.writer.add_scalar(name, value, self.iteration)
54 | except Exception as e:
55 | print(("Failed on", name, value, type(value)))
56 | raise
57 |
58 | def log_kvs(self, **kwargs):
59 | if self.dummy:
60 | return
61 |
62 | for k, v in kwargs.items():
63 | assert isinstance(k, str)
64 | self.kvs[k] = v
65 |
66 | kv_strings = ["%s=%s" % (k, v) for k, v in sorted(self.kvs.items())]
67 | val = "
".join(kv_strings)
68 | self.writer.add_text("properties", val, global_step=self.iteration)
69 |
--------------------------------------------------------------------------------
/common/run_manager.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | import time
4 | import numpy as np
5 | from . import logger
6 | import torch
7 | import os
8 |
9 |
10 | class RunManager(object):
11 | def __init__(self):
12 |
13 | self.parser = argparse.ArgumentParser()
14 |
15 | self.parser.add_argument("--workers", type=int, help="number of data loading workers", default=0)
16 | self.parser.add_argument("--cuda", type=int, default=1, help="enables cuda")
17 |
18 | # training parameters
19 | self.parser.add_argument("--batchSize", type=int, default=64, help="input batch size")
20 | self.parser.add_argument("--ngpu", type=int, default=1, help="num gpus to parallelize over")
21 |
22 | self.parser.add_argument("--nf", type=int, default=64, help="base number of filters")
23 | self.parser.add_argument("--txt", type=str, default="txt/train_domains_s95.txt", help="default txt input file")
24 |
25 | self.parser.add_argument("--epochs", type=int, default=100, help="enables cuda")
26 | self.parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
27 | self.parser.add_argument("--reg", type=float, default=5e-6, help="L2 regularization")
28 | self.parser.add_argument("--beta1", type=float, default=0.5, help="beta1 for adam. default=0.5")
29 | self.parser.add_argument("--momentum", type=float, default=0.01, help="momentum for batch norm")
30 |
31 | self.parser.add_argument(
32 | "--model", type=str, default="", help="path to saved pretrained model for resuming training",
33 | )
34 | self.parser.add_argument("--optimizer", type=str, default="", help="path to saved optimizer params")
35 | self.parser.add_argument(
36 | "--validation_frequency", type=int, default=500, help="how often to validate during training",
37 | )
38 | self.parser.add_argument("--save_frequency", type=int, default=2000, help="how often to save models")
39 | self.parser.add_argument("--sync_frequency", type=int, default=1000, help="how often to sync to GCP")
40 |
41 | self.parser.add_argument(
42 | "--num_return", type=int, default=400, help="number of nearest non-side-chain atmos to return per voxel",
43 | )
44 | self.parser.add_argument("--chunk_size", type=int, default=10000, help="chunk size for saved coordinate tensors")
45 |
46 | self.parser.add_argument("--data_dir", type=str, default="/data/simdev_2tb/protein/sequence_design/data/coords")
47 | self.parser.add_argument("--pdb_dir", type=str, default="/data/drive2tb/protein/pdb")
48 | self.parser.add_argument("--save_dir", type=str, default="./coords")
49 |
50 | # design inputs
51 | self.parser.add_argument("--file_dir", type=str, default="run", help="folder to store files (must be specified")
52 | self.parser.add_argument("--pdb", type=str, default="pdbs/tim10.pdb", help="Input PDB")
53 | self.parser.add_argument(
54 | "--model_list",
55 | "--list",
56 | default=[
57 | "models/conditional_model_0.pt",
58 | "models/conditional_model_1.pt",
59 | "models/conditional_model_2.pt",
60 | "models/conditional_model_3.pt",
61 | ],
62 | nargs="+",
63 | help="Paths to conditional models",
64 | )
65 | self.parser.add_argument(
66 | "--init_model", type=str, default="models/baseline_model.pt", help="Path to baseline model (conditioned on backbone atoms only)",
67 | )
68 |
69 | # saving / logging
70 | self.parser.add_argument(
71 | "--log_dir", type=str, default="./logs", help="Path to desired output log folder for designed structures",
72 | )
73 | self.parser.add_argument("--seed", default=2, type=int, help="Random seed. Design runs are non-deterministic.")
74 | self.parser.add_argument(
75 | "--save_rate", type=int, default=10, help="How often to save intermediate designed structures",
76 | )
77 |
78 | # design parameters
79 | self.parser.add_argument(
80 | "--no_init_model", type=int, default=0, choices=(0, 1), help="Do not use baseline model to initialize sequence/rotmaers.",
81 | )
82 | self.parser.add_argument(
83 | "--randomize",
84 | type=int,
85 | default=1,
86 | choices=(0, 1),
87 | help="Randomize starting sequence/rotamers for design. Toggle OFF to keep starting sequence and rotamers",
88 | )
89 | self.parser.add_argument(
90 | "--repack_only", type=int, default=0, choices=(0, 1), help="Only run rotamer repacking (no design, keep sequence fixed)",
91 | )
92 | self.parser.add_argument(
93 | "--use_rosetta_packer",
94 | type=int,
95 | default=0,
96 | choices=(0, 1),
97 | help="Use the Rosetta packer instead of the model for rotamer repacking during design",
98 | )
99 | self.parser.add_argument(
100 | "--threshold",
101 | type=float,
102 | default=20,
103 | help="Threshold in angstroms for defining conditionally independent residues for blocked sampling (should be greater than ~17.3)",
104 | )
105 | self.parser.add_argument("--symmetry", type=int, default=0, choices=(0, 1), help="Enforce symmetry during design")
106 | self.parser.add_argument(
107 | "--k", type=int, default=4, help="Enforce k-fold symmetry. Input pose length must be divisible by k. Requires --symmetry 1",
108 | )
109 | self.parser.add_argument(
110 | "--ala", type=int, default=0, choices=(0, 1), help="Initialize sequence with poly-alanine",
111 | )
112 | self.parser.add_argument(
113 | "--val", type=int, default=0, choices=(0, 1), help="Initialize sequence with poly-valine",
114 | )
115 | self.parser.add_argument(
116 | "--restrict_gly", type=int, default=1, choices=(0, 1), help="Restrict no glycines for non-loop residues",
117 | )
118 | self.parser.add_argument("--no_cys", type=int, default=0, choices=(0, 1), help="Enforce no cysteines in design")
119 | self.parser.add_argument("--no_met", type=int, default=0, choices=(0, 1), help="Enforce no methionines in design")
120 | self.parser.add_argument(
121 | "--pack_radius",
122 | type=float,
123 | default=5.0,
124 | help="Rosetta packer radius for rotamer packing after residue mutation. Must set --use_rosetta_packer 1.",
125 | )
126 | self.parser.add_argument(
127 | "--var_idx",
128 | type=str,
129 | default="",
130 | help="Path to txt file listing pose indices that should be designed/packed, all other side-chains will remain fixed. 0-indexed",
131 | )
132 | self.parser.add_argument(
133 | "--fixed_idx",
134 | type=str,
135 | default="",
136 | help="Path to txt file listing pose indices that should NOT be designed/packed, all other side-chains will be designed/packed. 0-indexed",
137 | )
138 |
139 | self.parser.add_argument("--resfile", type=str, default="", help="Specify path to a resfile to enforce constraints on particular residues")
140 |
141 | # optimization / sampling parameters
142 | self.parser.add_argument(
143 | "--anneal",
144 | type=int,
145 | default=1,
146 | choices=(0, 1),
147 | help="Option to do simulated annealing of average negative model pseudo-log-likelihood. Toggle OFF to do vanilla blocked sampling",
148 | )
149 | self.parser.add_argument("--do_mcmc", type=int, default=0, help="Option to do Metropolis-Hastings")
150 | self.parser.add_argument(
151 | "--step_rate", type=float, default=0.995, help="Multiplicative step rate for simulated annealing",
152 | )
153 | self.parser.add_argument(
154 | "--anneal_start_temp", type=float, default=1.0, help="Starting temperature for simulated annealing",
155 | )
156 | self.parser.add_argument(
157 | "--anneal_final_temp", type=float, default=0.0, help="Final temperature for simulated annealing",
158 | )
159 | self.parser.add_argument("--n_iters", type=int, default=2500, help="Total number of iterations")
160 |
161 | def add_argument(self, *args, **kwargs):
162 | self.parser.add_argument(*args, **kwargs)
163 |
164 | def parse_args(self):
165 | self.args = self.parser.parse_args()
166 |
167 | self.log = logger.Logger(log_dir=self.args.log_dir)
168 | self.log.log_kvs(**self.args.__dict__)
169 | self.log.log_args(self.args)
170 |
171 | np.random.seed(self.args.seed)
172 | random.seed(self.args.seed)
173 | torch.manual_seed(self.args.seed)
174 | torch.backends.cudnn.enabled = False
175 | torch.backends.cudnn.deterministic = True
176 | torch.backends.cudnn.benchmark = False
177 |
178 | return self.args
179 |
--------------------------------------------------------------------------------
/imgs/ex2_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProteinDesignLab/protein_seq_des/bb1e5a968f84a2db189f6a7ce400b96c5eaff691/imgs/ex2_results.png
--------------------------------------------------------------------------------
/imgs/ex3_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProteinDesignLab/protein_seq_des/bb1e5a968f84a2db189f6a7ce400b96c5eaff691/imgs/ex3_results.png
--------------------------------------------------------------------------------
/imgs/ex5_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProteinDesignLab/protein_seq_des/bb1e5a968f84a2db189f6a7ce400b96c5eaff691/imgs/ex5_results.png
--------------------------------------------------------------------------------
/imgs/ex6_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProteinDesignLab/protein_seq_des/bb1e5a968f84a2db189f6a7ce400b96c5eaff691/imgs/ex6_results.png
--------------------------------------------------------------------------------
/imgs/tim.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProteinDesignLab/protein_seq_des/bb1e5a968f84a2db189f6a7ce400b96c5eaff691/imgs/tim.gif
--------------------------------------------------------------------------------
/load_and_save_bb_coords.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | from tqdm import tqdm
9 | import common.run_manager
10 | import glob
11 | import seq_des.util.canonicalize as canonicalize
12 | import pickle
13 | import seq_des.util.data as datasets
14 | from torch.utils import data
15 |
16 |
17 | import resource
18 |
19 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
20 | resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
21 |
22 | """ script to load PDB coords, canonicalize, save """
23 |
24 | def main():
25 |
26 | manager = common.run_manager.RunManager()
27 |
28 | manager.parse_args()
29 | args = manager.args
30 | log = manager.log
31 |
32 | dataset = datasets.PDB_domain_spitter(txt_file=args.txt, pdb_path=args.pdb_dir, num_return=75, bb_only=1)
33 |
34 | dataloader = data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=args.workers)
35 |
36 | num_return = args.num_return
37 | gen = iter(dataloader)
38 | coords_out, data_out, ys, domain_ids, chis_out = [], [], [], [], []
39 |
40 | cs = args.chunk_size
41 | n = 0
42 |
43 | for it in tqdm(range(len(dataloader)), desc="loading and saving coords"):
44 |
45 | out = gen.next()
46 | if len(out) == 0 or out is None:
47 | print("out is none")
48 | continue
49 | atom_coords, atom_data, res_label, domain_id, chis = out
50 | for i in range(len(atom_coords)):
51 | coords_out.extend(atom_coords[i][0].cpu().data.numpy())
52 | data_out.extend(atom_data[i][0].cpu().data.numpy())
53 | ys.extend(res_label[i][0].cpu().data.numpy())
54 | domain_ids.extend([domain_id[i][0]] * res_label[i][0].cpu().data.numpy().shape[0])
55 | chis_out.extend(chis[i][0].cpu().data.numpy())
56 |
57 | assert len(coords_out) == len(ys)
58 | assert len(coords_out) == len(data_out)
59 | assert len(coords_out) == len(domain_ids), (len(coords_out), len(domain_ids))
60 | assert len(coords_out) == len(chis_out)
61 |
62 | del atom_coords
63 | del atom_data
64 | del res_label
65 | del domain_id
66 |
67 | # intermittent save data
68 | if len(coords_out) > cs or it == len(dataloader) - 1:
69 | # shuffle then save
70 | print(n, len(coords_out)) # -- NOTE keep this
71 | idx = np.arange(min(cs, len(coords_out)))
72 | np.random.shuffle(idx)
73 | print(n, len(idx))
74 |
75 | c, d, y, di, ch = map(lambda arr: np.array(arr[: len(idx)])[idx], [coords_out, data_out, ys, domain_ids, chis_out])
76 |
77 | print("saving", args.save_dir + "/" + "data_%0.4d.pt" % (n))
78 | torch.save((c, d, y, di, ch), args.save_dir + "/" + "data_%0.4d.pt" % (n))
79 |
80 | print("Current num examples", (n) * cs + len(coords_out))
81 |
82 | n += 1
83 | coords_out, data_out, ys, domain_ids, chis_out = map(lambda arr: arr[len(idx) :], [coords_out, data_out, ys, domain_ids, chis_out])
84 |
85 |
86 | if __name__ == "__main__":
87 | main()
88 |
--------------------------------------------------------------------------------
/load_and_save_coords.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | from tqdm import tqdm
9 | import common.run_manager
10 | import glob
11 | import seq_des.util.canonicalize as canonicalize
12 | import pickle
13 | import seq_des.util.data as datasets
14 | from torch.utils import data
15 |
16 |
17 | import resource
18 |
19 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
20 | resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
21 |
22 | """ script to load PDB coords, canonicalize, save """
23 |
24 | def main():
25 |
26 | manager = common.run_manager.RunManager()
27 |
28 | manager.parse_args()
29 | args = manager.args
30 | log = manager.log
31 |
32 | dataset = datasets.PDB_domain_spitter(txt_file=args.txt, pdb_path=args.pdb_dir)
33 |
34 | dataloader = data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=args.workers)
35 |
36 | num_return = args.num_return
37 | gen = iter(dataloader)
38 | coords_out, data_out, ys, domain_ids, chis_out = [], [], [], [], []
39 |
40 | cs = args.chunk_size
41 | n = 0
42 |
43 | for it in tqdm(range(len(dataloader)), desc="loading and saving coords"):
44 |
45 | out = gen.next()
46 | if len(out) == 0 or out is None:
47 | print("out is none")
48 | continue
49 | atom_coords, atom_data, res_label, domain_id, chis = out
50 | for i in range(len(atom_coords)):
51 | coords_out.extend(atom_coords[i][0].cpu().data.numpy())
52 | data_out.extend(atom_data[i][0].cpu().data.numpy())
53 | ys.extend(res_label[i][0].cpu().data.numpy())
54 | domain_ids.extend([domain_id[i][0]] * res_label[i][0].cpu().data.numpy().shape[0])
55 | chis_out.extend(chis[i][0].cpu().data.numpy())
56 |
57 | assert len(coords_out) == len(ys)
58 | assert len(coords_out) == len(data_out)
59 | assert len(coords_out) == len(domain_ids), (len(coords_out), len(domain_ids))
60 | assert len(coords_out) == len(chis_out)
61 |
62 | del atom_coords
63 | del atom_data
64 | del res_label
65 | del domain_id
66 |
67 | # intermittent save data
68 | if len(coords_out) > cs or it == len(dataloader) - 1:
69 | # shuffle then save
70 | print(n, len(coords_out)) # -- NOTE keep this
71 | idx = np.arange(min(cs, len(coords_out)))
72 | np.random.shuffle(idx)
73 | print(n, len(idx))
74 |
75 | c, d, y, di, ch = map(lambda arr: np.array(arr[: len(idx)])[idx], [coords_out, data_out, ys, domain_ids, chis_out])
76 |
77 | print("saving", args.save_dir + "/" + "data_%0.4d.pt" % (n))
78 | torch.save((c, d, y, di, ch), args.save_dir + "/" + "data_%0.4d.pt" % (n))
79 |
80 | print("Current num examples", (n) * cs + len(coords_out))
81 |
82 | n += 1
83 | coords_out, data_out, ys, domain_ids, chis_out = map(lambda arr: arr[len(idx) :], [coords_out, data_out, ys, domain_ids, chis_out])
84 |
85 |
86 | if __name__ == "__main__":
87 | main()
88 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | biopython
2 | matplotlib
3 | tqdm
4 | tensorboardX
5 | scipy
6 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import torch
4 | import torch.nn as nn
5 |
6 | from seq_des import *
7 | import seq_des.sampler as sampler
8 | import seq_des.models as models
9 |
10 | import common.run_manager
11 | import common.atoms
12 |
13 | import sys
14 | import pickle
15 | import glob
16 |
17 | from pyrosetta.rosetta.protocols.simple_filters import BuriedUnsatHbondFilterCreator, PackStatFilterCreator
18 | from pyrosetta.rosetta.protocols.denovo_design.filters import ExposedHydrophobicsFilterCreator
19 |
20 | from tqdm import tqdm
21 |
22 | __author__ = 'Namrata Anand-Achim'
23 |
24 |
25 | def log_metrics(run="sampler", args=None, log=None, iteration=0, design_sampler=None, prefix=""):
26 | # tensorboard logging
27 |
28 | # log structure / sequence metrics
29 | log.log_scalar("{run}/{prefix}rosetta_energy".format(run=run, prefix=prefix), design_sampler.rosetta_energy)
30 | log.log_scalar("{run}/{prefix}seq_overlap".format(run=run, prefix=prefix), design_sampler.seq_overlap)
31 | log.log_scalar("{run}/{prefix}anneal_start_temp".format(run=run, prefix=prefix), design_sampler.anneal_start_temp)
32 | log.log_scalar("{run}/{prefix}anneal_final_temp".format(run=run, prefix=prefix), design_sampler.anneal_final_temp)
33 | log.log_scalar("{run}/{prefix}log_p".format(run=run, prefix=prefix), design_sampler.log_p_mean.item())
34 | log.log_scalar("{run}/{prefix}chi_error".format(run=run, prefix=prefix), design_sampler.chi_error)
35 | log.log_scalar("{run}/{prefix}chi_rmsd".format(run=run, prefix=prefix), design_sampler.chi_rmsd)
36 |
37 | # log rosetta score terms
38 | for s in design_sampler.score_terms:
39 | log.log_scalar("{run}/z_{prefix}{s}".format(run=run, prefix=prefix, s=s), float(design_sampler.curr_score_terms[s].mean()))
40 |
41 | # log rosetta agnostic terms
42 | for n, s in design_sampler.filter_scores:
43 | log.log_scalar("{run}/y_{prefix}{n}".format(run=run, prefix=prefix, n=n), s)
44 |
45 |
46 |
47 | def load_model(model, use_cuda=True, nic=len(common.atoms.atoms)):
48 | classifier = models.seqPred(nic=nic)
49 | if use_cuda:
50 | classifier.cuda()
51 | if use_cuda:
52 | state = torch.load(model)
53 | else:
54 | state = torch.load(model, map_location="cpu")
55 | for k in state.keys():
56 | if "module" in k:
57 | print("MODULE")
58 | classifier = nn.DataParallel(classifier)
59 | break
60 | if use_cuda:
61 | classifier.load_state_dict(torch.load(model))
62 | else:
63 | classifier.load_state_dict(torch.load(model, map_location="cpu"))
64 | return classifier
65 |
66 |
67 | def load_models(model_list, use_cuda=True, nic=len(common.atoms.atoms)):
68 | classifiers = []
69 | for model in model_list:
70 | classifier = load_model(model, use_cuda=use_cuda, nic=nic)
71 | classifiers.append(classifier)
72 | return classifiers
73 |
74 |
75 | def main():
76 |
77 | manager = common.run_manager.RunManager()
78 |
79 | manager.parse_args()
80 | args = manager.args
81 | log = manager.log
82 |
83 | use_cuda = torch.cuda.is_available()
84 |
85 | # download pdb if not there already
86 | if not os.path.isfile(args.pdb):
87 | print("Downloading pdb to current directory...")
88 | os.system("wget -O {} https://files.rcsb.org/download/{}.pdb".format(args.pdb, args.pdb[:-4].upper()))
89 |
90 | assert os.path.isfile(args.pdb), "pdb not found"
91 |
92 | # load models
93 | if args.init_model != "":
94 | init_classifier = load_model(args.init_model, use_cuda=use_cuda, nic=len(common.atoms.atoms))
95 | init_classifier.eval()
96 | init_classifiers = [init_classifier]
97 | else:
98 | assert not (args.ala and args.val), "must specify either poly-alanine or poly-valine"
99 | if args.randomize:
100 | if args.ala:
101 | init_scheme = "poly-alanine"
102 | elif args.val:
103 | init_scheme = "poly-valine"
104 | else:
105 | init_scheme = "random"
106 | else: init_scheme = 'using starting structure'
107 | print("No baseline model specified, initialization will be %s" % init_scheme)
108 | init_classifiers = None
109 |
110 | classifiers = load_models(args.model_list, use_cuda=use_cuda, nic=len(common.atoms.atoms) + 1 + 21)
111 | for classifier in classifiers:
112 | classifier.eval()
113 |
114 | # set up design_sampler
115 | design_sampler = sampler.Sampler(args, classifiers, init_classifiers, log=log, use_cuda=use_cuda)
116 |
117 | # initialize sampler
118 | design_sampler.init()
119 |
120 | # log metrics for gt seq/structure
121 | log_metrics(args=args, log=log, iteration=0, design_sampler=design_sampler, prefix="gt_")
122 | best_rosetta_energy = np.inf
123 | best_energy = np.inf
124 |
125 | # initialize design_sampler sequence with baseline model prediction or random/poly-alanine/poly-valine initial sequence, save initial model
126 | design_sampler.init_seq()
127 | design_sampler.pose.dump_pdb(log.log_path + "/" + args.file_dir + "/" + "curr_0.pdb")
128 |
129 | # save trajectories for logmeans and rosettas
130 | logmeans = np.zeros(int(args.n_iters))
131 | rosettas = np.zeros(int(args.n_iters))
132 |
133 | # run design
134 | with torch.no_grad():
135 | for i in tqdm(range(1, int(args.n_iters)), desc='running design'):
136 |
137 | # step
138 | design_sampler.step()
139 |
140 | # logging
141 | log_metrics(args=args, log=log, iteration=i, design_sampler=design_sampler)
142 |
143 | # save log_p_means and rosettas
144 | logmeans[i] = design_sampler.log_p_mean
145 | rosettas[i] = design_sampler.rosetta_energy
146 |
147 | if design_sampler.log_p_mean < best_energy:
148 | design_sampler.pose.dump_pdb(log.log_path + "/" + args.file_dir + "/" + "curr_best_log_p_%s.pdb" % log.ts)
149 | best_energy = design_sampler.log_p_mean
150 |
151 | if design_sampler.rosetta_energy < best_rosetta_energy:
152 | design_sampler.pose.dump_pdb(log.log_path + "/" + args.file_dir + "/" + "curr_best_rosetta_energy_%s.pdb" % log.ts)
153 | best_rosetta_energy = design_sampler.rosetta_energy
154 |
155 | # save intermediate models -- comment out if desired
156 | if (i==1) or (i % args.save_rate == 0) or (i == args.n_iters - 1):
157 | design_sampler.pose.dump_pdb(log.log_path + "/" + args.file_dir + "/" + "curr_%s_%s.pdb" % (i, log.ts))
158 |
159 | log.advance_iteration()
160 |
161 | # save final model
162 | design_sampler.pose.dump_pdb(log.log_path + "/" + args.file_dir + "/" + "curr_final.pdb")
163 |
164 | np.savetxt('{}/{}/logmeans.txt'.format(log.log_path, args.file_dir),logmeans, delimiter=',')
165 | np.savetxt('{}/{}/rosetta_energy.txt'.format(log.log_path, args.file_dir),rosettas, delimiter=',')
166 |
167 | if __name__ == "__main__":
168 | main()
169 |
--------------------------------------------------------------------------------
/seq_des/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProteinDesignLab/protein_seq_des/bb1e5a968f84a2db189f6a7ce400b96c5eaff691/seq_des/__init__.py
--------------------------------------------------------------------------------
/seq_des/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import seq_des.util.data as data
4 | import common.atoms
5 |
6 |
7 | def init_ortho_weights(self):
8 | for module in self.modules():
9 | if isinstance(module, nn.Conv2d):
10 | torch.nn.init.orthogonal_(module.weight)
11 | elif isinstance(module, nn.ConvTranspose2d):
12 | torch.nn.init.orthogonal_(module.weight)
13 |
14 |
15 | class seqPred(nn.Module):
16 | def __init__(self, nic, nf=64, momentum=0.01):
17 | super(seqPred, self).__init__()
18 | self.nic = nic
19 | self.model = nn.Sequential(
20 | # 20 -- 10
21 | nn.Conv3d(nic, nf, 4, 2, 1, bias=False),
22 | nn.BatchNorm3d(nf, momentum=momentum),
23 | nn.LeakyReLU(0.2, inplace=True),
24 | nn.Dropout(0.1),
25 | nn.Conv3d(nf, nf, 3, 1, 1, bias=False),
26 | nn.BatchNorm3d(nf, momentum=momentum),
27 | nn.LeakyReLU(0.2, inplace=True),
28 | nn.Dropout(0.1),
29 | nn.Conv3d(nf, nf, 3, 1, 1, bias=False),
30 | nn.BatchNorm3d(nf, momentum=momentum),
31 | nn.LeakyReLU(0.2, inplace=True),
32 | nn.Dropout(0.1),
33 | # 10 -- 5
34 | nn.Conv3d(nf, nf * 2, 4, 2, 1, bias=False),
35 | nn.BatchNorm3d(nf * 2, momentum=momentum),
36 | nn.LeakyReLU(0.2, inplace=True),
37 | nn.Dropout(0.1),
38 | nn.Conv3d(nf * 2, nf * 2, 3, 1, 1, bias=False),
39 | nn.BatchNorm3d(nf * 2, momentum=momentum),
40 | nn.LeakyReLU(0.2, inplace=True),
41 | nn.Dropout(0.1),
42 | nn.Conv3d(nf * 2, nf * 2, 3, 1, 1, bias=False),
43 | nn.BatchNorm3d(nf * 2, momentum=momentum),
44 | nn.LeakyReLU(0.2, inplace=True),
45 | nn.Dropout(0.1),
46 | # 5 -- 1
47 | nn.Conv3d(nf * 2, nf * 4, 5, 1, 0, bias=False),
48 | nn.BatchNorm3d(nf * 4, momentum=momentum),
49 | nn.LeakyReLU(0.2, inplace=True),
50 | nn.Dropout(0.1),
51 | nn.Conv3d(nf * 4, nf * 4, 3, 1, 1, bias=False),
52 | nn.BatchNorm3d(nf * 4, momentum=momentum),
53 | nn.LeakyReLU(0.2, inplace=True),
54 | nn.Dropout(0.1),
55 | nn.Conv3d(nf * 4, nf * 4, 3, 1, 1, bias=False),
56 | nn.BatchNorm3d(nf * 4, momentum=momentum),
57 | nn.LeakyReLU(0.2, inplace=True),
58 | nn.Dropout(0.1),
59 | )
60 |
61 | # res pred
62 | self.out = nn.Sequential(
63 | nn.Conv1d(nf * 4, nf * 4, 3, 1, 1, bias=False),
64 | nn.BatchNorm1d(nf * 4, momentum=momentum),
65 | nn.LeakyReLU(0.2, inplace=True),
66 | nn.Dropout(0.1),
67 | nn.Conv1d(nf * 4, nf * 4, 3, 1, 1, bias=False),
68 | nn.BatchNorm1d(nf * 4, momentum=momentum),
69 | nn.LeakyReLU(0.2, inplace=True),
70 | nn.Dropout(0.1),
71 | nn.Conv1d(nf * 4, len(common.atoms.label_res_dict.keys()), 3, 1, 1, bias=False),
72 | )
73 |
74 | # chi feat vec -- condition on residue and env feature vector
75 | self.chi_feat = nn.Sequential(
76 | nn.Conv1d(nf * 4 + 20, nf * 4, 3, 1, 1, bias=False),
77 | nn.BatchNorm1d(nf * 4, momentum=momentum),
78 | nn.LeakyReLU(0.2, inplace=True),
79 | nn.Dropout(0.1),
80 | nn.Conv1d(nf * 4, nf * 4, 3, 1, 1, bias=False),
81 | nn.BatchNorm1d(nf * 4, momentum=momentum),
82 | nn.LeakyReLU(0.2, inplace=True),
83 | nn.Dropout(0.1),
84 | )
85 |
86 | # chi 1 pred -- condition on chi feat vec
87 | self.chi_1_out = nn.Sequential(
88 | nn.Conv1d(nf * 4, nf * 4, 3, 1, 1, bias=False),
89 | nn.BatchNorm1d(nf * 4, momentum=momentum),
90 | nn.LeakyReLU(0.2, inplace=True),
91 | nn.Dropout(0.1),
92 | nn.Conv1d(nf * 4, nf * 4, 3, 1, 1, bias=False),
93 | nn.BatchNorm1d(nf * 4, momentum=momentum),
94 | nn.LeakyReLU(0.2, inplace=True),
95 | nn.Dropout(0.1),
96 | nn.Conv1d(nf * 4, (len(data.CHI_BINS) - 1), 3, 1, 1, bias=False),
97 | )
98 |
99 | # chi 2 pred -- condition on chi 1 and chi feat vec
100 | self.chi_2_out = nn.Sequential(
101 | nn.Conv1d(nf * 4 + 1 * (len(data.CHI_BINS) - 1), nf * 4, 3, 1, 1, bias=False),
102 | nn.BatchNorm1d(nf * 4, momentum=momentum),
103 | nn.LeakyReLU(0.2, inplace=True),
104 | nn.Dropout(0.1),
105 | nn.Conv1d(nf * 4, nf * 4, 3, 1, 1, bias=False),
106 | nn.BatchNorm1d(nf * 4, momentum=momentum),
107 | nn.LeakyReLU(0.2, inplace=True),
108 | nn.Dropout(0.1),
109 | nn.Conv1d(nf * 4, (len(data.CHI_BINS) - 1), 3, 1, 1, bias=False),
110 | )
111 |
112 | # chi 3 pred -- condition on chi 1, chi 2, and chi feat vec
113 | self.chi_3_out = nn.Sequential(
114 | nn.Conv1d(nf * 4 + 2 * (len(data.CHI_BINS) - 1), nf * 4, 3, 1, 1, bias=False),
115 | nn.BatchNorm1d(nf * 4, momentum=momentum),
116 | nn.LeakyReLU(0.2, inplace=True),
117 | nn.Dropout(0.1),
118 | nn.Conv1d(nf * 4, nf * 4, 3, 1, 1, bias=False),
119 | nn.BatchNorm1d(nf * 4, momentum=momentum),
120 | nn.LeakyReLU(0.2, inplace=True),
121 | nn.Dropout(0.1),
122 | nn.Conv1d(nf * 4, (len(data.CHI_BINS) - 1), 3, 1, 1, bias=False),
123 | )
124 |
125 | # chi 4 pred -- condition on chi 1, chi 2, chi 3, and chi feat vec
126 | self.chi_4_out = nn.Sequential(
127 | nn.Conv1d(nf * 4 + 3 * (len(data.CHI_BINS) - 1), nf * 4, 3, 1, 1, bias=False),
128 | nn.BatchNorm1d(nf * 4, momentum=momentum),
129 | nn.LeakyReLU(0.2, inplace=True),
130 | nn.Dropout(0.1),
131 | nn.Conv1d(nf * 4, nf * 4, 3, 1, 1, bias=False),
132 | nn.BatchNorm1d(nf * 4, momentum=momentum),
133 | nn.LeakyReLU(0.2, inplace=True),
134 | nn.Dropout(0.1),
135 | nn.Conv1d(nf * 4, (len(data.CHI_BINS) - 1), 3, 1, 1, bias=False),
136 | )
137 |
138 | def res_pred(self, input):
139 | bs = input.size()[0]
140 | feat = self.model(input).view(bs, -1, 1)
141 | res_pred = self.out(feat).view(bs, -1)
142 | return res_pred, feat
143 |
144 | def get_chi_init_feat(self, feat, res_onehot):
145 | chi_init = torch.cat([feat, res_onehot[..., None]], 1)
146 | chi_feat = self.chi_feat(chi_init)
147 | return chi_feat
148 |
149 | def get_chi_1(self, chi_feat):
150 | chi_1_pred = self.chi_1_out(chi_feat).view(chi_feat.size()[0], -1)
151 | return chi_1_pred
152 |
153 | def get_chi_2(self, chi_feat, chi_1_onehot):
154 | chi_2_pred = self.chi_2_out(torch.cat([chi_feat, chi_1_onehot[..., None]], 1)).view(chi_feat.size()[0], -1)
155 | return chi_2_pred
156 |
157 | def get_chi_3(self, chi_feat, chi_1_onehot, chi_2_onehot):
158 | chi_3_pred = self.chi_3_out(torch.cat([chi_feat, chi_1_onehot[..., None], chi_2_onehot[..., None]], 1)).view(chi_feat.size()[0], -1)
159 | return chi_3_pred
160 |
161 | def get_chi_4(self, chi_feat, chi_1_onehot, chi_2_onehot, chi_3_onehot):
162 | chi_4_pred = self.chi_4_out(torch.cat([chi_feat, chi_1_onehot[..., None], chi_2_onehot[..., None], chi_3_onehot[..., None]], 1)).view(
163 | chi_feat.size()[0], -1
164 | )
165 | return chi_4_pred
166 |
167 | def get_feat(self, input, res_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot):
168 | bs = input.size()[0]
169 | feat = self.model(input).view(bs, -1, 1)
170 | res_pred = self.out(feat).view(bs, -1)
171 |
172 | # condition on res type and env feat
173 | chi_init = torch.cat([feat, res_onehot[..., None]], 1)
174 | chi_feat = self.chi_feat(chi_init)
175 |
176 | # condition on true residue type and previous ground-truth rotamer angles
177 | chi_1_pred = self.chi_1_out(chi_feat).view(bs, -1)
178 | chi_2_pred = self.chi_2_out(torch.cat([chi_feat, chi_1_onehot[..., None]], 1)).view(bs, -1)
179 | chi_3_pred = self.chi_3_out(torch.cat([chi_feat, chi_1_onehot[..., None], chi_2_onehot[..., None]], 1)).view(bs, -1)
180 | chi_4_pred = self.chi_4_out(torch.cat([chi_feat, chi_1_onehot[..., None], chi_2_onehot[..., None], chi_3_onehot[..., None]], 1)).view(bs, -1)
181 | return feat, res_pred, chi_1_pred, chi_2_pred, chi_3_pred, chi_4_pred
182 |
183 |
184 | def forward(self, input, res_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot):
185 | bs = input.size()[0]
186 | feat = self.model(input).view(bs, -1, 1)
187 | res_pred = self.out(feat).view(bs, -1)
188 |
189 | # condition on res type and env feat
190 | chi_init = torch.cat([feat, res_onehot[..., None]], 1)
191 | chi_feat = self.chi_feat(chi_init)
192 |
193 | # condition on true residue type and previous ground-truth rotamer angles
194 | chi_1_pred = self.chi_1_out(chi_feat).view(bs, -1)
195 | chi_2_pred = self.chi_2_out(torch.cat([chi_feat, chi_1_onehot[..., None]], 1)).view(bs, -1)
196 | chi_3_pred = self.chi_3_out(torch.cat([chi_feat, chi_1_onehot[..., None], chi_2_onehot[..., None]], 1)).view(bs, -1)
197 | chi_4_pred = self.chi_4_out(torch.cat([chi_feat, chi_1_onehot[..., None], chi_2_onehot[..., None], chi_3_onehot[..., None]], 1)).view(bs, -1)
198 |
199 | return res_pred, chi_1_pred, chi_2_pred, chi_3_pred, chi_4_pred
200 |
--------------------------------------------------------------------------------
/seq_des/sampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import math
4 | import sys
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch.distributions.categorical import Categorical
9 |
10 | import seq_des.util.pyrosetta_util as putil
11 | import seq_des.util.sampler_util as sampler_util
12 | import seq_des.util.canonicalize as canonicalize
13 | import seq_des.util.data as data
14 | import seq_des.util.resfile_util as resfile_util
15 | import common.atoms
16 |
17 | from pyrosetta.rosetta.protocols.simple_filters import (
18 | BuriedUnsatHbondFilterCreator,
19 | PackStatFilterCreator,
20 | )
21 | from pyrosetta.rosetta.protocols.denovo_design.filters import ExposedHydrophobicsFilterCreator
22 | from pyrosetta.rosetta.core.scoring import automorphic_rmsd
23 |
24 |
25 | class Sampler(object):
26 | def __init__(self, args, models, init_models=None, log=None, use_cuda=True):
27 | super(Sampler, self).__init__()
28 | self.models = models
29 | for model in self.models:
30 | model.eval()
31 |
32 | if init_models is not None:
33 | self.init_models = init_models
34 | for init_model in self.init_models:
35 | init_model.eval()
36 | else:
37 | self.init_models = None
38 | self.no_init_model = args.no_init_model
39 |
40 | self.pdb = args.pdb
41 | self.log = log
42 | self.use_cuda = use_cuda
43 |
44 | self.threshold = args.threshold
45 | self.pack_radius = args.pack_radius
46 | self.iteration = 0
47 | self.randomize = args.randomize
48 | self.rotamer_repack = args.repack_only
49 | self.use_rosetta_packer = args.use_rosetta_packer
50 | self.no_cys = args.no_cys
51 | self.no_met = args.no_met
52 | self.symmetry = args.symmetry
53 | self.k = args.k
54 | self.restrict_gly = args.restrict_gly
55 | self.ala = args.ala
56 | self.val = args.val
57 | assert not (self.ala and self.val), "only ala or val settings can be on for a given run"
58 | self.chi_mask = None
59 |
60 | self.anneal = args.anneal
61 | self.anneal_start_temp = args.anneal_start_temp
62 | self.anneal_final_temp = args.anneal_final_temp
63 | self.step_rate = args.step_rate
64 | self.accept_prob = 1
65 |
66 | # load fixed idx if applicable
67 | if args.fixed_idx != "":
68 | # assert not self.symmetry, 'fixed idx not supported in symmetry mode'
69 | self.fixed_idx = sampler_util.get_idx(args.fixed_idx)
70 | else:
71 | self.fixed_idx = []
72 |
73 | # resfile restrictions handling (see util/resfile_util.py)
74 | self.resfile = args.resfile
75 | if self.resfile:
76 | # get resfile NATRO (used to skip designing/packing at all)
77 | self.fixed_idx = resfile_util.get_natro(self.resfile)
78 | # get resfile commands (used to restrict amino acid probability distribution)
79 | self.resfile = resfile_util.read_resfile(self.resfile)
80 | # get initial resfile sequence (used to initialize the sequence)
81 | self.init_seq_resfile = self.resfile[2]
82 |
83 | # the initial sequence must be randomized (avoid running the baseline model)
84 | if self.init_seq_resfile:
85 | self.randomize = 0
86 |
87 | # load var idx if applicable
88 | if args.var_idx != "":
89 | # assert not self.symmetry, 'var idx not supported in symmetry mode'
90 | self.var_idx = sampler_util.get_idx(args.var_idx)
91 | else:
92 | self.var_idx = []
93 |
94 | assert not ((len(self.fixed_idx) > 0) and (len(self.var_idx) > 0)), "cannot specify both fixed and variable indices"
95 |
96 | if self.rotamer_repack:
97 | assert self.init_models is not None, "baseline model must be used to initialize rotamer repacking"
98 |
99 | if self.no_init_model:
100 | assert not self.rotamer_repack, "baseline model must be used for initializing rotamer repacking"
101 |
102 | if self.symmetry:
103 | assert len(self.fixed_idx) == 0, "specifying fixed idx not supported in symmetry model"
104 | assert len(self.var_idx) == 0, "specifying var idx not supported in symmetry model"
105 |
106 | def init(self):
107 | """ initialize sampler
108 | - initialize rosetta filters
109 | - score starting (ground-truth) sequence
110 | - set up constraints on glycines
111 | - set up symmetry
112 | - eval metrics on starting (ground-truth) sequence
113 | - get blocks for blocked sampling
114 | """
115 |
116 | # initialize sampler
117 | self.init_rosetta_filters()
118 | # score starting (ground-truth) pdb, get gt energies
119 | self.gt_pose = putil.get_pose(self.pdb)
120 | self.gt_seq = self.gt_pose.sequence()
121 | (
122 | _,
123 | self.log_p_per_res,
124 | self.log_p_mean,
125 | self.logits,
126 | self.chi_feat,
127 | self.gt_chi_angles,
128 | self.gt_chi_mask,
129 | self.gt_chi,
130 | ) = sampler_util.get_energy(
131 | self.models, pose=self.gt_pose, return_chi=1, log_path=self.log.log_path, include_rotamer_probs=1, use_cuda=self.use_cuda,
132 | )
133 | self.chi_error = 0
134 | self.re = putil.score_pose(self.gt_pose)
135 | self.gt_pose.dump_pdb(self.log.log_path + "/" + "gt_" + self.pdb)
136 | self.gt_score_terms = self.gt_pose.energies().residue_total_energies_array()
137 | self.score_terms = list(self.gt_score_terms.dtype.fields)
138 |
139 | # set no gly indices
140 | self.gt_pose.display_secstruct()
141 | ss = self.gt_pose.secstruct()
142 | self.no_gly_idx = [i for i in range(len(ss)) if ss[i] != "L"]
143 | self.n = self.gt_pose.residues.__len__()
144 |
145 | # handle symmetry
146 | if self.symmetry:
147 | if "tim" in self.pdb:
148 | # handle tim case
149 | self.n_k = math.ceil((self.n + 1) / self.k) if (self.n + 1) % 2 == 0 else math.ceil((self.n) / self.k)
150 | else:
151 | self.n_k = self.n // self.k
152 | assert self.n % self.k == 0, 'length of protein must be divisible by k for k-fold symm design'
153 | idx = [[i + j * (self.n_k) for j in range(self.k) if i + j * (self.n_k) < self.n] for i in range(self.n_k)]
154 | self.symmetry_idx = {}
155 | for idx_set in idx:
156 | for i in idx_set:
157 | self.symmetry_idx[i] = idx_set
158 |
159 | # updated fixed/var idx to reflect symmetry
160 | for i in self.fixed_idx:
161 | assert (
162 | i in self.symmetry_idx.keys()
163 | ), "fixed idx must only be specified for first symmetric unit in symmetry mode (within first n_k residues)"
164 | for i in self.var_idx:
165 | assert (
166 | i in self.symmetry_idx.keys()
167 | ), "var idx must only be specified for first symmetric unit in symmetry mode (within first n_k residues)"
168 |
169 | # get gt data -- monitor distance to initial sequence
170 | if '/' in self.pdb:
171 | pdb_idx = self.pdb.rfind("/") + 1
172 | pdb_dir = self.pdb[: self.pdb.rfind("/")]
173 | else:
174 | pdb_idx = 0
175 | pdb_dir = './'
176 | (self.gt_atom_coords, self.gt_atom_data, self.gt_residue_bb_index_list, res_data, self.gt_res_label, chis,) = data.get_pdb_data(
177 | self.pdb[pdb_idx : -4], data_dir=pdb_dir, assembly=0, skip_download=1,
178 | )
179 | self.eval_metrics(self.gt_pose, self.gt_res_label)
180 |
181 | # get conditionally independent blocks via greedy k-colring of backbone 'graph'
182 | self.get_blocks()
183 |
184 |
185 | def init_seq(self):
186 | # initialize starting sequence
187 |
188 | # random/poly-alanine/poly-valine initialize sequence, pack rotamers
189 | self.pose = putil.get_pose(self.pdb)
190 | if self.randomize:
191 | if (not self.no_init_model) and not (self.ala or self.val):
192 | # get features --> BB only
193 | (
194 | res_label,
195 | self.log_p_per_res_temp,
196 | self.log_p_mean_temp,
197 | self.logits_temp,
198 | self.chi_feat_temp,
199 | self.chi_angles_temp,
200 | self.chi_mask_temp,
201 | ) = sampler_util.get_energy(
202 | self.init_models, self.pose, bb_only=1, log_path=self.log.log_path, include_rotamer_probs=1, use_cuda=self.use_cuda,
203 | )
204 |
205 | # set sequence
206 | if not self.rotamer_repack:
207 | # sample res from logits
208 | if not self.symmetry:
209 | res, idx, res_label = self.sample(self.logits_temp, np.arange(len(res_label)))
210 | else:
211 | res, idx, res_label = self.sample(self.logits_temp, np.arange(self.n_k))
212 | # mutate pose residues based on baseline prediction
213 | self.pose = putil.mutate_list(self.pose, idx, res, pack_radius=0, fixed_idx=self.fixed_idx, var_idx=self.var_idx)
214 | else:
215 | res = [i for i in self.gt_seq]
216 | if self.symmetry:
217 | res_label = res_label[: self.n_k]
218 |
219 | # sample and set rotamers
220 | if self.symmetry:
221 | if not self.rotamer_repack:
222 | (self.chi_1, self.chi_2, self.chi_3, self.chi_4, idx, res_idx,) = self.sample_rotamer(
223 | np.arange(self.n_k), [res_label[i] for i in range(0, len(res_label), self.k)], self.chi_feat_temp, bb_only=1,
224 | )
225 | else:
226 | (self.chi_1, self.chi_2, self.chi_3, self.chi_4, idx, res_idx,) = self.sample_rotamer(
227 | np.arange(self.n_k), res_label, self.chi_feat_temp, bb_only=1,
228 | )
229 | else:
230 | (self.chi_1, self.chi_2, self.chi_3, self.chi_4, idx, res_idx,) = self.sample_rotamer(
231 | np.arange(len(res_label)), res_label, self.chi_feat_temp, bb_only=1,
232 | )
233 | res = [common.atoms.label_res_single_dict[k] for k in res_idx]
234 | self.pose = self.set_rotamer(self.pose, res, idx, self.chi_1, self.chi_2, self.chi_3, self.chi_4, fixed_idx=self.fixed_idx, var_idx=self.var_idx)
235 |
236 | # Randomize sequence/rotamers
237 | else:
238 | if not self.rotamer_repack:
239 | random_seq = np.random.choice(20, size=len(self.pose))
240 | if not self.ala and not self.val and self.symmetry:
241 | # random sequence must be symmetric
242 | random_seq = np.concatenate([random_seq[: self.n_k] for i in range(self.k)])
243 | random_seq = random_seq[: len(self.pose)]
244 | self.pose, _ = putil.randomize_sequence(
245 | random_seq,
246 | self.pose,
247 | pack_radius=self.pack_radius,
248 | ala=self.ala,
249 | val=self.val,
250 | resfile_init_seq=self.init_seq_resfile,
251 | fixed_idx=self.fixed_idx,
252 | var_idx=self.var_idx,
253 | repack_rotamers=1,)
254 | else:
255 | assert False, "baseline model must be used for initializing rotamer repacking"
256 |
257 | # evaluate energy for starting structure/sequence
258 | (self.res_label, self.log_p_per_res, self.log_p_mean, self.logits, self.chi_feat, self.chi_angles, self.chi_mask,) = sampler_util.get_energy(
259 | self.models, self.pose, log_path=self.log.log_path, include_rotamer_probs=1, use_cuda=self.use_cuda,
260 | )
261 | if self.rotamer_repack:
262 | assert np.all(self.chi_mask == self.gt_chi_mask), "gt and current pose chi masks should be the same when doing rotamer repacking"
263 |
264 | if self.anneal:
265 | self.pose.dump_pdb(self.log.log_path + "/" + "curr_pose_%s.pdb" % self.log.ts)
266 |
267 | def init_rosetta_filters(self):
268 | # initialize pyrosetta filters
269 | hbond_filter_creator = BuriedUnsatHbondFilterCreator()
270 | hydro_filter_creator = ExposedHydrophobicsFilterCreator()
271 | ps_filter_creator = PackStatFilterCreator()
272 | self.packstat_filter = ps_filter_creator.create_filter()
273 | self.exposed_hydrophobics_filter = hydro_filter_creator.create_filter()
274 | self.sc_buried_unsats_filter = hbond_filter_creator.create_filter()
275 | self.bb_buried_unsats_filter = hbond_filter_creator.create_filter()
276 | self.bb_buried_unsats_filter.set_report_bb_heavy_atom_unsats(True)
277 | self.sc_buried_unsats_filter.set_report_sc_heavy_atom_unsats(True)
278 | self.filters = [
279 | ("packstat", self.packstat_filter),
280 | ("exposed_hydrophobics", self.exposed_hydrophobics_filter),
281 | ("sc_buried_unsats", self.sc_buried_unsats_filter),
282 | ("bb_buried_unsats", self.bb_buried_unsats_filter),
283 | ]
284 |
285 | def get_blocks(self, single_res=False):
286 | # get node blocks for blocked sampling
287 | D = sampler_util.get_CB_distance(self.gt_atom_coords, self.gt_residue_bb_index_list)
288 | if single_res: # no blocked gibbs -- sampling one res at a time
289 | self.blocks = [[i] for i in np.arange(D.shape[0])]
290 | self.n_blocks = len(self.blocks)
291 | else:
292 | A = sampler_util.get_graph_from_D(D, self.threshold)
293 | # if symmetry holding --> collapse graph st. all neighbors of node i are neighbors of node i+n//4
294 | if self.symmetry:
295 | for i in range(self.n_k): # //self.k): #self.graph.shape[0]):
296 | A[i] = np.sum(np.concatenate([A[i + j * self.n_k][None] for j in range(self.k) if i + j * self.n_k < self.n]), axis=0,)
297 | for i in range(self.n_k):
298 | A[:, i] = np.sum(np.concatenate([A[:, i + j * self.n_k][None] for j in range(self.k) if i + j * self.n_k < self.n]), axis=0,)
299 | A[A > 1] = 1
300 | A = A[: self.n_k, : self.n_k]
301 |
302 | self.graph = {i: np.where(A[i, :] == 1)[0] for i in range(A.shape[0])}
303 | # min k-color of graph by greedy search
304 | nodes = np.arange(A.shape[0])
305 | np.random.shuffle(nodes)
306 | # eliminate fixed indices from list
307 | if self.symmetry:
308 | nodes = [n for n in range(self.n_k)]
309 | if len(self.fixed_idx) > 0:
310 | nodes = [n for n in nodes if n not in self.fixed_idx]
311 | elif len(self.var_idx) > 0:
312 | nodes = [n for n in nodes if n in self.var_idx]
313 | self.colors = sampler_util.color_nodes(self.graph, nodes)
314 | self.n_blocks = 0
315 | if self.colors: # check if there are any colored notes to get n-blocks (might be empty if running NATRO on all residues in resfile)
316 | self.n_blocks = sorted(list(set(self.colors.values())))[-1] + 1
317 | self.blocks = {}
318 | for k in self.colors.keys():
319 | if self.colors[k] not in self.blocks.keys():
320 | self.blocks[self.colors[k]] = []
321 | self.blocks[self.colors[k]].append(k)
322 |
323 | self.reset_block_rate = self.n_blocks
324 |
325 | def eval_metrics(self, pose, res_label):
326 | self.rosetta_energy = putil.score_pose(pose)
327 | self.curr_score_terms = pose.energies().residue_total_energies_array()
328 | self.seq_overlap = (res_label == self.gt_res_label).sum()
329 | self.filter_scores = []
330 | for n, filter in self.filters:
331 | self.filter_scores.append((n, filter.score(pose)))
332 | if self.rotamer_repack:
333 | self.chi_rmsd = sum([automorphic_rmsd(self.gt_pose.residue(i + 1), pose.residue(i + 1), True) for i in range(len(pose))]) / len(pose)
334 | else:
335 | self.chi_rmsd = 0
336 | self.seq = pose.sequence()
337 | if self.chi_mask is not None and self.rotamer_repack:
338 | chi_error = self.chi_mask * np.sqrt(
339 | (np.sin(self.chi_angles) - np.sin(self.gt_chi_angles)) ** 2 + (np.cos(self.chi_angles) - np.cos(self.gt_chi_angles)) ** 2
340 | )
341 | self.chi_error = np.sum(chi_error) / np.sum(self.chi_mask)
342 | else:
343 | self.chi_error = 0
344 |
345 | def enforce_resfile(self, logits, idx):
346 | """
347 | enforces resfile constraints by setting logits to -np.inf (see PyTorch on Categorical distribution - returns normalized value)
348 |
349 | logits - tensor where the columns are residue ids, rows are amino acid probabilities
350 | idx - residue ids
351 | """
352 | constraints, header = self.resfile[0], self.resfile[1]
353 | # iterate over all residues and check if they're to be constrained
354 | for i in idx:
355 | if i in constraints.keys():
356 | # set of amino acids to restrict in the tensor
357 | aa_to_restrict = constraints[i]
358 | for aa in aa_to_restrict:
359 | logits[i, common.atoms.aa_map_inv[aa]] = -99999
360 | elif header: # if not in the constraints, apply header (see util/resfile_util.py)
361 | aa_to_restrict = header["DEFAULT"]
362 | for aa in aa_to_restrict:
363 | logits[i, common.atoms.aa_map_inv[aa]] = -99999
364 | return logits
365 |
366 | def enforce_constraints(self, logits, idx):
367 | if self.resfile:
368 | logits = self.enforce_resfile(logits, idx)
369 | # enforce idx-wise constraints
370 | if self.no_cys:
371 | logits = logits[..., :-1]
372 | no_gly_idx = [i for i in idx if i in self.no_gly_idx]
373 | # note -- definitely other more careful ways to enforce met/gly constraints
374 | for i in idx:
375 | if self.restrict_gly:
376 | if i in self.no_gly_idx:
377 | logits[i, 18] = torch.min(logits[i])
378 | if self.no_met:
379 | logits[i, 13] = torch.min(logits[i])
380 | if self.no_cys:
381 | logits[i, 19] = torch.min(logits[i])
382 | if self.symmetry:
383 | # average logits across all symmetry postions
384 | for i in idx:
385 | logits[i] = torch.cat([logits[j][None] for j in self.symmetry_idx[i] if j < self.n], 0).mean(0)
386 | return logits
387 |
388 | def sample_rotamer(self, idx, res_idx, feat, bb_only=0):
389 | # idx --> (block) residue indices (on chain)
390 | # res_idx --> idx of residue *type* (AA type)
391 | # feat --> initial env features from conv net
392 | assert len(idx) == len(res_idx), (len(idx), len(res_idx))
393 | if bb_only:
394 | curr_models = self.init_models
395 | else:
396 | curr_models = self.models
397 |
398 | if not self.symmetry:
399 |
400 | # get residue onehot vector
401 | res_idx_long = torch.LongTensor(res_idx)
402 | res_onehot = sampler_util.make_onehot(res_idx_long.size()[0], 20, res_idx_long[:, None], use_cuda=self.use_cuda,)
403 |
404 | # get chi feat
405 | chi_feat = sampler_util.get_chi_init_feat(curr_models, feat[idx], res_onehot)
406 | # predict and sample chi angles
407 | chi_1_pred_out = sampler_util.get_chi_1_logits(curr_models, chi_feat)
408 | chi_1, chi_1_real, chi_1_onehot = sampler_util.sample_chi(chi_1_pred_out, use_cuda=self.use_cuda)
409 | chi_2_pred_out = sampler_util.get_chi_2_logits(curr_models, chi_feat, chi_1_onehot)
410 | chi_2, chi_2_real, chi_2_onehot = sampler_util.sample_chi(chi_2_pred_out, use_cuda=self.use_cuda)
411 | chi_3_pred_out = sampler_util.get_chi_3_logits(curr_models, chi_feat, chi_1_onehot, chi_2_onehot)
412 | chi_3, chi_3_real, chi_3_onehot = sampler_util.sample_chi(chi_3_pred_out, use_cuda=self.use_cuda)
413 | chi_4_pred_out = sampler_util.get_chi_4_logits(curr_models, chi_feat, chi_1_onehot, chi_2_onehot, chi_3_onehot)
414 | chi_4, chi_4_real, chi_4_onehot = sampler_util.sample_chi(chi_4_pred_out, use_cuda=self.use_cuda)
415 |
416 | return chi_1_real, chi_2_real, chi_3_real, chi_4_real, idx, res_idx
417 |
418 | else:
419 |
420 | # symmetric rotamer sampling
421 |
422 | # get symmetry indices
423 | symm_idx = []
424 | for i in idx:
425 | symm_idx.extend([j for j in self.symmetry_idx[i]])
426 |
427 | res_idx_symm = []
428 | for i, idx_i in enumerate(idx):
429 | res_idx_symm.extend([res_idx[i] for j in self.symmetry_idx[idx_i]])
430 |
431 | # get residue onehot vector
432 | res_idx_long = torch.LongTensor(res_idx_symm)
433 | res_onehot = sampler_util.make_onehot(res_idx_long.size()[0], 20, res_idx_long[:, None], use_cuda=self.use_cuda,)
434 |
435 | symm_idx_ptr = []
436 | count = 0
437 | for i, idx_i in enumerate(idx):
438 | symm_idx_ptr.append([count + j for j in range(len(self.symmetry_idx[idx_i]))])
439 | count = count + len(self.symmetry_idx[idx_i])
440 |
441 | # get chi feature vector
442 | chi_feat = sampler_util.get_chi_init_feat(curr_models, feat[symm_idx], res_onehot)
443 |
444 | # predict and sample chi for each symmetry position
445 | chi_1_pred_out = sampler_util.get_chi_1_logits(curr_models, chi_feat)
446 | chi_1_real, chi_1_onehot = sampler_util.get_symm_chi(chi_1_pred_out, symm_idx_ptr, use_cuda=self.use_cuda)
447 |
448 | chi_2_pred_out = sampler_util.get_chi_2_logits(curr_models, chi_feat, chi_1_onehot)
449 | # set debug=True below to reproduce biorxiv results. Sample uniformly 2x from predicted rotamer bin. Small bug for TIM-barrel symmetry experiments for chi_2.
450 | chi_2_real, chi_2_onehot = sampler_util.get_symm_chi(chi_2_pred_out, symm_idx_ptr, use_cuda=self.use_cuda, debug=True)
451 |
452 | chi_3_pred_out = sampler_util.get_chi_3_logits(curr_models, chi_feat, chi_1_onehot, chi_2_onehot)
453 | chi_3_real, chi_3_onehot = sampler_util.get_symm_chi(chi_3_pred_out, symm_idx_ptr, use_cuda=self.use_cuda)
454 |
455 | chi_4_pred_out = sampler_util.get_chi_4_logits(curr_models, chi_feat, chi_1_onehot, chi_2_onehot, chi_3_onehot)
456 | chi_4_real, chi_4_onehot = sampler_util.get_symm_chi(chi_4_pred_out, symm_idx_ptr, use_cuda=self.use_cuda)
457 |
458 | return (
459 | chi_1_real,
460 | chi_2_real,
461 | chi_3_real,
462 | chi_4_real,
463 | symm_idx,
464 | res_idx_symm,
465 | )
466 |
467 | def set_rotamer(self, pose, res, idx, chi_1, chi_2, chi_3, chi_4, fixed_idx=[], var_idx=[]):
468 | # res -- residue type ID
469 | # idx -- residue index on BB (0-indexed)
470 | assert len(res) == len(idx)
471 | assert len(idx) == len(chi_1), (len(idx), len(chi_1))
472 | for i, r_idx in enumerate(idx):
473 | if len(fixed_idx) > 0 and r_idx in fixed_idx:
474 | continue
475 | elif len(var_idx) > 0 and r_idx not in var_idx:
476 | continue
477 | res_i = res[i]
478 | chi_i = common.atoms.chi_dict[common.atoms.aa_inv[res_i]]
479 | if "chi_1" in chi_i.keys():
480 | pose.set_chi(1, r_idx + 1, chi_1[i] * (180 / np.pi))
481 | assert np.abs(pose.chi(1, r_idx + 1) - chi_1[i] * (180 / np.pi)) <= 1e-5, (pose.chi(1, r_idx + 1), chi_1[i] * (180 / np.pi))
482 | if "chi_2" in chi_i.keys():
483 | pose.set_chi(2, r_idx + 1, chi_2[i] * (180 / np.pi))
484 | assert np.abs(pose.chi(2, r_idx + 1) - chi_2[i] * (180 / np.pi)) <= 1e-5, (pose.chi(2, r_idx + 1), chi_2[i] * (180 / np.pi))
485 | if "chi_3" in chi_i.keys():
486 | pose.set_chi(3, r_idx + 1, chi_3[i] * (180 / np.pi))
487 | assert np.abs(pose.chi(3, r_idx + 1) - chi_3[i] * (180 / np.pi)) <= 1e-5, (pose.chi(3, r_idx + 1), chi_3[i] * (180 / np.pi))
488 | if "chi_4" in chi_i.keys():
489 | pose.set_chi(4, r_idx + 1, chi_4[i] * (180 / np.pi))
490 | assert np.abs(pose.chi(4, r_idx + 1) - chi_4[i] * (180 / np.pi)) <= 1e-5, (pose.chi(4, r_idx + 1), chi_4[i] * (180 / np.pi))
491 |
492 | return pose
493 |
494 | def sample(self, logits, idx):
495 | # sample residue from model conditional prob distribution at idx with current logits
496 | logits = self.enforce_constraints(logits, idx)
497 | dist = Categorical(logits=logits[idx])
498 | res_idx = dist.sample().cpu().data.numpy()
499 | idx_out = []
500 | res = []
501 | assert len(res_idx) == len(idx), (len(idx), len(res_idx))
502 |
503 | for k in list(res_idx):
504 | res.append(common.atoms.label_res_single_dict[k])
505 |
506 | if self.symmetry:
507 | idx_out = []
508 | for i in idx:
509 | idx_out.extend([j for j in self.symmetry_idx[i] if j < self.n])
510 | res_out = []
511 | for i, idx_i in enumerate(idx):
512 | res_out.extend([res[i] for j in self.symmetry_idx[idx_i] if j < self.n])
513 | res_idx_out = []
514 | for i, idx_i in enumerate(idx):
515 | res_idx_out.extend([res_idx[i] for j in self.symmetry_idx[idx_i] if j < self.n])
516 |
517 | assert len(idx_out) == len(res_out), (len(idx_out), len(res_out))
518 | assert len(idx_out) == len(res_idx_out), (len(idx_out), len(res_idx_out))
519 |
520 | return res_out, idx_out, res_idx_out
521 |
522 | return res, idx, res_idx
523 |
524 | def sim_anneal_step(self, e, e_old):
525 | delta_e = e - e_old
526 | if delta_e < 0:
527 | accept_prob = 1.0
528 | else:
529 | if self.anneal_start_temp == 0:
530 | accept_prob = 0
531 | else:
532 | accept_prob = torch.exp(-(delta_e) / self.anneal_start_temp).item()
533 | return accept_prob
534 |
535 | def step_T(self):
536 | # anneal temperature
537 | self.anneal_start_temp = max(self.anneal_start_temp * self.step_rate, self.anneal_final_temp)
538 |
539 | def step(self):
540 | # no blocks to sample (NATRO for all residues)
541 | if self.n_blocks == 0:
542 | self.step_anneal()
543 | return
544 |
545 | # random idx selection, draw sample
546 | idx = self.blocks[np.random.choice(self.n_blocks)]
547 |
548 | if not self.rotamer_repack:
549 | # sample new residue indices/ residues
550 | res, idx, res_idx = self.sample(self.logits, idx)
551 | else:
552 | # residue idx is fixed (identity fixed) for rotamer repacking
553 | res = [self.gt_seq[i] for i in idx]
554 | res_idx = [common.atoms.aa_map_inv[self.gt_seq[i]] for i in idx]
555 |
556 | # sample rotamer using precomputed chi_feat vector
557 | (self.chi_1, self.chi_2, self.chi_3, self.chi_4, idx, res_idx,) = self.sample_rotamer(idx, res_idx, self.chi_feat)
558 | if self.anneal:
559 | self.pose = putil.get_pose(self.log.log_path + "/" + "curr_pose_%s.pdb" % self.log.ts)
560 |
561 | # mutate residues, set rotamers
562 | res = [common.atoms.label_res_single_dict[k] for k in res_idx]
563 |
564 | if not self.use_rosetta_packer:
565 | # mutate center residue
566 | if not self.rotamer_repack:
567 | self.pose_temp = putil.mutate_list(self.pose, idx, res, pack_radius=0, fixed_idx=self.fixed_idx, var_idx=self.var_idx)
568 |
569 | else:
570 | self.pose_temp = self.pose
571 |
572 | # sample and set center residue rotamer
573 | self.pose_temp = self.set_rotamer(self.pose_temp, res, idx, self.chi_1, self.chi_2, self.chi_3, self.chi_4, fixed_idx=self.fixed_idx, var_idx=self.var_idx)
574 |
575 | else:
576 | # Pyrosetta mutate and rotamer repacking
577 | self.pose_temp = putil.mutate_list(
578 | self.pose, idx, res, pack_radius=self.pack_radius, fixed_idx=self.fixed_idx, var_idx=self.var_idx, repack_rotamers=1
579 | )
580 |
581 | # get log prob under model
582 | (
583 | self.res_label_temp,
584 | self.log_p_per_res_temp,
585 | self.log_p_mean_temp,
586 | self.logits_temp,
587 | self.chi_feat_temp,
588 | self.chi_angles_temp,
589 | self.chi_mask_temp,
590 | ) = sampler_util.get_energy(self.models, self.pose_temp, log_path=self.log.log_path, include_rotamer_probs=1, use_cuda=self.use_cuda,)
591 | if self.anneal:
592 | # simulated annealing accept/reject step
593 | self.accept_prob = self.sim_anneal_step(self.log_p_mean_temp, self.log_p_mean)
594 | r = np.random.uniform(0, 1)
595 | else:
596 | # vanilla sampling step
597 | self.accept_prob = 1
598 | r = 0
599 |
600 | if r < self.accept_prob:
601 | if self.anneal:
602 | self.pose_temp.dump_pdb(self.log.log_path + "/" + "curr_pose_%s.pdb" % self.log.ts)
603 | # update pose
604 | self.pose = self.pose_temp
605 | (self.log_p_mean, self.log_p_per_res, self.logits, self.chi_feat, self.res_label,) = (
606 | self.log_p_mean_temp,
607 | self.log_p_per_res_temp,
608 | self.logits_temp,
609 | self.chi_feat_temp,
610 | self.res_label_temp,
611 | )
612 | self.chi_angles, self.chi_mask = self.chi_angles_temp, self.chi_mask_temp
613 |
614 | # eval all metrics
615 | self.eval_metrics(self.pose, self.res_label)
616 |
617 | self.step_anneal()
618 |
619 | def step_anneal(self):
620 | # ending for step()
621 | if self.anneal:
622 | self.step_T()
623 |
624 | self.iteration += 1
625 |
626 | # reset blocks
627 | if self.reset_block_rate != 0 and (self.iteration % self.reset_block_rate == 0):
628 | self.get_blocks()
629 |
--------------------------------------------------------------------------------
/seq_des/util/README.md:
--------------------------------------------------------------------------------
1 | # Resfile Interface
2 |
3 | Authors: Damir Temir, Christian Choe
4 |
5 | ## Overview
6 |
7 | The resfile interface controls the amino acid distributions produced by the baseline and conditional models.
8 | It can be used to specify particular amino acids in certain residues,
9 | thus guiding the Protein Sequence Design algorithm to produce desired structures.
10 |
11 | Example of a resfile:
12 |
13 | ALLAA # set a default command for all residues not listed below
14 | START
15 | 34 ALLAAwc # allow all amino acids at residue #34
16 | 65 POLAR # allow only polar amino acids at residue #65
17 | 36 - 38 ALLAAxc # allow all amino acids except cysteine at residues #36 to #38 (including)
18 | 34 TPIKAA C # set the initial pose sequence postion at residue #34 to cysteine
19 | 55 - 58 NOTAA EHKNRQDST # disallow the listed amino acids at residues #55 to #58
20 | 20 NATRO # do not design the residue #20 at all
21 |
22 | ## Using resfile
23 |
24 | To use a resfile, create a new `.txt` where you specify all the flags. Then run:
25 |
26 | python3 run.py --pdb pdbs/3mx7_gt.pdb --resfile txt/resfiles/.txt
27 |
28 | ## List of Functions
29 |
30 | ### Body
31 |
32 | This is a **complete list of the commands that can be specified in the body** for particular residue ids:
33 |
34 | | Command | Description |
35 | | ------ | ----- |
36 | |ALLAA|Allows all amino acids|
37 | |ALLAAwc|Allows all amino acids (including cysteine)|
38 | |ALLAAxc|Allows all amino acids (excluding cysteine)|
39 | |POLAR|Allows only polar amino acids (DEHKNQRST)|
40 | |APOLAR|Allows only non-polar amino acids (ACFGILMPVWY)|
41 | |PIKAA|Allows only the specified amino acids|
42 | |NOTAA|Allows only those other than the specified amino acids|
43 | |NATRO|Disallows designing for that residue|
44 | |TPIKAA|Sets the specified amino acid in the initial sequence|
45 | |TNOTAA|Sets the amino acid other than the specified in the initial sequence|
46 |
47 | ### Header
48 |
49 | The header _can take_ these commands to limit **all residues not specified in the body**:
50 |
51 | | Command | Description |
52 | | ------ | ----- |
53 | |ALLAA|Allows all amino acids|
54 | |ALLAAwc|Allows all amino acids (including cysteine)|
55 | |ALLAAxc|Allows all amino acids (excluding cysteine)|
56 | |POLAR|Allows only polar amino acids (DEHKNQRST)|
57 | |APOLAR|Allows only non-polar amino acids (ACFGILMPVWY)|
58 | |PIKAA|Allows only the specified amino acids|
59 | |NOTAA|Allows only those other than the specified amino acids
60 |
61 | **NOTE**: The header command must be followed by the keyword **START** on a new line.
62 |
63 | The header _cannot take_ these commands for the following reasons:
64 |
65 | | Command | Reason |
66 | | ---- | ----- |
67 | |NATRO|Extracting residues for which the algorithm shouldn't design is a separate process. Please specify the range of residues to preserve in the body instead `ex. 1 - 90 NATRO`|
68 | |TPIKAA|Setting the particular residues in the initial sequence is a separate process. Please specify each amino acid for each residue in the body instead `5 TPIKAA C`|
69 | |TNOTAA|For the same reason as above. Please specify all amino acids to avoid in initializing for each residue instead `ex. 5 TNOTAA HKRDESTNQAVLIMFYWPG`|
70 |
71 | ### Ranges
72 |
73 | You can specify the ranges for which the command should apply. For example:
74 |
75 | 1 - 90 NATRO # will preserve all residues from residue #1 to #90 (including #90)
76 |
77 | The ranges can be specified for _all_ body commands, but **cannot be specified in the header section**.
78 |
79 | ### Initial Sequencing
80 |
81 | With the `TPIKAA` and `TNOTAA` commands we can initialize the sequence with particular amino acids.
82 |
83 | 1 TPIKAA C
84 | 2 TPIKAA T
85 | 3 TPIKAA Y
86 | 4 TNOTAA ACFGILMPVWYDEHKNQRS # will set res #4 to T since it's the only one not restricted
87 | ...
88 |
89 | Will result in an initial sequence `CTYT...`
90 |
91 | **NOTE**: you can still specify other commands for those residues that will restrict them in the following designs using the conditional model and not the baseline model.
92 |
93 | ## Results
94 |
95 | An example of a designed all-beta structure using the **backbone [3mx7_gt.pdb](../../pdbs/3mx7_gt.pdb)** with the **[resfile](../../txt/resfiles/resfile_3mx7_gt_ex1.txt)**:
96 |
97 | Before | After |
98 | :------:|:------|
99 | |
100 |
101 | An example of a designed all-alpha structure using the **backbone [1bkr_gt.pdb](../../pdbs/1bkr_gt.pdb)** with the **[resfile](../../txt/resfiles/resfile_1bkr_gt_ex6.txt)**:
102 |
103 | Before | After |
104 | :------:|:------|
105 | |
106 |
107 |
108 |
--------------------------------------------------------------------------------
/seq_des/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProteinDesignLab/protein_seq_des/bb1e5a968f84a2db189f6a7ce400b96c5eaff691/seq_des/util/__init__.py
--------------------------------------------------------------------------------
/seq_des/util/acc_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | """ accuracy eval fns """
4 |
5 | label_coarse = {0:0,1:0,2:0, 3:1,4:1, 5:2,6:2, 7:3,8:3, 9:4,10:4,11:4,12:4,13:4, 14:5,15:5,16:5, 17:6,18:7, 19:8}
6 | label_res_single_dict_coarse= {0:'(+)', 1:'(-)', 2: 'ST', 3: 'NQ', 4: 'SH', 5:'LH', 6:'P', 7:'G', 8:'C'}
7 |
8 | label_polar = {0:0, 1:0,2:0,3:0,4:0,5:0,6:0, 7:0,8:0, 9:1,10:1,11:1,12:1, 13:2, 14:1, 15:2,16:2, 17:3,18:4, 19:0}
9 | label_res_single_dict_polar={0:'polar', 1: 'nonpolar', 2: 'amphipathic', 3: 'proline', 4:'glycine'}
10 |
11 | def get_acc(logits, label, cm=None, label_dict=None, ignore_idx=None):
12 |
13 | pred = torch.argmax(logits, 1)
14 |
15 | if label_dict is not None:
16 | pred = torch.LongTensor([label_dict[p] for p in pred.cpu().data.numpy()])
17 | label = torch.LongTensor([label_dict[l] for l in label.cpu().data.numpy()])
18 |
19 | if ignore_idx is None:
20 | acc = float((pred == label).sum(-1)) / label.size()[0]
21 | else:
22 | if len(label[label != ignore_idx]) == 0:
23 | # case when all data in a batch is to be ignored
24 | acc = 0.0
25 | else:
26 | acc = float((pred[label != ignore_idx ] == label[label != ignore_idx]).sum(-1)) / len(label[label != ignore_idx])
27 |
28 | if cm is not None:
29 | if ignore_idx is None:
30 | for i in range(pred.size()[0]):
31 | # NOTE -- do not try to un-for loop this... errors
32 | cm[label[i], pred[i]] += 1
33 |
34 | else:
35 | for i in range(pred.size()[0]):
36 | # NOTE -- do not try to un-for loop this... errors
37 | if label[i] != ignore_idx:
38 | cm[label[i], pred[i]] += 1
39 |
40 | return acc, cm
41 |
42 |
43 | def get_chi_acc(logits, label, res_label, cm_dict=None, label_dict=None, ignore_idx=None):
44 |
45 | pred = torch.argmax(logits, 1)
46 |
47 | if label_dict is not None:
48 | pred = torch.LongTensor([label_dict[p] for p in pred.cpu().data.numpy()])
49 | label = torch.LongTensor([label_dict[l] for l in label.cpu().data.numpy()])
50 |
51 | if ignore_idx is None:
52 | acc = float((pred == label).sum(-1)) / label.size()[0]
53 | else:
54 | if len(label[label != ignore_idx]) == 0:
55 | # case when all data in a batch is to be ignored
56 | acc = 0.0
57 | else:
58 | acc = float((pred[label != ignore_idx ] == label[label != ignore_idx]).sum(-1)) / len(label[label != ignore_idx])
59 |
60 | if cm_dict is not None:
61 | if ignore_idx is None:
62 | for i in range(pred.size()[0]):
63 | # NOTE -- do not try to un-for loop this... errors
64 | cm_dict[res_label[i].item()][label[i], pred[i]] += 1
65 |
66 | else:
67 | for i in range(pred.size()[0]):
68 | # NOTE -- do not try to un-for loop this... errors
69 | if label[i] != ignore_idx:
70 | cm_dict[res_label[i].item()][label[i], pred[i]] += 1
71 |
72 | return acc, cm_dict
73 |
74 |
75 | def get_chi_EV(probs, label, res_label, cm_dict=None, label_dict=None, ignore_idx=None):
76 |
77 |
78 | if cm_dict is not None:
79 | if ignore_idx is None:
80 | for i in range(probs.shape[0]): #ize()[0]):
81 | # NOTE -- do not try to un-for loop this... errors
82 | cm_dict[res_label[i].item()]['ev'] += probs[i]
83 | cm_dict[res_label[i].item()]['n']+= 1
84 |
85 | else:
86 | for i in range(probs.shape[0]): #ize()[0]):
87 | # NOTE -- do not try to un-for loop this... errors
88 | if label[i] != ignore_idx:
89 | cm_dict[res_label[i].item()]['ev'] += probs[i]
90 | cm_dict[res_label[i].item()]['n']+= 1
91 |
92 | return cm_dict
93 |
94 |
95 | # from pytorch ...
96 | def get_top_k_acc(output, target, k=3, ignore_idx=None):
97 | """Computes the accuracy over the k top predictions for the specified values of k"""
98 | with torch.no_grad():
99 | batch_size = target.size(0)
100 |
101 | _, pred = output.topk(k, 1, True, True)
102 | pred = pred.t()
103 | if ignore_idx is not None:
104 | pred = pred[target !=ignore_idx]
105 | target = target[target !=ignore_idx]
106 |
107 | correct = pred.eq(target.view(1, -1).expand_as(pred))
108 |
109 | res = []
110 | correct = correct.contiguous()
111 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
112 | #res.append(correct_k.mul_(100.0 / batch_size))
113 | return correct_k.mul_(1.0 / batch_size).item()
114 |
115 | ###
116 |
117 |
118 |
119 |
--------------------------------------------------------------------------------
/seq_des/util/canonicalize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import copy
3 | import glob
4 | import pickle
5 |
6 |
7 | gly_CB_mu = np.array([-0.5311191 , -0.75842446, 1.2198311 ]) #pickle.load(open("pkl/CB_mu.pkl", "rb"))
8 |
9 |
10 | def get_len(v):
11 | return np.sqrt(np.sum(v ** 2, -1))
12 |
13 |
14 | def get_unit_normal(ab, bc):
15 | n = np.cross(ab, bc, -1)
16 | length = get_len(n)
17 | if len(n.shape) > 2:
18 | length = length[..., None]
19 | return n / length
20 |
21 |
22 | def get_angle(v1, v2):
23 | # get in plane angle between v1, v2 -- cos^-1(v1.v2 / ||v1|| ||v2||)
24 | return np.arccos(np.sum(v1 * v2, -1) / get_len(v1) * get_len(v2))
25 |
26 |
27 | def bdot(a, b):
28 | return np.matmul(a, b)
29 |
30 |
31 | def return_align_f(axis, theta):
32 | c_theta = np.cos(theta)[..., None]
33 | s_theta = np.sin(theta)[..., None]
34 | f_rot = lambda v: c_theta * v + s_theta * np.cross(axis, v, axis=-1) + (1 - c_theta) * bdot(axis, v.transpose(0, 2, 1)) * axis
35 | return f_rot
36 |
37 |
38 | def return_batch_align_f(axis, theta, n):
39 | # n is total number of atoms
40 | c_theta = np.cos(theta)
41 | s_theta = np.sin(theta)
42 | axis = np.repeat(axis, n, axis=1)[:, :, None]
43 | c_theta = np.repeat(c_theta, n, axis=1)[:, :, None, None]
44 | s_theta = np.repeat(s_theta, n, axis=1)[:, :, None, None]
45 |
46 | f_rot = lambda v: c_theta * v + s_theta * np.cross(axis, v, axis=-1) + (1 - c_theta) * bdot(axis, v.transpose(0, 1, 3, 2)) * axis
47 | return f_rot
48 |
49 |
50 | def get_batch_N_CA_C_align(normal, r, n):
51 | # get fn to align n to positive z_hat, via rotation about x axis (assume N-CA already along x_hat)
52 | # r is number of residues
53 | z = np.repeat(np.array([[0, 0, 1]]), r, 0)[:, None]
54 | theta = get_angle(normal, z)
55 | axis = get_unit_normal(normal, z)
56 | return return_align_f(axis, theta), return_batch_align_f(axis, theta, n=n)
57 |
58 |
59 | def get_batch_N_CA_align(v, r, n):
60 | # assuming ca is at (0,0,0), return fn to batch align CA--N to positive x axis
61 | # v = n - ca
62 | x = np.repeat(np.array([[1, 0, 0]])[None], r, 0)
63 | axis = get_unit_normal(v, x)
64 | theta = get_angle(v, x)
65 | return return_align_f(axis, theta), return_batch_align_f(axis, theta, n=n)
66 |
67 |
68 | def batch_canonicalize_coords(atom_coords, atom_data, residue_bb_index_list, res_idx=None, num_return=400, bb_only=0):
69 | """Function to get batch canonicalize atoms about all residues in a structure and mask out residue of interest.
70 |
71 | Args:
72 | atom_coords (np.array): num_atoms x 3 coordinates of all retained atoms in structure
73 | atom_data (np.array): num_atoms x 4 data for atoms -- [residue idx, BB ind, atom type, res type]
74 | residue_bb_index_list (np.array): num_res x 4 mapping from residue idx to atom indices for backbone atoms (N, CA, C, CB) used for canonicalization
75 | res_idx (np.array): num_output_res x 1 -- residue indices for subsampling residues ahead of canonicalization
76 | num_return (int): number of atoms to preserve about residue in environment
77 | Returns:
78 | x_coords (np.array): num_output_res x num_return x 1 x 3 -- canonicalized coordinates about each residue with center residue masked
79 | x_data (np.array): num_output_res x num_return x 1 x 4 -- metadata for canonicalized atoms for each environment
80 | """
81 |
82 | n_atoms = atom_coords.shape[0]
83 |
84 | # subsampling residues to canonicalize
85 | if res_idx is not None:
86 | residue_bb_index_list = residue_bb_index_list[res_idx]
87 | n_res = len(res_idx)
88 | else:
89 | n_res = residue_bb_index_list.shape[0]
90 |
91 | num_return = min(num_return, n_atoms - 15)
92 |
93 | idx_N, idx_CA, idx_C, idx_CB = residue_bb_index_list[:, 0], residue_bb_index_list[:, 1], residue_bb_index_list[:, 2], residue_bb_index_list[:, 3]
94 | x = atom_coords.copy()
95 |
96 | center = x[idx_CA].copy()
97 | x_idxN, x_idxC, x_idxCA, x_idxCB = x[idx_N] - center, x[idx_C] - center, x[idx_CA] - center, x[idx_CB] - center
98 | x_data = atom_data.copy()
99 |
100 | x = np.repeat(x[None], n_res, axis=0)
101 | x_data = np.repeat(x_data[None], n_res, axis=0)
102 |
103 | # center coordinates at CA position
104 | x = x - center[:, None]
105 |
106 | # for each residue, eliminate side chain residue coordinates here --
107 | bs, _, _, x_dim = x.shape
108 | x_data_dim = x_data.shape[-1]
109 |
110 | if res_idx is None:
111 | res_idx = np.arange(n_res)
112 |
113 | res_idx = np.tile(res_idx[:, None], (1, n_atoms)).reshape(-1)
114 | x = x.reshape(-1, x_dim)
115 | x_data = x_data.reshape(-1, x_data_dim)
116 | # get res_idx, indicator of bb atom
117 | x_res, x_bb, x_res_type = x_data[..., 0], x_data[..., 1], x_data[..., -1]
118 | assert len(x_res) == len(res_idx)
119 |
120 | if not bb_only:
121 | # exclude atoms on residue of interest that are not BB atoms
122 | exclude_idx = np.where((x_res == res_idx) & (x_bb != 1))[0]
123 | else:
124 | # exclude all side-chain atoms (bb only)
125 | exclude_idx = np.where((x_bb != 1))[0]
126 |
127 | # mask res type for all current residue atoms (no cheating!)
128 | res_type_exclude_idx = np.where((x_res == res_idx))[0]
129 | x_res_type[res_type_exclude_idx] = 21 # set to idx higher than highest --
130 |
131 | # move coordinates for non-include residues well out of frame of reference -- will be omitted in next step or voxelize
132 | x[exclude_idx] = x[exclude_idx] + np.array([-1000.0, -1000.0, -1000.0])
133 | x = x.reshape(bs, n_atoms, x_dim)[:, :, None]
134 |
135 | x_data = x_data.reshape(bs, n_atoms, x_data_dim)[:, :, None]
136 |
137 | # select num_return nearest atoms to env center
138 | d_x_out = np.sqrt(np.sum(x ** 2, -1))
139 | idx = np.argpartition(d_x_out, kth=num_return, axis=1)
140 | idx = idx[:, :num_return]
141 |
142 | x = np.take_along_axis(x, idx[..., None], axis=1)
143 | x_data = np.take_along_axis(x_data, idx[..., None], axis=1)
144 |
145 | n = num_return
146 |
147 | # align N-CA along x axis
148 | f_R, f_bR = get_batch_N_CA_align(x_idxN - x_idxCA, r=n_res, n=n) # um_return)
149 | x = f_bR(x)
150 | x_idxN, x_idxC, x_idxCA, x_idxCB = f_R(x_idxN), f_R(x_idxC), f_R(x_idxCA), f_R(x_idxCB)
151 |
152 | # rotate so that normal of N-CA-C plane aligns to positive z_hat
153 | normal = get_unit_normal(x_idxN, x_idxC)
154 | f_R, f_bR = get_batch_N_CA_C_align(normal, r=n_res, n=n) # um_return)
155 | x_idxN, x_idxC, x_idxCA, x_idxCB = f_R(x_idxN), f_R(x_idxC), f_R(x_idxCA), f_R(x_idxCB)
156 | x = f_bR(x)
157 |
158 | # recenter at CB
159 | fixed_CB = np.ones((x_idxCB.shape[0], 1, 3)) * gly_CB_mu
160 | x = x - fixed_CB[:, None]
161 |
162 | return x, x_data
163 |
--------------------------------------------------------------------------------
/seq_des/util/data.py:
--------------------------------------------------------------------------------
1 | import Bio.PDB
2 | import Bio.PDB.vectors
3 |
4 | import torch
5 | from torch.utils import data
6 | import torch.nn.functional as F
7 |
8 | import json
9 | import numpy as np
10 | import os
11 | import re
12 | import glob
13 |
14 | import common.atoms
15 | import seq_des.util.canonicalize as canonicalize
16 | import seq_des.util.voxelize as voxelize
17 |
18 |
19 | CHI_BINS = np.linspace(-np.pi, np.pi, num=25)
20 |
21 | def read_domain_ids_per_chain_from_txt(txt_file):
22 | pdbs = []
23 | ids_chains = {}
24 | with open(txt_file, 'r') as f:
25 | for line in f:
26 | line = line.strip('\n').split()
27 | pdbs.append(line[0][:4])
28 | ids_chains[line[0][:4]] = []
29 | with open(txt_file, 'r') as f:
30 | for line in f:
31 | line = line.strip('\n').split()
32 | if len(line) == 6: # no icodes
33 | line.extend([' ', ' '])
34 | elif len(line) == 7:
35 | line.extend([' '])
36 | pdb = line[0][:4]
37 | ids_chains[pdb].append(tuple(line)) #line[:4], line[4:]))
38 | return [(k, ids_chains[k]) for k in ids_chains.keys()]
39 |
40 |
41 | def map_to_bins(chi):
42 | # map rotamer angles to discretized bins
43 | binned_pwd = np.digitize(chi, CHI_BINS)
44 | if len(binned_pwd[binned_pwd == 0]) > 0:
45 | binned_pwd[binned_pwd == 0] = 1 # in case chi == -np.pi
46 | return binned_pwd
47 |
48 |
49 | def download_pdb(pdb, data_dir, assembly=1):
50 |
51 | """Function to download pdb -- either biological assembly or if that
52 | is not available/specified -- download default pdb structure
53 | Uses biological assembly as default, otherwise gets default pdb.
54 |
55 | Args:
56 | pdb (str): pdb ID.
57 | data_dir (str): path to pdb directory
58 |
59 | Returns:
60 | f (str): path to downloaded pdb
61 |
62 | """
63 |
64 | if assembly:
65 | f = data_dir + "/" + pdb + ".pdb1"
66 | if not os.path.isfile(f):
67 | try:
68 | os.system("wget -O {}.gz https://files.rcsb.org/download/{}.pdb1.gz".format(f, pdb.upper()))
69 | os.system("gunzip {}.gz".format(f))
70 |
71 | except:
72 | f = data_dir + "/" + pdb + ".pdb"
73 | if not os.path.isfile(f):
74 | os.system("wget -O {} https://files.rcsb.org/download/{}.pdb".format(f, pdb.upper()))
75 | else:
76 | f = data_dir + "/" + pdb + ".pdb"
77 |
78 | if not os.path.isfile(f):
79 | os.system("wget -O {} https://files.rcsb.org/download/{}.pdb".format(f, pdb.upper()))
80 |
81 | return f
82 |
83 |
84 | def get_pdb_chains(pdb, data_dir, assembly=1, skip_download=0):
85 |
86 | """Function to load pdb structure via Biopython and extract all chains.
87 | Uses biological assembly as default, otherwise gets default pdb.
88 |
89 | Args:
90 | pdb (str): pdb ID.
91 | data_dir (str): path to pdb directory
92 |
93 | Returns:
94 | chains (list of (chain, chain_id)): all pdb chains
95 |
96 | """
97 | if not skip_download:
98 | f = download_pdb(pdb, data_dir, assembly=assembly)
99 |
100 | if assembly:
101 | f = data_dir + "/" + pdb + ".pdb1"
102 | if not os.path.isfile(f):
103 | f = data_dir + "/" + pdb + ".pdb"
104 | else:
105 | f = data_dir + "/" + pdb + ".pdb"
106 |
107 | assert os.path.isfile(f)
108 | structure = Bio.PDB.PDBParser(QUIET=True).get_structure(pdb, f)
109 |
110 | assert len(structure) > 0, pdb
111 |
112 | # for assemblies -- sometimes chains are represented as different structures
113 | if len(structure) > 1:
114 | model = structure[0]
115 | count = 0
116 | for i in range(len(structure)):
117 | for c in structure[i].get_chains():
118 | try:
119 | c.id = common.atoms.rename_chains[count]
120 | except:
121 | continue
122 | count += 1
123 | try:
124 | model.add(c)
125 | except Bio.PDB.PDBExceptions.PDBConstructionException:
126 | continue
127 | else:
128 | model = structure[0]
129 |
130 | # special hard-coded case with very large assembly -- not necessary to train on all
131 | if "2y26" in pdb:
132 | return [(c, c.id) for c in model.get_chains() if c.id in ["B", "A", "E", "C", "D"]]
133 |
134 | return [(c, c.id) for c in model.get_chains()]
135 |
136 |
137 | def get_pdb_data(pdb, data_dir="", assembly=1, skip_download=0):
138 |
139 | """Function to get atom coordinates and atom/residue metadata from pdb structures.
140 |
141 | Args:
142 | pdb (str): pdb ID
143 | data_dir (str): path to pdb directory
144 | assembly (int): 0/1 indicator of whether to use biological assembly or default pdb
145 | skip_download (int): 0/1 indicator of whether to skip attempt to download pdb from remote server
146 |
147 | Returns:
148 | atom_coords (np.array): num_atoms x 3 coordinates of all retained atoms in structure
149 | atom_data (np.array): num_atoms x 4 data for atoms -- [residue idx, BB ind, atom type, res type]
150 | residue_bb_index_list (np.array): num_res x 4 mapping from residue idx to atom indices for backbone atoms (N, CA, C, CB) used for canonicalization
151 | res_data (dict of list of lists): dictionary {chain ID: [ [residue ID, residue icode, residue index, residue type], ...]}
152 | res_label (np.array): num_res x 1 residue type labels (amino acid type) for all residues (to be included in training)
153 |
154 | """
155 |
156 | # get pdb chain data
157 | pdb_chains = get_pdb_chains(pdb, data_dir, assembly=assembly, skip_download=skip_download)
158 |
159 | res_idx = 0
160 | res_data = {}
161 | atom_coords = []
162 | atom_data = []
163 | residue_bb_index = {}
164 | residue_bb_index_list = []
165 | res_label = []
166 | chis = []
167 | # iterate over chains
168 | for pdb_chain, chain_id in pdb_chains:
169 | # iterate over residues
170 | res_data[chain_id] = []
171 | for res in pdb_chain.get_residues():
172 | skip_res = False # whether to skip training directly on this residue
173 |
174 | res_name = res.get_resname()
175 | het, res_id, res_icode = res.id
176 |
177 | # skip waters, metal ions, pre-specified ligands, unknown ligands
178 | if res_name in common.atoms.skip_res_list:
179 | continue
180 |
181 | res_atoms = [atom for atom in res.get_atoms()]
182 |
183 | # skip training on residues where all BB atoms are not present -- this will break canonicalization
184 | if res_name in common.atoms.res_label_dict.keys() and len(res_atoms) < 4:
185 | skip_res = True
186 |
187 | # if residue is an amino acid, add to label and save residue ID
188 | if (not skip_res) and (res_name in common.atoms.res_label_dict.keys()):
189 | res_type = common.atoms.res_label_dict[res_name]
190 | res_data[chain_id].append((res_id, res_icode, res_idx, res_type))
191 | res_label.append(res_type)
192 | residue_bb_index[res_idx] = {}
193 |
194 | # iterate over atoms -- get coordinate data
195 | for atom in res.get_atoms():
196 |
197 | if atom.element in common.atoms.skip_atoms:
198 | continue
199 | elif atom.element not in common.atoms.atoms:
200 | if res_name == "MSE" and atom.element == "SE":
201 | elem_name = "S" # swap MET for MSE
202 | else:
203 | elem_name = "other" # all other atoms are labeled 'other'
204 | else:
205 | elem_name = atom.element
206 |
207 | # get atomic coordinate
208 | c = np.array(list(atom.get_coord()))[None].astype(np.float32)
209 |
210 | # get atom type index
211 | assert elem_name in common.atoms.atoms
212 | atom_type = common.atoms.atoms.index(elem_name)
213 |
214 | # get whether atom is a BB atom
215 | bb = int(res_name in common.atoms.res_label_dict.keys() and atom.name in ["N", "CA", "C", "O", "OXT"])
216 |
217 | if res_name in common.atoms.res_label_dict.keys():
218 | res_type_idx = common.atoms.res_label_dict[res_name]
219 | else:
220 | res_type_idx = 20 # 'other' type (ligand, ion)
221 |
222 | # index -- residue idx, bb?, atom index, residue type (AA)
223 | index = np.array([res_idx, bb, atom_type, res_type_idx])
224 | atom_coords.append(c)
225 | atom_data.append(index[None])
226 | # if atom is BB atom, add to residue_bb_index dictionary
227 | if (not skip_res) and ((res_name in common.atoms.res_label_dict.keys())):
228 | # map from residue index to atom coordinate
229 | residue_bb_index[res_idx][atom.name] = len(atom_coords) - 1
230 |
231 | # get rotamer chi angles
232 | if (not skip_res) and (res_name in common.atoms.res_label_dict.keys()):
233 | if res_name == "GLY" or res_name == "ALA":
234 | chi = [0, 0, 0, 0]
235 | mask = [0, 0, 0, 0]
236 |
237 | else:
238 | chi = []
239 | mask = []
240 | if "N" in residue_bb_index[res_idx].keys() and "CA" in residue_bb_index[res_idx].keys():
241 | n = Bio.PDB.vectors.Vector(list(atom_coords[residue_bb_index[res_idx]["N"]][0]))
242 | ca = Bio.PDB.vectors.Vector(list(atom_coords[residue_bb_index[res_idx]["CA"]][0]))
243 | if (
244 | "chi_1" in common.atoms.chi_dict[common.atoms.label_res_dict[res_type]].keys()
245 | and common.atoms.chi_dict[common.atoms.label_res_dict[res_type]]["chi_1"] in residue_bb_index[res_idx].keys()
246 | and "CB" in residue_bb_index[res_idx].keys()
247 | ):
248 | cb = Bio.PDB.vectors.Vector(list(atom_coords[residue_bb_index[res_idx]["CB"]][0]))
249 | cg = Bio.PDB.vectors.Vector(
250 | atom_coords[residue_bb_index[res_idx][common.atoms.chi_dict[common.atoms.label_res_dict[res_type]]["chi_1"]]][0]
251 | )
252 | chi_1 = Bio.PDB.vectors.calc_dihedral(n, ca, cb, cg)
253 | chi.append(chi_1)
254 | mask.append(1)
255 |
256 | if (
257 | "chi_2" in common.atoms.chi_dict[common.atoms.label_res_dict[res_type]].keys()
258 | and common.atoms.chi_dict[common.atoms.label_res_dict[res_type]]["chi_2"] in residue_bb_index[res_idx].keys()
259 | ):
260 | cd = Bio.PDB.vectors.Vector(
261 | atom_coords[residue_bb_index[res_idx][common.atoms.chi_dict[common.atoms.label_res_dict[res_type]]["chi_2"]]][0]
262 | )
263 | chi_2 = Bio.PDB.vectors.calc_dihedral(ca, cb, cg, cd)
264 | chi.append(chi_2)
265 | mask.append(1)
266 |
267 | if (
268 | "chi_3" in common.atoms.chi_dict[common.atoms.label_res_dict[res_type]].keys()
269 | and common.atoms.chi_dict[common.atoms.label_res_dict[res_type]]["chi_3"] in residue_bb_index[res_idx].keys()
270 | ):
271 | ce = Bio.PDB.vectors.Vector(
272 | atom_coords[residue_bb_index[res_idx][common.atoms.chi_dict[common.atoms.label_res_dict[res_type]]["chi_3"]]][
273 | 0
274 | ]
275 | )
276 | chi_3 = Bio.PDB.vectors.calc_dihedral(cb, cg, cd, ce)
277 | chi.append(chi_3)
278 | mask.append(1)
279 |
280 | if (
281 | "chi_4" in common.atoms.chi_dict[common.atoms.label_res_dict[res_type]].keys()
282 | and common.atoms.chi_dict[common.atoms.label_res_dict[res_type]]["chi_4"] in residue_bb_index[res_idx].keys()
283 | ):
284 | cz = Bio.PDB.vectors.Vector(
285 | atom_coords[
286 | residue_bb_index[res_idx][common.atoms.chi_dict[common.atoms.label_res_dict[res_type]]["chi_4"]]
287 | ][0]
288 | )
289 | chi_4 = Bio.PDB.vectors.calc_dihedral(cg, cd, ce, cz)
290 | chi.append(chi_4)
291 | mask.append(1)
292 | else:
293 | chi.append(0)
294 | mask.append(0)
295 | else:
296 | chi.extend([0, 0])
297 | mask.extend([0, 0])
298 |
299 | else:
300 | chi.extend([0, 0, 0])
301 | mask.extend([0, 0, 0])
302 | else:
303 | chi = [0, 0, 0, 0]
304 | mask = [0, 0, 0, 0]
305 | else:
306 | chi = [0, 0, 0, 0]
307 | mask = [0, 0, 0, 0]
308 | chi = np.array(chi)
309 | mask = np.array(mask)
310 | chis.append(np.concatenate([chi[None], mask[None]], axis=0))
311 |
312 | # add bb atom indices in residue_list to residue_bb_index dict
313 | if (not skip_res) and res_name in common.atoms.res_label_dict.keys():
314 | residue_bb_index[res_idx]["list"] = []
315 | for atom in ["N", "CA", "C", "CB"]:
316 | if atom in residue_bb_index[res_idx]:
317 | residue_bb_index[res_idx]["list"].append(residue_bb_index[res_idx][atom])
318 | else:
319 | # GLY handling for CB
320 | residue_bb_index[res_idx]["list"].append(-1)
321 |
322 | residue_bb_index_list.append(residue_bb_index[res_idx]["list"])
323 | if not skip_res and (res_name in common.atoms.res_label_dict.keys()):
324 | res_idx += 1
325 |
326 | assert len(atom_coords) == len(atom_data)
327 | assert len(residue_bb_index_list) == len(res_label)
328 | assert len(chis) == len(residue_bb_index_list)
329 |
330 | return np.array(atom_coords), np.array(atom_data), np.array(residue_bb_index_list), res_data, np.array(res_label), np.array(chis)
331 |
332 |
333 |
334 | def get_domain_envs(pdb_id, domains_list, pdb_dir="/data/drive2tb/protein/pdb", num_return=400, bb_only=0):
335 | """ Get domain specific residues and local environments by first getting full biological assembly for
336 | pdb of interest -- selecting domain specific residues.
337 |
338 | Args:
339 | pdb_id (str): pdb structure ID
340 | domains_list (list of list of tuples of str): for each domain within pdb of interest -- list of domain start, stop residue IDs and icodes
341 |
342 | Returns:
343 | atom_coords_canonicalized (np.array): n_res x n_atoms x 3 array with canonicalized local
344 | environment atom coordinates
345 | atom_data_canonicalized (np.array): n_res x n_atoms x 4 with metadata for local env atoms
346 | [residue idx, BB ind, atom type, res type]
347 | res_data (dict of list of lists): dictionary with residue metadata -- {chain ID: [ [residue ID, residue icode, residue index, residue type], ...]}
348 | res_label (np.array): num_res x 1 residue type labels (amino acid type) for all residues (to be included in training)
349 |
350 | """
351 |
352 | atom_coords, atom_data, residue_bb_index_list, res_data, res_label, chis = get_pdb_data(pdb_id, data_dir=pdb_dir)
353 | atom_coords_def = None
354 |
355 | assert len(res_label) > 0
356 |
357 | ind_assembly = []
358 | res_idx_list_domains = []
359 | # iterate over domains for PDB of interest
360 | for domain_split in domains_list:
361 | domain_id = domain_split[0]
362 | domain_split = domain_split[-1]
363 | chain_id, domains = get_domain(domain_split)
364 | res_idx_list = []
365 | # iterate over start/end cutpoints for domain
366 | if chain_id in res_data.keys():
367 | ind_assembly.append(1)
368 | else:
369 | if atom_coords_def is None:
370 | atom_coords_def, atom_data_def, residue_bb_index_list_def, res_data_def, res_label_def, chis_def = get_pdb_data(pdb_id, data_dir=pdb_dir, assembly=0)
371 |
372 | if chain_id in res_data_def.keys():
373 | ind_assembly.append(0)
374 | if atom_coords_def is None:
375 | atom_coords_def, atom_data_def, residue_bb_index_list_def, res_data_def, res_label_def, chis_def = get_pdb_data(pdb_id, data_dir=pdb_dir, assembly=0)
376 | else:
377 | print("chain not found", chain_id, res_data.keys(), res_data_def.keys())
378 | continue
379 | for ds, de in domains:
380 | start = False
381 | end = False
382 | if chain_id not in res_data.keys():
383 | for res_id, res_icode, res_idx, res_type in res_data_def[chain_id]:
384 | assert res_idx < len(res_label_def)
385 | if (res_id != ds) and not start:
386 | continue
387 | elif res_id == ds:
388 | start = True
389 | if res_id == de:
390 | end = True
391 | if start and not end:
392 | res_idx_list.append(res_idx)
393 | if end:
394 | break
395 | else:
396 | # parse chain_res_data to get res_idx for domain of interest
397 | for res_id, res_icode, res_idx, res_type in res_data[chain_id]:
398 | assert res_idx < len(res_label)
399 | if (res_id != ds) and not start:
400 | continue
401 | elif res_id == ds:
402 | start = True
403 | if res_id == de:
404 | end = True
405 | if start and not end:
406 | res_idx_list.append(res_idx)
407 | if end:
408 | break
409 | res_idx_list_domains.append(res_idx_list)
410 |
411 | assert len(res_idx_list_domains) == len(ind_assembly)
412 |
413 | atom_coords_out = []
414 | atom_data_out = []
415 | res_label_out = []
416 | domain_ids_out = []
417 | chis_out = []
418 |
419 | for i in range(len(res_idx_list_domains)):
420 | # canonicalize -- subset of residues
421 | if len(res_idx_list_domains[i]) == 0:
422 | continue
423 | if ind_assembly[i] == 1:
424 | # pull data from biological assembly
425 | atom_coords_canonicalized, atom_data_canonicalized = canonicalize.batch_canonicalize_coords(atom_coords, atom_data, residue_bb_index_list, res_idx=np.array(res_idx_list_domains[i]), num_return=num_return, bb_only=bb_only)
426 | else:
427 | # pull data from default structure
428 | atom_coords_canonicalized, atom_data_canonicalized = canonicalize.batch_canonicalize_coords(
429 | atom_coords_def, atom_data_def, residue_bb_index_list_def, res_idx=np.array(res_idx_list_domains[i]), num_return=num_return, bb_only=bb_only
430 | )
431 |
432 | atom_coords_out.append(atom_coords_canonicalized)
433 | atom_data_out.append(atom_data_canonicalized)
434 | if ind_assembly[i] == 1:
435 | res_label_out.append(res_label[res_idx_list_domains[i]])
436 | assert len(atom_coords_canonicalized) == len(res_label[res_idx_list_domains[i]])
437 | chis_out.append(chis[res_idx_list_domains[i]])
438 | else:
439 | res_label_out.append(res_label_def[res_idx_list_domains[i]])
440 | assert len(atom_coords_canonicalized) == len(res_label_def[res_idx_list_domains[i]])
441 | chis_out.append(chis_def[res_idx_list_domains[i]])
442 | domain_ids_out.append(domains_list[i][0])
443 |
444 | return atom_coords_out, atom_data_out, res_label_out, domain_ids_out, chis_out
445 |
446 |
447 | def get_domain(domain_split):
448 | # function to parse CATH domain info from txt -- returns chain and domain residue IDs
449 | chain = domain_split[-1]
450 |
451 | domains = domain_split.split(",")
452 | domains = [d[: d.rfind(":")] for d in domains]
453 |
454 | domains = [(d[: d.rfind("-")], d[d.rfind("-") + 1 :]) for d in domains]
455 | domains = [(int(re.findall("\D*\d+", ds)[0]), int(re.findall("\D*\d+", de)[0])) for ds, de in domains]
456 |
457 | return chain, np.array(domains)
458 |
459 |
460 |
461 | class PDB_domain_spitter(data.Dataset):
462 | def __init__(self, txt_file="data/052320_cath-b-newest-all.txt", pdb_path="/data/drive2tb/protein/pdb", num_return=400, bb_only=0):
463 | self.domains = read_domain_ids_per_chain_from_txt(txt_file)
464 | self.pdb_path = pdb_path
465 | self.num_return = num_return
466 | self.bb_only = bb_only
467 |
468 | def __len__(self):
469 | return len(self.domains)
470 |
471 | def __getitem__(self, index):
472 | pdb_id, domain_list = self.domains[index]
473 | return self.get_data(pdb_id, domain_list)
474 |
475 | def get_and_download_pdb(self, index):
476 | pdb_id, domain_list = self.domains[index]
477 | f = download_pdb(pdb_id, data_dir=self.pdb_path)
478 | return f
479 |
480 | def get_data(self, pdb, domain_list):
481 | try:
482 | atom_coords, atom_data, res_label, domain_id, chis = get_domain_envs(pdb, domain_list, pdb_dir=self.pdb_path, num_return=self.num_return, bb_only=self.bb_only)
483 | return atom_coords, atom_data, res_label, domain_id, chis
484 | except:
485 | return []
486 |
487 |
488 | class PDB_data_spitter(data.Dataset):
489 | def __init__(self, data_dir="/data/simdev_2tb/protein/sequence_design/data/coords/test_s95_chi/", n=20, dist=10, datalen=1000):
490 | self.files = glob.glob("%s/data*pt" % (data_dir))
491 | self.cached_pt = -1
492 | self.chunk_size = 10000 # args.chunk_size #50000i #NOTE -- CAUTION
493 | self.datalen = datalen
494 | self.data_dir = data_dir
495 | self.n = n
496 | self.dist = dist
497 | self.c = len(common.atoms.atoms)
498 | self.len = 0
499 |
500 | def __len__(self):
501 | if self.len == 0:
502 | return len(self.files) * self.chunk_size
503 | else:
504 | return self.len
505 |
506 | def get_data(self, index):
507 | if self.cached_pt != index // self.chunk_size:
508 | self.cached_pt = int(index // self.chunk_size)
509 | self.xs, self.x_data, self.ys, self.domain_ids, self.chis = torch.load("%s/data_%0.4d.pt" % (self.data_dir, self.cached_pt))
510 |
511 | index = index % self.chunk_size
512 | x, x_data, y, domain_id, chis = self.xs[index], self.x_data[index], self.ys[index], self.domain_ids[index], self.chis[index]
513 | return x, x_data, y, domain_id, chis
514 |
515 | def __getitem__(self, index): # index):
516 | x, x_data, y, domain_id, chis = self.get_data(index)
517 | ## voxelize coordinates and atom metadata
518 | bs_idx, x_atom, x_bb, x_b, y_b, z_b, x_res_type = voxelize.get_voxel_idx(x[None], x_data[None], n=self.n, c=self.c, dist=self.dist)
519 | # map chi angles to bins
520 | chi_angles = chis[0]
521 | chi_mask = chis[1]
522 | chi_angles_binned = map_to_bins(chi_angles)
523 | chi_angles_binned[chi_mask == 0] = 0 # ignore index
524 |
525 | # return domain_id, x, x_data, y, chi_angles, chi_angles_binned
526 | return bs_idx, x_atom, x_bb, x_b, y_b, z_b, x_res_type, y, chi_angles, chi_angles_binned
527 |
528 |
529 | def collate_wrapper(data, crop=True):
530 | max_n = 0
531 | for i in range(len(data)):
532 | bs_idx, x_atom, x_bb, x_b, y_b, z_b, x_res_type, y, chi_angles, chi_angles_binned = data[i][0], data[i][1], data[i][2], data[i][3], data[i][4], data[i][5], data[i][6], data[i][7], data[i][8], data[i][9]
533 | # print(bs_idx.shape, x_atom.shape, x_bb.shape, x_b.shape, y_b.shape, z_b.shape, x_res_type.shape)# if pwd is greater than CROP_SIZE -- random crop
534 | n_i = x_atom.shape[-1]
535 | # print(n_i, min_n)
536 | if n_i > max_n:
537 | max_n = n_i
538 |
539 | # pad pwd data, coords
540 | out_bs_idx = []
541 | out_y = []
542 | out_x_atom = []
543 | out_x_bb = []
544 | out_x_b = []
545 | out_y_b = []
546 | out_z_b = []
547 | out_x_res_type = []
548 | out_chi_angles = []
549 | out_chi_angles_binned = []
550 | padding = False
551 | for i in range(len(data)):
552 | bs_idx, x_atom, x_bb, x_b, y_b, z_b, x_res_type, y, chi_angles, chi_angles_binned = data[i][0], data[i][1], data[i][2], data[i][3], data[i][4], data[i][5], data[i][6], data[i][7], data[i][8], data[i][9]
553 | n_i = x_atom.shape[-1]
554 |
555 | if n_i < max_n:
556 | padding = True
557 | # zero pad all --> x, y, z indexing will be omitted
558 | x_atom = np.pad(x_atom, ((0, max_n - n_i)), mode='constant')
559 | x_b = np.pad(x_b, ((0, max_n - n_i)), mode='constant')
560 | y_b = np.pad(y_b, ((0, max_n - n_i)), mode='constant')
561 | z_b = np.pad(z_b, ((0, max_n - n_i)), mode='constant')
562 | x_bb = np.pad(x_bb, ((0, max_n - n_i)), mode='constant')
563 | x_res_type = np.pad(x_res_type, ((0, max_n - n_i)), mode='constant')
564 |
565 | # handle batch indexing correctly
566 | out_bs_idx.append(torch.Tensor([i for j in range(len(x_b))])[None])
567 | out_y.append(torch.Tensor([y])) # [None])
568 | out_x_atom.append(torch.Tensor(x_atom)[None])
569 | out_x_bb.append(torch.Tensor(x_bb)[None])
570 | out_x_b.append(torch.Tensor(x_b)[None])
571 | out_y_b.append(torch.Tensor(y_b)[None])
572 | out_z_b.append(torch.Tensor(z_b)[None])
573 | out_x_res_type.append(torch.Tensor(x_res_type)[None])
574 | out_chi_angles.append(torch.Tensor(chi_angles)[None])
575 | out_chi_angles_binned.append(torch.Tensor(chi_angles_binned)[None])
576 |
577 | out_bs_idx = torch.cat(out_bs_idx, 0)
578 | out_y = torch.cat(out_y, 0)
579 | out_x_atom = torch.cat(out_x_atom, 0)
580 | out_x_bb = torch.cat(out_x_bb, 0)
581 | out_x_b = torch.cat(out_x_b, 0)
582 | out_y_b = torch.cat(out_y_b, 0)
583 | out_z_b = torch.cat(out_z_b, 0)
584 | out_x_res_type = torch.cat(out_x_res_type, 0)
585 | out_chi_angles = torch.cat(out_chi_angles, 0)
586 | out_chi_angles_binned = torch.cat(out_chi_angles_binned, 0)
587 | return out_bs_idx.long(), out_x_atom.long(), out_x_bb.long(), out_x_b.long(), out_y_b.long(), out_z_b.long(), out_x_res_type.long(), out_y.long(), out_chi_angles, out_chi_angles_binned.long()
588 |
589 |
590 |
--------------------------------------------------------------------------------
/seq_des/util/pyrosetta_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import common.atoms
3 |
4 | from rosetta import *
5 | from pyrosetta import *
6 | init("-mute basic -mute core -mute protocols -ex1 -ex2 -constant_seed")
7 |
8 | #from pyrosetta.toolbox import pose_from_rcsb, cleanATOM # , mutate_residue
9 | from pyrosetta.rosetta.protocols.simple_moves import MutateResidue
10 |
11 | from pyrosetta.rosetta.core import conformation
12 | from pyrosetta.rosetta.core import chemical
13 | from pyrosetta.rosetta.protocols.minimization_packing import PackRotamersMover
14 |
15 | score_manager = pyrosetta.rosetta.core.scoring.ScoreTypeManager()
16 | scorefxn = get_fa_scorefxn()
17 | from pyrosetta.rosetta.core.chemical import aa_from_oneletter_code
18 |
19 |
20 | def get_seq_delta(s1, s2):
21 | count = 0
22 | for i in range(len(s1)):
23 | if s1[i] != s2[i]:
24 | count += 1
25 | return count
26 |
27 | def score_pose(pose):
28 | return scorefxn(pose)
29 |
30 | def randomize_sequence(new_seq, pose, pack_radius=5.0, fixed_idx=[], var_idx=[], ala=False, val=False, resfile_init_seq=False, enforce=False, repack_rotamers=0):
31 | for idx in range(pose.residues.__len__()):
32 | # do not mutate fixed indices / only mutate var indices
33 | if idx in fixed_idx:
34 | continue
35 | elif len(var_idx) > 0 and idx not in var_idx:
36 | continue
37 |
38 | res = pose.residue(idx + 1)
39 | ref_res_name = res.name()
40 |
41 | if ":" in ref_res_name:
42 | ref_res_name = ref_res_name[: ref_res_name.find(":")]
43 | if "_" in ref_res_name:
44 | ref_res_name = ref_res_name[: ref_res_name.find("_")]
45 |
46 | if ref_res_name not in common.atoms.res_label_dict.keys():
47 | continue
48 |
49 | if ala:
50 | r = common.atoms.res_label_dict["ALA"]
51 | elif val:
52 | r = common.atoms.res_label_dict["VAL"]
53 | else:
54 | r = new_seq[idx]
55 |
56 | res_aa = common.atoms.aa_map[r]
57 |
58 | # resfile hangling: ex. 5 TPIKAA C means set the initial sequence at residue 5 to 'C'
59 | if idx in resfile_init_seq.keys():
60 | res_aa = resfile_init_seq[idx]
61 |
62 | pose = handle_disulfide(pose, idx)
63 | mutate_residue(pose, idx + 1, res_aa, pack_radius=pack_radius, repack_rotamers=repack_rotamers)
64 |
65 | return pose, pose.residues.__len__()
66 |
67 |
68 | # from https://github.com/barricklab/mutant-protein-stability/blob/master/PyRosetta_TACC_MPI.py
69 | def handle_disulfide(pose, idx):
70 | res = pose.residue(idx + 1)
71 | if (res.name() == "CYS:disulfide") or (res.name() == "CYD"):
72 | disulfide_partner = None
73 | try:
74 | disulfide_partner = res.residue_connection_partner(res.n_residue_connections())
75 | except AttributeError:
76 | disulfide_partner = res.residue_connection_partner(res.n_current_residue_connections())
77 | temp_pose = pyrosetta.Pose()
78 | temp_pose.assign(pose)
79 | # (Packing causes seg fault if current CYS residue is not
80 | # also converted before mutating.)
81 | conformation.change_cys_state(idx + 1, "CYS", temp_pose.conformation())
82 | conformation.change_cys_state(disulfide_partner, "CYS", temp_pose.conformation())
83 | pose = temp_pose
84 | return pose
85 |
86 |
87 | def mutate(pose, idx, res, pack_radius=5.0, fixed_idx=[], var_idx=[], repack_rotamers=0):
88 | if idx in fixed_idx:
89 | return pose
90 | elif len(var_idx) > 0 and idx not in var_idx:
91 | return pose
92 | pose = handle_disulfide(pose, idx)
93 | pose = mutate_residue(pose, idx + 1, res, pack_radius=pack_radius, repack_rotamers=repack_rotamers)
94 | return pose
95 |
96 |
97 | def mutate_list(pose, idx_list, res_list, pack_radius=5.0, fixed_idx=[], var_idx=[], repack_rotamers=0):
98 | assert len(idx_list) == len(res_list), (len(idx_list), len(res_list))
99 | for i in range(len(idx_list)):
100 | idx, res = idx_list[i], res_list[i]
101 | if len(fixed_idx) > 0 and idx in fixed_idx:
102 | continue
103 | if len(var_idx) > 0 and idx not in var_idx:
104 | continue
105 | sequence = pose.sequence()
106 | pose = mutate(pose, idx, res, pack_radius=pack_radius, fixed_idx=fixed_idx, var_idx=var_idx, repack_rotamers=repack_rotamers)
107 | new_sequence = pose.sequence()
108 | assert get_seq_delta(sequence, new_sequence) <= 1, get_seq_delta(sequence, new_sequence)
109 | assert res == pose.sequence()[idx], (res, pose.sequence()[idx])
110 | return pose
111 |
112 |
113 | def get_pose(pdb):
114 | return pose_from_pdb(pdb)
115 |
116 |
117 | # from PyRosetta toolbox
118 | def restrict_non_nbrs_from_repacking(pose, res, task, pack_radius, repack_rotamers=0):
119 | """Configure a `PackerTask` to only repack neighboring residues and
120 | return the task.
121 |
122 | Args:
123 | pose (pyrosetta.Pose): The `Pose` to opertate on.
124 | res (int): Pose-numbered residue position to exclude.
125 | task (pyrosetta.rosetta.core.pack.task.PackerTask): `PackerTask` to modify.
126 | pack_radius (float): Radius used to define neighboring residues.
127 |
128 | Returns:
129 | pyrosetta.rosetta.core.pack.task.PackerTask: Configured `PackerTask`.
130 | """
131 |
132 | if not repack_rotamers:
133 | assert pack_radius == 0, "pack radius must be 0 if you don't want to repack rotamers"
134 |
135 | def representative_coordinate(resNo):
136 | return pose.residue(resNo).xyz(pose.residue(resNo).nbr_atom())
137 |
138 | center = representative_coordinate(res)
139 | for i in range(1, len(pose.residues) + 1):
140 | # only pack the mutating residue and any within the pack_radius
141 | if i == res:
142 | # comment out this block to reproduce biorxiv results
143 | #if not repack_rotamers:
144 | # task.nonconst_residue_task(i).prevent_repacking()
145 | continue
146 | if center.distance(representative_coordinate(i)) > pack_radius:
147 | task.nonconst_residue_task(i).prevent_repacking()
148 | else:
149 | if repack_rotamers:
150 | task.nonconst_residue_task(i).restrict_to_repacking()
151 | else:
152 | task.nonconst_residue_task(i).prevent_repacking()
153 |
154 | return task
155 |
156 |
157 | # modified from PyRosetta toolbox
158 | def mutate_residue(pose, mutant_position, mutant_aa, pack_radius=0.0, pack_scorefxn=None, repack_rotamers=0):
159 | """Replace the residue at a single position in a Pose with a new amino acid
160 | and repack any residues within user-defined radius of selected residue's
161 | center using.
162 |
163 | Args:
164 | pose (pyrosetta.rosetta.core.pose.Pose):
165 | mutant_position (int): Pose-numbered position of the residue to mutate.
166 | mutant_aa (str): The single letter name for the desired amino acid.
167 | pack_radius (float): Radius used to define neighboring residues.
168 | pack_scorefxn (pyrosetta.ScoreFunction): `ScoreFunction` to use when repacking the `Pose`.
169 | Defaults to the standard `ScoreFunction`.
170 | """
171 |
172 | wpose = pose
173 |
174 | if not wpose.is_fullatom():
175 | raise IOError("mutate_residue only works with fullatom poses")
176 |
177 | # create a standard scorefxn by default
178 | if not pack_scorefxn:
179 | pack_scorefxn = pyrosetta.get_score_function()
180 |
181 | # forces mutation
182 | mut = MutateResidue(mutant_position, common.atoms.aa_inv[mutant_aa])
183 | mut.apply(wpose)
184 |
185 | # the numbers 1-20 correspond individually to the 20 proteogenic amino acids
186 | mutant_aa = int(aa_from_oneletter_code(mutant_aa))
187 | aa_bool = pyrosetta.Vector1([aa == mutant_aa for aa in range(1, 21)])
188 | # mutation is performed by using a PackerTask with only the mutant
189 | # amino acid available during design
190 |
191 | task = pyrosetta.standard_packer_task(wpose)
192 | task.nonconst_residue_task(mutant_position).restrict_absent_canonical_aas(aa_bool)
193 |
194 | # prevent residues from packing by setting the per-residue "options" of the PackerTask
195 | task = restrict_non_nbrs_from_repacking(wpose, mutant_position, task, pack_radius, repack_rotamers=repack_rotamers)
196 |
197 | # apply the mutation and pack nearby residues
198 |
199 | packer = PackRotamersMover(pack_scorefxn, task)
200 | packer.apply(wpose)
201 | # return pack_or_pose
202 | return wpose
203 |
--------------------------------------------------------------------------------
/seq_des/util/resfile_util.py:
--------------------------------------------------------------------------------
1 | # developed by Damir Temir | github.com/dtemir | as a part of the RosettaCommons Summer Internship
2 |
3 | import common
4 | import re
5 |
6 | def read_resfile(filename):
7 | """
8 | read a resfile and return a dictionary of constraints for each residue id
9 |
10 | the constraints is a dictionary where the keys are residue ids and values are the amino acids to restrict
11 | (passed residue ids in the resfile are subtracted 1 because the count in PDBs starts from 1,
12 | while in the logits the count is from 0)
13 |
14 | example:
15 | 65 ALLAA # allow all amino acids at residue id 65 (64 in the tensor)
16 | 54 ALLAAxc # allow all amino acids except cysteine at residue id 54 (53 in the tensor)
17 | 30 POLAR # allow only polar amino acids at residue id 30 (29 in the tensor)
18 | 31 - 33 NOTAA CFYG # disallow the specified amino acids at residue ids 31 to 33 (30 to 32 in the tensor)
19 | 43 TPIKAA C # allow only cysteine when initializing the sequence (same logic for TNOTAA)
20 |
21 | results into a dictionary:
22 | {64: {}, 53: {'C'}, 29: {'T', 'R', 'K', 'Q', 'D', 'E', 'S', 'N', 'H'},
23 | 30: {'C', 'F', 'Y', 'G'}, 31: {'C', 'F', 'Y', 'G'}, 32: {'C', 'F', 'Y', 'G'}}
24 |
25 | plus it returns a header from check_for_header():
26 | {"DEFAULT": {}}
27 |
28 | plus it returns a dictionary with the amino acids for initial sequence (NOTE: amino acids listed will NOT be used to initialize the sequence)
29 | {42: 'C'}
30 | """
31 | def place_constraints(constraint, init_seq):
32 | """
33 | places the constraints in the appropriate dicts
34 | -initial_seq for building the initial sequence with TPIKAA and TNOTAA
35 | -constraints for restricting the conditional model with PIKAA, NOTAA, ALLAA, POLAR, etc.
36 | """
37 | if not init_seq:
38 | constraints[res_id] = constraint
39 | else:
40 | initial_seq[res_id] = constraint
41 |
42 | constraints = dict() # amino acids to restrict in the design
43 | header, start_id = check_for_header(filename) # amino acids to use as default for those not specified in constraints
44 | initial_seq = dict() # amino acids to use when initializing the sequence
45 |
46 | with open(filename, "r") as f:
47 | # iterate over the lines and extract arguments (residue id, command)
48 | lines = f.readlines()
49 | for line in lines[start_id + 1:]:
50 | args = [arg.strip() for arg in line.split(" ")]
51 | is_integer(args[0]) # the res id needs to be an integer
52 | assert isinstance(args[1], str), "the resfile command needs to be a string"
53 |
54 | res_id = int(args[0]) - 1
55 | if args[1] == "-": # if given a range of residue ids (ex. 31 - 33 NOTAA)
56 | is_integer(args[2]) # the res id needs to be an integer
57 | for res_id in range(res_id, int(args[2])):
58 | constraint, init_seq = check_for_commands(args, 3, 4)
59 | place_constraints(constraint, init_seq)
60 | else: # if not given a range (ex. 31 NOTAA CA)
61 | constraint, init_seq = check_for_commands(args, 1, 2)
62 | place_constraints(constraint, init_seq)
63 |
64 | # update the initial seq dictionary to only have one element per residue id (at random)
65 | initial_seq = {res_id : (common.atoms.resfile_commands["ALLAAwc"] - restricted_aa).pop() for res_id, restricted_aa in initial_seq.items()}
66 |
67 | return constraints, header, initial_seq
68 |
69 | def check_for_header(filename):
70 | """
71 | read a resfile and return the header if present
72 |
73 | the header is commands that should be applied by default
74 | to all residues that are not specified after the 'start' keyword
75 |
76 | example of a header:
77 | ALLLA # allows all amino acids for residues that are not specified in the body
78 | START # divides the body and header
79 | # ... the body starts here, see read_resfile()
80 | """
81 | header = {}
82 | start_id = -1
83 | with open(filename, "r") as f:
84 | start = re.compile(r"\bSTART|start\b")
85 | # if the file has the keyword start, extract header
86 | if bool(start.search(f.read())):
87 | f.seek(0) # set the cursor back to the beginning
88 | lines = f.readlines()
89 | for i, line in enumerate(lines):
90 | if start.match(line):
91 | start_id = i # the line number where start is used (divides header and body)
92 | break
93 | args = line.split()
94 | args.insert(0, "") # check_for_commands only handles the second argument (first is usually res_id)
95 | header['DEFAULT'] = check_for_commands(args, 1, 2)
96 |
97 | return header, start_id
98 |
99 |
100 | def check_for_commands(args, command_id, list_id):
101 | """
102 | converts given commands into sets of amino acids to restrict in the logits
103 |
104 | so far, it handles these commands: ALLAA, ALLAAxc, POLAR, APOLAR, NOTAA, PIKAA
105 |
106 | command_id - the index where the command is within the args
107 | list_id - the index where the possible list of AA is within the args (only for NOTAA and PIKAA)
108 | """
109 | constraint = set()
110 | command = args[command_id].upper()
111 | init_seq = False # reflect if it's TPIKAA or TNOTAA
112 | if command in common.atoms.resfile_commands.keys():
113 | constraint = common.atoms.resfile_commands["ALLAAwc"] - common.atoms.resfile_commands[command]
114 | elif "PIKAA" in command: # allow only the specified amino acids
115 | constraint = common.atoms.resfile_commands["ALLAAwc"] - set(args[list_id].strip())
116 | elif "NOTAA" in command: # disallow only the specified amino acids
117 | constraint = set(args[list_id].strip())
118 |
119 | if command == "TPIKAA" or command == "TNOTAA":
120 | init_seq = True
121 |
122 | return constraint, init_seq
123 |
124 | def get_natro(filename):
125 | """
126 | provides a list of indecies whose input rotamers and identities need to be presevered (Native Rotamer - NATRO)
127 |
128 | overrides the sampler.py's self.fixed_idx attribute with a list of the NATRO residues to be skipped in the
129 | self.get_blocks() function that picks sampling blocks
130 |
131 | if ALL residues in the resfile are NATRO, the sampler.py's self.step() skips running the neural network for
132 | amino acid prediction AND rotamer prediction
133 | """
134 | fixed_idx = set()
135 | with open(filename, "r") as f:
136 | lines = f.readlines()
137 | for line in lines:
138 | args = [arg.strip().upper() for arg in line.split(" ")]
139 | if "NATRO" in args:
140 | is_integer(args[0])
141 | if args[1] == "-": # provided a range of NATRO residues
142 | is_integer(args[2])
143 | fixed_idx.update(range(int(args[0]) - 1, int(args[2])))
144 | else: # provided a single NATRO residue
145 | fixed_idx.add(int(args[0]) - 1)
146 |
147 | return list(fixed_idx)
148 |
149 | def is_integer(n):
150 | try:
151 | int(n)
152 | except ValueError:
153 | raise ValueError("Incorrect residue index in the resfile ", n)
154 |
--------------------------------------------------------------------------------
/seq_des/util/sampler_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | import seq_des.util.data as data
5 | import seq_des.util.canonicalize as canonicalize
6 | import seq_des.util.voxelize as voxelize
7 | import common.atoms
8 |
9 | import torch.nn.functional as F
10 | from torch.distributions.categorical import Categorical
11 |
12 |
13 | def get_idx(filename):
14 | # get variable or fixed indices from list
15 | with open(filename, "r") as f:
16 | lines = list(f)
17 | idx = [int(line.strip("\n").split()[0]) for line in lines]
18 | return idx
19 |
20 |
21 | def get_CB_distance(x, x_data):
22 | # get CB-CB pairwise distances
23 | A = []
24 | for k in range(x_data.shape[0]):
25 | idx_CA, idx_CB = x_data[k, 1], x_data[k, -1]
26 | if idx_CB >= 0:
27 | A.append(x[idx_CB])
28 | else:
29 | A.append(x[idx_CA])
30 | A = np.array(A)[:, 0, :3]
31 | D = np.sqrt(np.sum((A[:, None].repeat(len(A), axis=1) - A[None].repeat(len(A), axis=0)) ** 2, -1))
32 | return D
33 |
34 |
35 | def get_graph_from_D(D, threshold):
36 | A = np.zeros_like(D)
37 | A[D < threshold] = 1
38 | return A
39 |
40 |
41 | def make_onehot(bs, dim, scatter_tensor, use_cuda=1):
42 | onehot = torch.FloatTensor(bs, dim)
43 | onehot.zero_()
44 | onehot.scatter_(1, scatter_tensor, 1)
45 | if use_cuda:
46 | return onehot.cuda()
47 | else:
48 | return onehot
49 |
50 |
51 | def get_energy_from_logits(logits, res_idx, mask=None, baseline=0):
52 | # get negative log prob from logits
53 | log_p = -F.log_softmax(logits, -1).gather(1, res_idx[:, None])
54 | if mask is not None:
55 | log_p[mask == 1] = baseline
56 | log_p_mean = log_p.mean()
57 | return log_p, log_p_mean
58 |
59 |
60 | def get_conv_feat(
61 | curr_models, atom_coords, atom_data, residue_bb_index_list, res_data, res_label, chis, bb_only=0, return_chi=0, use_cuda=1
62 | ):
63 | atom_coords_canonicalized, atom_data_canonicalized = canonicalize.batch_canonicalize_coords(
64 | atom_coords, atom_data, residue_bb_index_list, bb_only=bb_only
65 | )
66 |
67 | x = atom_coords_canonicalized
68 | y = res_label
69 | x_data = atom_data_canonicalized
70 |
71 | voxels = voxelize.voxelize(x, x_data, n=20, c=len(common.atoms.atoms), dist=10, bb_only=bb_only)
72 | voxels = torch.FloatTensor(voxels)
73 | bs_i = voxels.size()[0]
74 | if use_cuda:
75 | voxels = voxels.cuda()
76 |
77 | # map chi angles to bins
78 | chi_angles = chis[:, 0]
79 | chi_mask = chis[:, 1]
80 | chi_angles_binned = data.map_to_bins(chi_angles)
81 | chi_angles_binned[chi_mask == 0] = 0
82 | chi_angles_binned = torch.LongTensor(chi_angles_binned)
83 |
84 | chi_1 = chi_angles_binned[..., 0]
85 | chi_2 = chi_angles_binned[..., 1]
86 | chi_3 = chi_angles_binned[..., 2]
87 | chi_4 = chi_angles_binned[..., 3]
88 |
89 | # get chi onehot vectors -- NOTE can make this faster by precomputing, saving zero tensors
90 | chi_1_onehot = make_onehot(bs_i, len(data.CHI_BINS), chi_1[:, None], use_cuda=use_cuda)
91 | chi_2_onehot = make_onehot(bs_i, len(data.CHI_BINS), chi_2[:, None], use_cuda=use_cuda)
92 | chi_3_onehot = make_onehot(bs_i, len(data.CHI_BINS), chi_3[:, None], use_cuda=use_cuda)
93 | chi_4_onehot = make_onehot(bs_i, len(data.CHI_BINS), chi_4[:, None], use_cuda=use_cuda)
94 |
95 | y = torch.LongTensor(y)
96 | y_onehot = make_onehot(bs_i, 20, y[:, None], use_cuda=use_cuda)
97 | if use_cuda:
98 | y = y.cuda()
99 |
100 | # ensemble prediction over all models -- average logits
101 | logits_out = []
102 | chi_feat_out = []
103 | chi_1_out = []
104 | chi_2_out = []
105 | chi_3_out = []
106 | chi_4_out = []
107 |
108 | with torch.no_grad():
109 | for model in curr_models:
110 | feat, res_pred_logits, chi_1_pred, chi_2_pred, chi_3_pred, chi_4_pred = model.get_feat(
111 | voxels, y_onehot, chi_1_onehot[:, 1:], chi_2_onehot[:, 1:], chi_3_onehot[:, 1:]
112 | )
113 | logits_out.append(res_pred_logits[None])
114 | chi_feat_out.append(feat[None])
115 | chi_1_out.append(chi_1_pred[None])
116 | chi_2_out.append(chi_2_pred[None])
117 | chi_3_out.append(chi_3_pred[None])
118 | chi_4_out.append(chi_4_pred[None])
119 |
120 | logits_out = torch.cat(logits_out, 0).mean(0)
121 | chi_feat_out = torch.cat(chi_feat_out, 0).mean(0)
122 | chi_1_logits = torch.cat(chi_1_out, 0).mean(0)
123 | chi_2_logits = torch.cat(chi_2_out, 0).mean(0)
124 | chi_3_logits = torch.cat(chi_3_out, 0).mean(0)
125 | chi_4_logits = torch.cat(chi_4_out, 0).mean(0)
126 |
127 | chi_1 = chi_1 - 1
128 | chi_2 = chi_2 - 1
129 | chi_3 = chi_3 - 1
130 | chi_4 = chi_4 - 1
131 |
132 | if use_cuda:
133 | chi_1 = (chi_1).cuda()
134 | chi_2 = (chi_2).cuda()
135 | chi_3 = (chi_3).cuda()
136 | chi_4 = (chi_4).cuda()
137 |
138 | return (
139 | logits_out,
140 | chi_feat_out,
141 | y,
142 | chi_1_logits,
143 | chi_2_logits,
144 | chi_3_logits,
145 | chi_4_logits,
146 | chi_1,
147 | chi_2,
148 | chi_3,
149 | chi_4,
150 | chi_angles,
151 | chi_mask,
152 | )
153 |
154 |
155 | def get_energy_from_feat(
156 | models,
157 | logits,
158 | chi_feat,
159 | y,
160 | chi_1_logits,
161 | chi_2_logits,
162 | chi_3_logits,
163 | chi_4_logits,
164 | chi_1,
165 | chi_2,
166 | chi_3,
167 | chi_4,
168 | chi_angles,
169 | chi_mask,
170 | include_rotamer_probs=0,
171 | return_log_ps=0,
172 | use_cuda=True,
173 | ):
174 | # get residue log probs
175 | # energy, energy_per_res,
176 | log_p_per_res, log_p_mean = get_energy_from_logits(logits, y)
177 |
178 | # get rotamer log_probs
179 | chi_1_mask = torch.zeros_like(chi_1)
180 | chi_2_mask = torch.zeros_like(chi_2)
181 | chi_3_mask = torch.zeros_like(chi_3)
182 | chi_4_mask = torch.zeros_like(chi_4)
183 |
184 | if use_cuda:
185 | chi_1_mask = chi_1_mask.cuda()
186 | chi_2_mask = chi_2_mask.cuda()
187 | chi_3_mask = chi_3_mask.cuda()
188 | chi_4_mask = chi_4_mask.cuda()
189 |
190 | chi_1_mask[chi_1 < 0] = 1
191 | chi_2_mask[chi_2 < 0] = 1
192 | chi_3_mask[chi_3 < 0] = 1
193 | chi_4_mask[chi_4 < 0] = 1
194 |
195 | chi_1[chi_1 < 0] = 0
196 | chi_2[chi_2 < 0] = 0
197 | chi_3[chi_3 < 0] = 0
198 | chi_4[chi_4 < 0] = 0
199 |
200 | log_p_per_res_chi_1, log_p_per_res_chi_1_mean = get_energy_from_logits(chi_1_logits, chi_1, mask=chi_1_mask, baseline=1.3183412514892)
201 | log_p_per_res_chi_2, log_p_per_res_chi_2_mean = get_energy_from_logits(chi_2_logits, chi_2, mask=chi_2_mask, baseline=1.5970909799808386)
202 | log_p_per_res_chi_3, log_p_per_res_chi_3_mean = get_energy_from_logits(chi_3_logits, chi_3, mask=chi_3_mask, baseline=2.231545756901711)
203 | log_p_per_res_chi_4, log_p_per_res_chi_4_mean = get_energy_from_logits(chi_4_logits, chi_4, mask=chi_4_mask, baseline=2.084356748355477)
204 |
205 | if return_log_ps:
206 | return log_p_mean, log_p_per_res_chi_1_mean, log_p_per_res_chi_2_mean, log_p_per_res_chi_3_mean, log_p_per_res_chi_4_mean
207 |
208 | if include_rotamer_probs:
209 | # get per residue log probs (autoregressive)
210 | log_p_per_res = log_p_per_res + log_p_per_res_chi_1 + log_p_per_res_chi_2 + log_p_per_res_chi_3 + log_p_per_res_chi_4
211 | # optimize mean log prob across residues
212 | log_p_mean = log_p_per_res.mean()
213 |
214 | return log_p_per_res, log_p_mean
215 |
216 |
217 | def get_energy(models, pose=None, pdb=None, chain="A", bb_only=0, return_chi=0, use_cuda=1, log_path="./", include_rotamer_probs=0):
218 | if pdb is not None:
219 | atom_coords, atom_data, residue_bb_index_list, res_data, res_label, chis = data.get_pdb_data(
220 | pdb[pdb.rfind("/") + 1 : -4], data_dir=pdb[: pdb.rfind("/")], skip_download=1, assembly=0
221 | )
222 | else:
223 | assert pose is not None, "need to specify pose to calc energy"
224 | pose.dump_pdb(log_path + "/" + "curr.pdb")
225 | atom_coords, atom_data, residue_bb_index_list, res_data, res_label, chis = data.get_pdb_data(
226 | "curr", data_dir=log_path, skip_download=1, assembly=0
227 | )
228 |
229 | # get residue and rotamer logits
230 | logits, chi_feat, y, chi_1_logits, chi_2_logits, chi_3_logits, chi_4_logits, chi_1, chi_2, chi_3, chi_4, chi_angles, chi_mask = get_conv_feat(
231 | models, atom_coords, atom_data, residue_bb_index_list, res_data, res_label, chis, bb_only=bb_only, return_chi=return_chi, use_cuda=use_cuda
232 | )
233 |
234 | # get model negative log probs (model energy)
235 | log_p_per_res, log_p_mean = get_energy_from_feat(
236 | models,
237 | logits,
238 | chi_feat,
239 | y,
240 | chi_1_logits,
241 | chi_2_logits,
242 | chi_3_logits,
243 | chi_4_logits,
244 | chi_1,
245 | chi_2,
246 | chi_3,
247 | chi_4,
248 | chi_angles,
249 | chi_mask,
250 | include_rotamer_probs=include_rotamer_probs,
251 | use_cuda=use_cuda,
252 | )
253 |
254 | if return_chi:
255 | return res_label, log_p_per_res, log_p_mean, logits, chi_feat, chi_angles, chi_mask, [chi_1, chi_2, chi_3, chi_4]
256 | return res_label, log_p_per_res, log_p_mean, logits, chi_feat, chi_angles, chi_mask
257 |
258 |
259 | def get_chi_init_feat(curr_models, feat, res_onehot):
260 | chi_feat_out = []
261 | with torch.no_grad():
262 | for model in curr_models:
263 | chi_feat = model.get_chi_init_feat(feat, res_onehot)
264 | chi_feat_out.append(chi_feat[None])
265 | chi_feat = torch.cat(chi_feat_out, 0).mean(0)
266 | return chi_feat
267 |
268 |
269 | def get_chi_1_logits(curr_models, chi_feat):
270 | chi_1_pred_out = []
271 | with torch.no_grad():
272 | for model in curr_models:
273 | chi_1_pred = model.get_chi_1(chi_feat)
274 | chi_1_pred_out.append(chi_1_pred[None])
275 | chi_1_pred_out = torch.cat(chi_1_pred_out, 0).mean(0)
276 | return chi_1_pred_out
277 |
278 |
279 | def get_chi_2_logits(curr_models, chi_feat, chi_1_onehot):
280 | chi_2_pred_out = []
281 | with torch.no_grad():
282 | for model in curr_models:
283 | chi_2_pred = model.get_chi_2(chi_feat, chi_1_onehot)
284 | chi_2_pred_out.append(chi_2_pred[None])
285 | chi_2_pred_out = torch.cat(chi_2_pred_out, 0).mean(0)
286 | return chi_2_pred_out
287 |
288 |
289 | def get_chi_3_logits(curr_models, chi_feat, chi_1_onehot, chi_2_onehot):
290 | chi_3_pred_out = []
291 | with torch.no_grad():
292 | for model in curr_models:
293 | chi_3_pred = model.get_chi_3(chi_feat, chi_1_onehot, chi_2_onehot)
294 | chi_3_pred_out.append(chi_3_pred[None])
295 | chi_3_pred_out = torch.cat(chi_3_pred_out, 0).mean(0)
296 | return chi_3_pred_out
297 |
298 |
299 | def get_chi_4_logits(curr_models, chi_feat, chi_1_onehot, chi_2_onehot, chi_3_onehot):
300 | chi_4_pred_out = []
301 | with torch.no_grad():
302 | for model in curr_models:
303 | chi_4_pred = model.get_chi_4(chi_feat, chi_1_onehot, chi_2_onehot, chi_3_onehot)
304 | chi_4_pred_out.append(chi_4_pred[None])
305 | chi_4_pred_out = torch.cat(chi_4_pred_out, 0).mean(0)
306 | return chi_4_pred_out
307 |
308 |
309 | def sample_chi(chi_logits, use_cuda=True):
310 | # sample chi bin from predicted distribution
311 | chi_dist = Categorical(logits=chi_logits)
312 | chi_idx = chi_dist.sample().cpu().data.numpy()
313 | chi = torch.LongTensor(chi_idx)
314 | # get one-hot encoding of sampled bin for autoregressive unroll
315 | chi_onehot = make_onehot(chi_logits.size()[0], len(data.CHI_BINS) - 1, chi[:, None], use_cuda=use_cuda)
316 | # sample chi angle (real) uniformly within bin
317 | chi_real = np.random.uniform(low=data.CHI_BINS[chi_idx], high=data.CHI_BINS[chi_idx + 1])
318 | return chi, chi_real, chi_onehot
319 |
320 |
321 | def get_symm_chi(chi_pred_out, symm_idx_ptr, use_cuda=True, debug=False):
322 | chi_pred_out_symm = []
323 | for i, ptr in enumerate(symm_idx_ptr):
324 | chi_pred_out_symm.append(chi_pred_out[ptr].mean(0)[None])
325 | chi_pred_out = torch.cat(chi_pred_out_symm, 0)
326 | chi, chi_real, chi_onehot = sample_chi(chi_pred_out, use_cuda=use_cuda)
327 | if debug:
328 | # sample uniformly again from predicted bin. small bug for TIM-barrel symmetry experiments. ¯\_(ツ)_/¯
329 | chi, chi_real, chi_onehot = sample_chi(chi_pred_out, use_cuda=use_cuda)
330 |
331 | chi_real_out = []
332 | for i, ptr in enumerate(symm_idx_ptr):
333 | chi_real_out.append([chi_real[i][None] for j in range(len(ptr))]) # , 0))
334 | chi_real = np.concatenate(chi_real_out, axis=0)
335 |
336 | chi_onehot_out = []
337 | for i, ptr in enumerate(symm_idx_ptr):
338 | chi_onehot_out.append(torch.cat([chi_onehot[i][None] for j in range(len(ptr))], 0))
339 | chi_onehot = torch.cat(chi_onehot_out, 0)
340 | return chi_real, chi_onehot
341 |
342 |
343 | # from https://codereview.stackexchange.com/questions/203319/greedy-graph-coloring-in-python
344 | def color_nodes(graph, nodes):
345 | color_map = {}
346 | # Consider nodes in descending degree
347 | for node in nodes: # sorted(graph, key=lambda x: len(graph[x]), reverse=True):
348 | neighbor_colors = set(color_map.get(neigh) for neigh in graph[node])
349 | color_map[node] = next(color for color in range(len(graph)) if color not in neighbor_colors)
350 | return color_map
351 |
352 |
353 | ################################
354 |
--------------------------------------------------------------------------------
/seq_des/util/voxelize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def voxelize(x, x_data, n=20, c=13, dist=10, plot=False, bb_only=0):
5 | """Function to voxelize atom coordinate data ahead of training. Could be sped up on GPU
6 |
7 | Args:
8 | x_coords (np.array): num_res x num_return x 1 x 3 -- canonicalized coordinates about each residue with center residue masked
9 | x_data (np.array): num_res x num_return x 1 x 4 -- metadata for canonicalized atoms for each environment
10 | Returns:
11 | output (np.array): num_res x c x n x n x n -- 3D environments centered at each residue with atom type in channel dimensions
12 | """
13 |
14 | bins = np.linspace(-dist, dist, n + 1)
15 | bs, nres, _, x_dim = x.shape
16 | x_data_dim = x_data.shape[-1]
17 | x = x.reshape(bs * nres, -1, x_dim)
18 | x_data = x_data.reshape(bs * nres, -1, x_data_dim)
19 | x_atom = x_data[..., 2].astype(np.int64)
20 | x_res_type = x_data[..., -1].astype(np.int64)
21 | x_bb = x_data[..., 1].astype(np.int64)
22 |
23 | bs_idx = np.tile(np.arange(bs)[:, None], (1, nres)).reshape(-1)
24 | # coordinates to voxels
25 | x_b = np.digitize(x[..., 0], bins) # [:, 0]
26 | y_b = np.digitize(x[..., 1], bins) # [:, 0]
27 | z_b = np.digitize(x[..., 2], bins) # [:, 0]
28 |
29 | # eliminate 'other' atoms
30 | x_atom[x_atom > c - 1] = c # force any non-listed atoms into 'other' category
31 |
32 | # this step can possibly be moved to GPU
33 | output_atom = np.zeros((bs, c + 1, n + 2, n + 2, n + 2))
34 | output_atom[bs_idx, x_atom[:, 0], x_b[:, 0], y_b[:, 0], z_b[:, 0]] = 1 # atom type
35 | if not bb_only:
36 | output_bb = np.zeros((bs, 2, n + 2, n + 2, n + 2))
37 | output_bb[bs_idx, x_bb[:, 0], x_b[:, 0], y_b[:, 0], z_b[:, 0]] = 1 # BB indicator
38 | output_res = np.zeros((bs, 22, n + 2, n + 2, n + 2))
39 | output_res[bs_idx, x_res_type[:, 0], x_b[:, 0], y_b[:, 0], z_b[:, 0]] = 1 # res type for each atom
40 | # eliminate last channel for output_atom ('other' atom type), output_bb, and output_res (res type for current side chain)
41 | output = np.concatenate([output_atom[:, :c], output_bb[:, :1], output_res[:, :21]], 1)
42 | else:
43 | output = output_atom[:, :c]
44 |
45 | output = output[:, :, 1:-1, 1:-1, 1:-1]
46 |
47 | return output
48 |
49 |
50 | def get_voxel_idx(x, x_data, n=20, c=13, dist=10, plot=False):
51 | """Function to get indices for voxelized atom coordinate data ahead of training.
52 |
53 | Args:
54 | x_coords (np.array): num_res x num_return x 1 x 3 -- canonicalized coordinates about each residue with center residue masked
55 | x_data (np.array): num_res x num_return x 1 x 4 -- metadata for canonicalized atoms for each environment
56 | Returns:
57 | #NOTE -- FIX THIS
58 | output (np.array): num_res x c x n x n x n -- 3D environments centered at each residue with atom type in channel dimensions
59 | """
60 |
61 | bins = np.linspace(-dist, dist, n + 1)
62 | bs, nres, _, x_dim = x.shape
63 | x_data_dim = x_data.shape[-1]
64 | x = x.reshape(bs * nres, -1, x_dim)
65 | x_data = x_data.reshape(bs * nres, -1, x_data_dim)
66 | x_atom = x_data[..., 2].astype(np.int64)
67 | x_res_type = x_data[..., -1].astype(np.int64) # not used for now
68 | x_bb = x_data[..., 1].astype(np.int64)
69 |
70 | bs_idx = np.tile(np.arange(bs)[:, None], (1, nres)).reshape(-1)
71 |
72 | # coordinates to voxels
73 | x_b = np.digitize(x[..., 0], bins) # [:, 0]
74 | y_b = np.digitize(x[..., 1], bins) # [:, 0]
75 | z_b = np.digitize(x[..., 2], bins) # [:, 0]
76 |
77 | # eliminate 'other' atoms
78 | x_atom[x_atom > c - 1] = c # force any non-listed atoms into 'other' category
79 | # print(x_atom.shape, x_res_type.shape, x_bb.shape)
80 |
81 | return bs_idx, x_atom[..., 0], x_bb[..., 0], x_b[..., 0], y_b[..., 0], z_b[..., 0], x_res_type[..., 0]
82 |
83 |
84 |
--------------------------------------------------------------------------------
/seq_des_info.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProteinDesignLab/protein_seq_des/bb1e5a968f84a2db189f6a7ce400b96c5eaff691/seq_des_info.pdf
--------------------------------------------------------------------------------
/train_autoreg_chi.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | from tqdm import tqdm
9 | import common.run_manager
10 | import seq_des.models as models
11 | import seq_des.util.voxelize as voxelize
12 | import glob
13 | import seq_des.util.canonicalize as canonicalize
14 | import pickle
15 | import seq_des.util.data as datasets
16 | from torch.utils import data
17 | import common.atoms
18 | import seq_des.util.acc_util as acc_util
19 | import subprocess as sp
20 | import time
21 | import torch.nn.functional as F
22 |
23 | """ script to train 3D CNN on local residue-centered environments -- with autoregressive rotamer chi angle prediction"""
24 |
25 | dist = 10
26 | n = 20
27 | c = len(common.atoms.atoms)
28 |
29 |
30 | def test(model, gen, dataloader, criterion, chi_1_criterion, chi_2_criterion, chi_3_criterion, chi_4_criterion, max_it=1e6, desc="test", batch_size=64, n_iters=500, k=3, return_cm=False, use_cuda=True):
31 | n_iters = min(max_it, n_iters)
32 | model = model.eval()
33 | gen = iter(dataloader)
34 | losses, avg_acc, avg_top_k_acc, avg_coarse_acc, avg_polar_acc, avg_chi_1_acc, avg_chi_2_acc, avg_chi_3_acc, avg_chi_4_acc, avg_chi_1_loss, avg_chi_2_loss, avg_chi_3_loss, avg_chi_4_loss = ([] for i in range(13))
35 |
36 | with torch.no_grad():
37 |
38 | for i in tqdm(range(n_iters), desc=desc):
39 | try:
40 | out = gen.next()
41 | except StopIteration:
42 | gen = iter(dataloader)
43 | out = gen.next()
44 |
45 | out = step(model, out, criterion, chi_1_criterion, chi_2_criterion, chi_3_criterion, chi_4_criterion, use_cuda=use_cuda)
46 |
47 | if out is None:
48 | continue
49 | loss, chi_1_loss, chi_2_loss, chi_3_loss, chi_4_loss, out, y, acc, top_k_acc, coarse_acc, polar_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc = out
50 |
51 | # append losses, accs to lists
52 | for x, y in zip(
53 | [losses, avg_acc, avg_top_k_acc, avg_coarse_acc, avg_polar_acc, avg_chi_1_acc, avg_chi_2_acc, avg_chi_3_acc, avg_chi_4_acc, avg_chi_1_loss, avg_chi_2_loss, avg_chi_3_loss, avg_chi_4_loss],
54 | [loss.item(), acc, top_k_acc, coarse_acc, polar_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc, chi_1_loss.item(), chi_2_loss.item(), chi_3_loss.item(), chi_4_loss.item()],
55 | ):
56 | x.append(y)
57 |
58 | del loss, chi_1_loss, chi_2_loss, chi_3_loss, chi_4_loss, out, y, acc, top_k_acc, coarse_acc, polar_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc
59 |
60 | print("\nloss", np.mean(losses), "acc", np.mean(avg_acc), "top3", np.mean(avg_top_k_acc), "coarse", np.mean(avg_coarse_acc), "polar", np.mean(avg_polar_acc))
61 |
62 | return (
63 | gen,
64 | np.mean(losses),
65 | np.mean(avg_chi_1_loss),
66 | np.mean(avg_chi_2_loss),
67 | np.mean(avg_chi_3_loss),
68 | np.mean(avg_chi_4_loss),
69 | np.mean(avg_acc),
70 | np.mean(avg_top_k_acc),
71 | np.mean(avg_coarse_acc),
72 | np.mean(avg_polar_acc),
73 | np.mean(avg_chi_1_acc),
74 | np.mean(avg_chi_2_acc),
75 | np.mean(avg_chi_3_acc),
76 | np.mean(avg_chi_4_acc),
77 | )
78 |
79 |
80 | def step(model, out, criterion, chi_1_criterion, chi_2_criterion, chi_3_criterion, chi_4_criterion, k=3, use_cuda=True):
81 |
82 | bs_idx, x_atom, x_bb, x_b, y_b, z_b, x_res_type, y, chi_angles_real, chi_angles = out
83 |
84 | bs = len(bs_idx)
85 | output_atom = torch.zeros((bs, c + 1, n + 2, n + 2, n + 2))
86 | output_bb = torch.zeros((bs, 2, n + 2, n + 2, n + 2))
87 | output_res = torch.zeros((bs, 22, n + 2, n + 2, n + 2))
88 |
89 | if use_cuda:
90 | output_atom, output_bb, output_res = map(lambda x: x.cuda(), [output_atom, output_bb, output_res])
91 |
92 |
93 | output_atom[bs_idx, x_atom, x_b, y_b, z_b] = 1 # atom type
94 | output_bb.zero_()
95 | output_bb[bs_idx, x_bb, x_b, y_b, z_b] = 1 # BB indicator
96 | output_res.zero_()
97 | output_res[bs_idx, x_res_type, x_b, y_b, z_b] = 1 # res type
98 | output = torch.cat([output_atom[:, :c], output_bb[:, :1], output_res[:, :21]], 1)
99 | X = output[:, :, 1:-1, 1:-1, 1:-1]
100 |
101 | X, y = X.float(), y.long()
102 | chi_angles = chi_angles.long()
103 |
104 | chi_1 = chi_angles[:, 0]
105 | chi_2 = chi_angles[:, 1]
106 | chi_3 = chi_angles[:, 2]
107 | chi_4 = chi_angles[:, 3]
108 |
109 | y_onehot = torch.FloatTensor(y.size()[0], 20)
110 | y_onehot.zero_()
111 | y_onehot.scatter_(1, y[:, None], 1)
112 |
113 | chi_1_onehot = torch.FloatTensor(chi_1.size()[0], len(datasets.CHI_BINS))
114 | chi_1_onehot.zero_()
115 | chi_1_onehot.scatter_(1, chi_1[:, None], 1)
116 |
117 | chi_2_onehot = torch.FloatTensor(chi_2.size()[0], len(datasets.CHI_BINS))
118 | chi_2_onehot.zero_()
119 | chi_2_onehot.scatter_(1, chi_2[:, None], 1)
120 |
121 | chi_3_onehot = torch.FloatTensor(chi_3.size()[0], len(datasets.CHI_BINS))
122 | chi_3_onehot.zero_()
123 | chi_3_onehot.scatter_(1, chi_3[:, None], 1)
124 |
125 | if use_cuda:
126 | X, y, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot, chi_1, chi_2, chi_3, chi_4 = map(lambda x: x.cuda(), [X, y, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot, chi_1, chi_2, chi_3, chi_4])
127 |
128 | out, chi_1_pred, chi_2_pred, chi_3_pred, chi_4_pred = model(X, y_onehot, chi_1_onehot[:, 1:], chi_2_onehot[:, 1:], chi_3_onehot[:, 1:])
129 | # loss
130 | loss = criterion(out, y)
131 | chi_1_loss = chi_1_criterion(chi_1_pred, chi_1 - 1) # [:, 1:])
132 | chi_2_loss = chi_2_criterion(chi_2_pred, chi_2 - 1) # [:, 1:])
133 | chi_3_loss = chi_3_criterion(chi_3_pred, chi_3 - 1) # [:, 1:])
134 | chi_4_loss = chi_4_criterion(chi_4_pred, chi_4 - 1) # [:, 1:])
135 |
136 | # acc
137 | acc, _ = acc_util.get_acc(out, y)
138 | top_k_acc = acc_util.get_top_k_acc(out, y, k=k)
139 | coarse_acc, _ = acc_util.get_acc(out, y, label_dict=acc_util.label_coarse)
140 | polar_acc, _ = acc_util.get_acc(out, y, label_dict=acc_util.label_polar)
141 | chi_1_acc, _ = acc_util.get_acc(chi_1_pred, chi_1 - 1, ignore_idx=-1)
142 | chi_2_acc, _ = acc_util.get_acc(chi_2_pred, chi_2 - 1, ignore_idx=-1)
143 | chi_3_acc, _ = acc_util.get_acc(chi_3_pred, chi_3 - 1, ignore_idx=-1)
144 | chi_4_acc, _ = acc_util.get_acc(chi_4_pred, chi_4 - 1, ignore_idx=-1)
145 |
146 | return loss, chi_1_loss, chi_2_loss, chi_3_loss, chi_4_loss, out, y, acc, top_k_acc, coarse_acc, polar_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc
147 |
148 |
149 | def step_iter(gen, dataloader):
150 | try:
151 | out = gen.next()
152 | except StopIteration:
153 | gen = iter(dataloader)
154 | out = gen.next()
155 | return gen, out
156 |
157 |
158 | def main():
159 |
160 | manager = common.run_manager.RunManager()
161 |
162 | manager.parse_args()
163 | args = manager.args
164 | log = manager.log
165 |
166 | use_cuda = torch.cuda.is_available() and args.cuda
167 |
168 | # set up model
169 | model = models.seqPred(nic=len(common.atoms.atoms) + 1 + 21, nf=args.nf, momentum=0.01)
170 | model.apply(models.init_ortho_weights)
171 |
172 | if use_cuda:
173 | model.cuda()
174 | else:
175 | print("Training model on CPU")
176 |
177 | if args.model != "":
178 | # load pretrained model
179 | model.load_state_dict(torch.load(args.model))
180 | print("loaded pretrained model")
181 |
182 | # parallelize over available GPUs
183 | if torch.cuda.device_count() > 1 and args.cuda:
184 | print("using", torch.cuda.device_count(), "GPUs")
185 | model = nn.DataParallel(model)
186 |
187 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta1, 0.999), weight_decay=args.reg)
188 |
189 | if args.optimizer != "":
190 | # load pretrained optimizer
191 | optimizer.load_state_dict(torch.load(args.optimizer))
192 | print("loaded pretrained optimizer")
193 |
194 | # load pretrained model weights / optimizer state
195 |
196 | chi_1_criterion = nn.CrossEntropyLoss(ignore_index=-1)
197 | chi_2_criterion = nn.CrossEntropyLoss(ignore_index=-1)
198 | chi_3_criterion = nn.CrossEntropyLoss(ignore_index=-1)
199 | chi_4_criterion = nn.CrossEntropyLoss(ignore_index=-1)
200 | criterion = nn.CrossEntropyLoss()
201 | if use_cuda:
202 | criterion.cuda()
203 | chi_1_criterion.cuda()
204 | chi_2_criterion.cuda()
205 | chi_3_criterion.cuda()
206 | chi_4_criterion.cuda()
207 |
208 | train_dataset = datasets.PDB_data_spitter(data_dir=args.data_dir + "/train_s95_chi")
209 | train_dataset.len = 8145448 # NOTE -- need to update this if underlying data changes
210 |
211 | test_dataset = datasets.PDB_data_spitter(data_dir=args.data_dir + "/test_s95_chi")
212 | test_dataset.len = 574267 # NOTE -- need to update this if underlying data changes
213 |
214 | train_dataloader = data.DataLoader(train_dataset, batch_size=args.batchSize, shuffle=False, num_workers=args.workers, pin_memory=True, collate_fn=datasets.collate_wrapper)
215 | test_dataloader = data.DataLoader(test_dataset, batch_size=args.batchSize, shuffle=False, num_workers=args.workers, pin_memory=True, collate_fn=datasets.collate_wrapper)
216 |
217 | # training params
218 | validation_frequency = args.validation_frequency
219 | save_frequency = args.save_frequency
220 |
221 | """ TRAIN """
222 |
223 | model.train()
224 | gen = iter(train_dataloader)
225 | test_gen = iter(test_dataloader)
226 | bs = args.batchSize
227 | output_atom = torch.zeros((bs, c + 1, n + 2, n + 2, n + 2))
228 | output_bb = torch.zeros((bs, 2, n + 2, n + 2, n + 2))
229 | output_res = torch.zeros((bs, 22, n + 2, n + 2, n + 2))
230 | y_onehot = torch.FloatTensor(bs, 20)
231 | chi_1_onehot = torch.FloatTensor(bs, len(datasets.CHI_BINS))
232 | chi_2_onehot = torch.FloatTensor(bs, len(datasets.CHI_BINS))
233 | chi_3_onehot = torch.FloatTensor(bs, len(datasets.CHI_BINS))
234 |
235 | if use_cuda:
236 | output_atom, output_bb, output_res, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot = map(lambda x: x.cuda(), [output_atom, output_bb, output_res, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot])
237 | for epoch in range(args.epochs):
238 | for it in tqdm(range(len(train_dataloader)), desc="training epoch %0.2d" % epoch):
239 |
240 | gen, out = step_iter(gen, train_dataloader)
241 | bs_idx, x_atom, x_bb, x_b, y_b, z_b, x_res_type, y, chi_angles_real, chi_angles = out
242 | bs_i = len(bs_idx)
243 | output_atom.zero_()
244 | output_atom[bs_idx, x_atom, x_b, y_b, z_b] = 1 # atom type
245 | output_bb.zero_()
246 | output_bb[bs_idx, x_bb, x_b, y_b, z_b] = 1 # BB indicator
247 | output_res.zero_()
248 | output_res[bs_idx, x_res_type, x_b, y_b, z_b] = 1 # res type
249 | output = torch.cat([output_atom[:, :c], output_bb[:, :1], output_res[:, :21]], 1)
250 |
251 | X = output[:, :, 1:-1, 1:-1, 1:-1]
252 |
253 | X, y = X.float(), y.long()
254 | chi_angles = chi_angles.long()
255 |
256 | chi_1 = chi_angles[:, 0]
257 | chi_2 = chi_angles[:, 1]
258 | chi_3 = chi_angles[:, 2]
259 | chi_4 = chi_angles[:, 3]
260 |
261 | if use_cuda:
262 | y, y_onehot, chi_1, chi_2, chi_3, chi_4 = map(lambda x: x.cuda(), [y, y_onehot, chi_1, chi_2, chi_3, chi_4])
263 |
264 | if bs_i < bs:
265 | y = F.pad(y, (0, bs - bs_i))
266 | chi_1 = F.pad(chi_1, (0, bs - bs_i))
267 | chi_2 = F.pad(chi_2, (0, bs - bs_i))
268 | chi_3 = F.pad(chi_3, (0, bs - bs_i))
269 |
270 | y_onehot.zero_()
271 | y_onehot.scatter_(1, y[:, None], 1)
272 |
273 | chi_1_onehot.zero_()
274 | chi_1_onehot.scatter_(1, chi_1[:, None], 1)
275 |
276 | chi_2_onehot.zero_()
277 | chi_2_onehot.scatter_(1, chi_2[:, None], 1)
278 |
279 | chi_3_onehot.zero_()
280 | chi_3_onehot.scatter_(1, chi_3[:, None], 1)
281 |
282 | # 0 index for chi indicates that chi is masked
283 | out, chi_1_pred, chi_2_pred, chi_3_pred, chi_4_pred = model(X[:bs_i], y_onehot[:bs_i], chi_1_onehot[:bs_i, 1:], chi_2_onehot[:bs_i, 1:], chi_3_onehot[:bs_i, 1:])
284 | res_loss = criterion(out, y[:bs_i])
285 | chi_1_loss = chi_1_criterion(chi_1_pred, chi_1[:bs_i] - 1) # , 1:])
286 | chi_2_loss = chi_2_criterion(chi_2_pred, chi_2[:bs_i] - 1) # , 1:])
287 | chi_3_loss = chi_3_criterion(chi_3_pred, chi_3[:bs_i] - 1) # , 1:])
288 | chi_4_loss = chi_4_criterion(chi_4_pred, chi_4[:bs_i] - 1) # , 1:])
289 |
290 | train_loss = res_loss + chi_1_loss + chi_2_loss + chi_3_loss + chi_4_loss
291 | train_loss.backward()
292 | optimizer.step()
293 |
294 | # acc
295 | train_acc, _ = acc_util.get_acc(out, y[:bs_i], cm=None)
296 | train_top_k_acc = acc_util.get_top_k_acc(out, y[:bs_i], k=3)
297 | train_coarse_acc, _ = acc_util.get_acc(out, y[:bs_i], label_dict=acc_util.label_coarse)
298 | train_polar_acc, _ = acc_util.get_acc(out, y[:bs_i], label_dict=acc_util.label_polar)
299 |
300 | chi_1_acc, _ = acc_util.get_acc(chi_1_pred, chi_1[:bs_i] - 1, ignore_idx=-1)
301 | chi_2_acc, _ = acc_util.get_acc(chi_2_pred, chi_2[:bs_i] - 1, ignore_idx=-1)
302 | chi_3_acc, _ = acc_util.get_acc(chi_3_pred, chi_3[:bs_i] - 1, ignore_idx=-1)
303 | chi_4_acc, _ = acc_util.get_acc(chi_4_pred, chi_4[:bs_i] - 1, ignore_idx=-1)
304 |
305 | # tensorboard logging
306 | map(
307 | lambda x: log.log_scalar("seq_chi_pred/%s" % x[0], x[1]),
308 | zip(
309 | ["res_loss", "chi_1_loss", "chi_2_loss", "chi_3_loss", "chi_4_loss", "train_acc", "chi_1_acc", "chi_2_acc", "chi_3_acc", "chi_4_acc", "train_top3_acc", "train_coarse_acc", "train_polar_acc"],
310 | [res_loss.item(), chi_1_loss.item(), chi_2_loss.item(), chi_3_loss.item(), chi_4_loss.item(), train_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc, train_top_k_acc, train_coarse_acc, train_polar_acc],
311 | ),
312 | )
313 |
314 | if it % validation_frequency == 0 or it == len(train_dataloader) - 1:
315 |
316 | if it > 0:
317 | if torch.cuda.device_count() > 1 and args.cuda:
318 | torch.save(model.module.state_dict(), log.log_path + "/seq_chi_pred_curr_weights.pt")
319 | else:
320 | torch.save(model.state_dict(), log.log_path + "/seq_chi_pred_curr_weights.pt")
321 | torch.save(optimizer.state_dict(), log.log_path + "/seq_chi_pred_curr_optimizer.pt")
322 |
323 | # NOTE -- saving models for each validation step
324 | if it > 0 and (it % save_frequency == 0 or it == len(train_dataloader) - 1):
325 | if torch.cuda.device_count() > 1 and args.cuda:
326 | torch.save(model.module.state_dict(), log.log_path + "/seq_chi_pred_epoch_%0.3d_%s_weights.pt" % (epoch, it))
327 | else:
328 | torch.save(model.state_dict(), log.log_path + "/seq_chi_pred_epoch_%0.3d_%s_weights.pt" % (epoch, it))
329 |
330 | torch.save(optimizer.state_dict(), log.log_path + "/seq_chi_pred_epoch_%0.3d_%s_optimizer.pt" % (epoch, it))
331 |
332 | ##NOTE -- turning back on model.eval()
333 | model.eval()
334 | # eval on the test set
335 | test_gen, curr_test_loss, test_chi_1_loss, test_chi_2_loss, test_chi_3_loss, test_chi_4_loss, curr_test_acc, curr_test_top_k_acc, coarse_acc, polar_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc = test(
336 | model,
337 | test_gen,
338 | test_dataloader,
339 | criterion,
340 | chi_1_criterion,
341 | chi_2_criterion,
342 | chi_3_criterion,
343 | chi_4_criterion,
344 | max_it=len(test_dataloader),
345 | n_iters=min(10, len(test_dataloader)),
346 | desc="test",
347 | batch_size=args.batchSize,
348 | use_cuda=use_cuda,
349 | )
350 |
351 | map(
352 | lambda x: log.log_scalar("seq_chi_pred/%s" % x[0], x[1]),
353 | zip(
354 | [
355 | "test_loss",
356 | "test_chi_1_loss",
357 | "test_chi_2_loss",
358 | "test_chi_3_loss",
359 | "test_chi_4_loss",
360 | "test_acc",
361 | "test_chi_1_acc",
362 | "test_chi_2_acc",
363 | "test_chi_3_acc",
364 | "test_chi_4_acc",
365 | "test_acc_top3",
366 | "test_coarse_acc",
367 | "test_polar_acc",
368 | ],
369 | [
370 | curr_test_loss.item(),
371 | chi_1_loss.item(),
372 | chi_2_loss.item(),
373 | chi_3_loss.item(),
374 | chi_4_loss.item(),
375 | curr_test_acc.item(),
376 | chi_1_acc.item(),
377 | chi_2_acc.item(),
378 | chi_3_acc.item(),
379 | chi_4_acc.item(),
380 | curr_test_top_k_acc.item(),
381 | coarse_acc.item(),
382 | polar_acc.item(),
383 | ],
384 | ),
385 | )
386 |
387 | model.train()
388 |
389 | log.advance_iteration()
390 |
391 |
392 | if __name__ == "__main__":
393 | main()
394 |
--------------------------------------------------------------------------------
/train_autoreg_chi_baseline.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | from tqdm import tqdm
9 | import common.run_manager
10 | import seq_des.models as models
11 | import seq_des.util.voxelize as voxelize
12 | import glob
13 | import seq_des.util.canonicalize as canonicalize
14 | import pickle
15 | import seq_des.util.data as datasets
16 | from torch.utils import data
17 | import common.atoms
18 | import seq_des.util.acc_util as acc_util
19 | import subprocess as sp
20 | import time
21 | import torch.nn.functional as F
22 |
23 | """ script to train 3D CNN on local residue-centered environments -- BB only -- with autoregressive rotamer chi angle prediction"""
24 |
25 | dist = 10
26 | n = 20
27 | c = len(common.atoms.atoms)
28 |
29 |
30 | def test(
31 | model, gen, dataloader, criterion, chi_1_criterion, chi_2_criterion, chi_3_criterion, chi_4_criterion, max_it=1e6, desc="test", batch_size=64, n_iters=500, k=3, use_cuda=True,
32 | ):
33 | n_iters = min(max_it, n_iters)
34 | model = model.eval()
35 | gen = iter(dataloader)
36 | (losses, avg_acc, avg_top_k_acc, avg_coarse_acc, avg_polar_acc, avg_chi_1_acc, avg_chi_2_acc, avg_chi_3_acc, avg_chi_4_acc, avg_chi_1_loss, avg_chi_2_loss, avg_chi_3_loss, avg_chi_4_loss,) = ([] for i in range(13))
37 | with torch.no_grad():
38 |
39 | for i in tqdm(range(n_iters), desc=desc):
40 | try:
41 | out = gen.next()
42 | except StopIteration:
43 | gen = iter(dataloader)
44 | out = gen.next()
45 |
46 | out = step(model, out, criterion, chi_1_criterion, chi_2_criterion, chi_3_criterion, chi_4_criterion, use_cuda=use_cuda)
47 |
48 | if out is None:
49 | continue
50 | (loss, chi_1_loss, chi_2_loss, chi_3_loss, chi_4_loss, out, y, acc, top_k_acc, coarse_acc, polar_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc,) = out
51 |
52 | # append losses, accs to lists
53 | for x, y in zip(
54 | [losses, avg_acc, avg_top_k_acc, avg_coarse_acc, avg_polar_acc, avg_chi_1_acc, avg_chi_2_acc, avg_chi_3_acc, avg_chi_4_acc, avg_chi_1_loss, avg_chi_2_loss, avg_chi_3_loss, avg_chi_4_loss,],
55 | [loss.item(), acc, top_k_acc, coarse_acc, polar_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc, chi_1_loss.item(), chi_2_loss.item(), chi_3_loss.item(), chi_4_loss.item(),],
56 | ):
57 | x.append(y)
58 |
59 | del (
60 | loss,
61 | chi_1_loss,
62 | chi_2_loss,
63 | chi_3_loss,
64 | chi_4_loss,
65 | out,
66 | y,
67 | acc,
68 | top_k_acc,
69 | coarse_acc,
70 | polar_acc,
71 | chi_1_acc,
72 | chi_2_acc,
73 | chi_3_acc,
74 | chi_4_acc,
75 | )
76 |
77 | print(
78 | "\nloss", np.mean(losses), "acc", np.mean(avg_acc), "top3", np.mean(avg_top_k_acc), "coarse", np.mean(avg_coarse_acc), "polar", np.mean(avg_polar_acc),
79 | )
80 |
81 | return (
82 | gen,
83 | np.mean(losses),
84 | np.mean(avg_chi_1_loss),
85 | np.mean(avg_chi_2_loss),
86 | np.mean(avg_chi_3_loss),
87 | np.mean(avg_chi_4_loss),
88 | np.mean(avg_acc),
89 | np.mean(avg_top_k_acc),
90 | np.mean(avg_coarse_acc),
91 | np.mean(avg_polar_acc),
92 | np.mean(avg_chi_1_acc),
93 | np.mean(avg_chi_2_acc),
94 | np.mean(avg_chi_3_acc),
95 | np.mean(avg_chi_4_acc),
96 | )
97 |
98 |
99 | def step(model, out, criterion, chi_1_criterion, chi_2_criterion, chi_3_criterion, chi_4_criterion, k=3, use_cuda=True):
100 |
101 | (bs_idx, x_atom, x_bb, x_b, y_b, z_b, x_res_type, y, chi_angles_real, chi_angles,) = out
102 |
103 | bs = len(bs_idx)
104 | output_atom = torch.zeros((bs, c + 1, n + 2, n + 2, n + 2))
105 | output_atom[bs_idx, x_atom, x_b, y_b, z_b] = 1
106 |
107 | if use_cuda:
108 | output_atom = output_atom.cuda()
109 |
110 | X = output_atom[:, :c, 1:-1, 1:-1, 1:-1]
111 |
112 | if X is None:
113 | return None
114 |
115 | X, y = X.float(), y.long()
116 | chi_angles = chi_angles.long()
117 |
118 | chi_1 = chi_angles[:, 0]
119 | chi_2 = chi_angles[:, 1]
120 | chi_3 = chi_angles[:, 2]
121 | chi_4 = chi_angles[:, 3]
122 |
123 | y_onehot = torch.FloatTensor(y.size()[0], 20)
124 | y_onehot.zero_()
125 | y_onehot.scatter_(1, y[:, None], 1)
126 |
127 | chi_1_onehot = torch.FloatTensor(chi_1.size()[0], len(datasets.CHI_BINS))
128 | chi_1_onehot.zero_()
129 | chi_1_onehot.scatter_(1, chi_1[:, None], 1)
130 |
131 | chi_2_onehot = torch.FloatTensor(chi_2.size()[0], len(datasets.CHI_BINS))
132 | chi_2_onehot.zero_()
133 | chi_2_onehot.scatter_(1, chi_2[:, None], 1)
134 |
135 | chi_3_onehot = torch.FloatTensor(chi_3.size()[0], len(datasets.CHI_BINS))
136 | chi_3_onehot.zero_()
137 | chi_3_onehot.scatter_(1, chi_3[:, None], 1)
138 |
139 | if use_cuda:
140 | (X, y, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot, chi_1, chi_2, chi_3, chi_4,) = map(lambda x: x.cuda(), [X, y, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot, chi_1, chi_2, chi_3, chi_4,],)
141 |
142 | out, chi_1_pred, chi_2_pred, chi_3_pred, chi_4_pred = model(X, y_onehot, chi_1_onehot[:, 1:], chi_2_onehot[:, 1:], chi_3_onehot[:, 1:])
143 | # loss
144 | loss = criterion(out, y)
145 | chi_1_loss = chi_1_criterion(chi_1_pred, chi_1 - 1)
146 | chi_2_loss = chi_2_criterion(chi_2_pred, chi_2 - 1)
147 | chi_3_loss = chi_3_criterion(chi_3_pred, chi_3 - 1)
148 | chi_4_loss = chi_4_criterion(chi_4_pred, chi_4 - 1)
149 |
150 | # acc
151 | acc, _ = acc_util.get_acc(out, y)
152 | top_k_acc = acc_util.get_top_k_acc(out, y, k=k)
153 | coarse_acc, _ = acc_util.get_acc(out, y, label_dict=acc_util.label_coarse)
154 | polar_acc, _ = acc_util.get_acc(out, y, label_dict=acc_util.label_polar)
155 | chi_1_acc, _ = acc_util.get_acc(chi_1_pred, chi_1 - 1, ignore_idx=-1)
156 | chi_2_acc, _ = acc_util.get_acc(chi_2_pred, chi_2 - 1, ignore_idx=-1)
157 | chi_3_acc, _ = acc_util.get_acc(chi_3_pred, chi_3 - 1, ignore_idx=-1)
158 | chi_4_acc, _ = acc_util.get_acc(chi_4_pred, chi_4 - 1, ignore_idx=-1)
159 |
160 | return (
161 | loss,
162 | chi_1_loss,
163 | chi_2_loss,
164 | chi_3_loss,
165 | chi_4_loss,
166 | out,
167 | y,
168 | acc,
169 | top_k_acc,
170 | coarse_acc,
171 | polar_acc,
172 | chi_1_acc,
173 | chi_2_acc,
174 | chi_3_acc,
175 | chi_4_acc,
176 | )
177 |
178 |
179 | def step_iter(gen, dataloader):
180 | try:
181 | out = gen.next()
182 | except StopIteration:
183 | gen = iter(dataloader)
184 | out = gen.next()
185 | return gen, out
186 |
187 |
188 | def main():
189 |
190 | manager = common.run_manager.RunManager()
191 |
192 | manager.parse_args()
193 | args = manager.args
194 | log = manager.log
195 |
196 | use_cuda = torch.cuda.is_available() and args.cuda
197 |
198 | # set up model
199 | model = models.seqPred(nic=len(common.atoms.atoms), nf=args.nf, momentum=args.momentum)
200 | model.apply(models.init_ortho_weights)
201 | if use_cuda:
202 | model.cuda()
203 | else:
204 | print("Training model on CPU")
205 |
206 | # parallelize over available GPUs
207 | if torch.cuda.device_count() > 1 and args.cuda:
208 | print("using", torch.cuda.device_count(), "GPUs")
209 | model = nn.DataParallel(model)
210 |
211 | if args.model != "":
212 | # load pretrained model
213 | model.load_state_dict(torch.load(args.model))
214 | print("loaded pretrained model")
215 |
216 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta1, 0.999), weight_decay=args.reg)
217 |
218 | if args.optimizer != "":
219 | # load pretrained optimizer
220 | optimizer.load_state_dict(torch.load(args.optimizer))
221 | print("loaded pretrained optimizer")
222 |
223 | # load pretrained model weights / optimizer state
224 |
225 | chi_1_criterion = nn.CrossEntropyLoss(ignore_index=-1)
226 | chi_2_criterion = nn.CrossEntropyLoss(ignore_index=-1)
227 | chi_3_criterion = nn.CrossEntropyLoss(ignore_index=-1)
228 | chi_4_criterion = nn.CrossEntropyLoss(ignore_index=-1)
229 | criterion = nn.CrossEntropyLoss()
230 | if use_cuda:
231 | criterion.cuda()
232 | chi_1_criterion.cuda()
233 | chi_2_criterion.cuda()
234 | chi_3_criterion.cuda()
235 | chi_4_criterion.cuda()
236 |
237 | train_dataset = datasets.PDB_data_spitter(data_dir=args.data_dir + "/train_s95_chi_bb")
238 | train_dataset.len = 8145448 # NOTE -- need to update this if underlying data changes
239 |
240 | test_dataset = datasets.PDB_data_spitter(data_dir=args.data_dir + "/test_s95_chi_bb")
241 | test_dataset.len = 574267 # NOTE -- need to update this if underlying data changes
242 |
243 | train_dataloader = data.DataLoader(train_dataset, batch_size=args.batchSize, shuffle=False, num_workers=args.workers, pin_memory=True, collate_fn=datasets.collate_wrapper,)
244 | test_dataloader = data.DataLoader(test_dataset, batch_size=args.batchSize, shuffle=False, num_workers=args.workers, pin_memory=True, collate_fn=datasets.collate_wrapper,)
245 |
246 | # training params
247 | validation_frequency = args.validation_frequency
248 | save_frequency = args.save_frequency
249 |
250 | """ TRAIN """
251 |
252 | model.train()
253 | gen = iter(train_dataloader)
254 | test_gen = iter(test_dataloader)
255 | bs = args.batchSize
256 | output_atom = torch.zeros((bs, c + 1, n + 2, n + 2, n + 2))
257 | y_onehot = torch.FloatTensor(bs, 20)
258 | chi_1_onehot = torch.FloatTensor(bs, len(datasets.CHI_BINS))
259 | chi_2_onehot = torch.FloatTensor(bs, len(datasets.CHI_BINS))
260 | chi_3_onehot = torch.FloatTensor(bs, len(datasets.CHI_BINS))
261 |
262 | if use_cuda:
263 | output_atom, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot = map(lambda x: x.cuda(), [output_atom, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot])
264 |
265 | for epoch in range(args.epochs):
266 | for it in tqdm(range(len(train_dataloader)), desc="training epoch %0.2d" % epoch):
267 |
268 | gen, out = step_iter(gen, train_dataloader)
269 | (bs_idx, x_atom, x_bb, x_b, y_b, z_b, x_res_type, y, chi_angles_real, chi_angles,) = out
270 | bs_i = len(bs_idx)
271 | output_atom.zero_()
272 | output_atom[bs_idx, x_atom, x_b, y_b, z_b] = 1 # atom type
273 | X = output_atom[:, :c, 1:-1, 1:-1, 1:-1]
274 |
275 | X, y = X.float(), y.long()
276 | chi_angles = chi_angles.long()
277 |
278 | chi_1 = chi_angles[:, 0]
279 | chi_2 = chi_angles[:, 1]
280 | chi_3 = chi_angles[:, 2]
281 | chi_4 = chi_angles[:, 3]
282 |
283 | if use_cuda:
284 | y, y_onehot, chi_1, chi_2, chi_3, chi_4 = map(lambda x: x.cuda(), [y, y_onehot, chi_1, chi_2, chi_3, chi_4])
285 |
286 | if bs_i < bs:
287 | y = F.pad(y, (0, bs - bs_i))
288 | chi_1 = F.pad(chi_1, (0, bs - bs_i))
289 | chi_2 = F.pad(chi_2, (0, bs - bs_i))
290 | chi_3 = F.pad(chi_3, (0, bs - bs_i))
291 |
292 | y_onehot.zero_()
293 | y_onehot.scatter_(1, y[:, None], 1)
294 |
295 | chi_1_onehot.zero_()
296 | chi_1_onehot.scatter_(1, chi_1[:, None], 1)
297 |
298 | chi_2_onehot.zero_()
299 | chi_2_onehot.scatter_(1, chi_2[:, None], 1)
300 |
301 | chi_3_onehot.zero_()
302 | chi_3_onehot.scatter_(1, chi_3[:, None], 1)
303 |
304 | out, chi_1_pred, chi_2_pred, chi_3_pred, chi_4_pred = model(X[:bs_i], y_onehot[:bs_i], chi_1_onehot[:bs_i, 1:], chi_2_onehot[:bs_i, 1:], chi_3_onehot[:bs_i, 1:])
305 | res_loss = criterion(out, y[:bs_i])
306 | chi_1_loss = chi_1_criterion(chi_1_pred, chi_1[:bs_i] - 1)
307 | chi_2_loss = chi_2_criterion(chi_2_pred, chi_2[:bs_i] - 1)
308 | chi_3_loss = chi_3_criterion(chi_3_pred, chi_3[:bs_i] - 1)
309 | chi_4_loss = chi_4_criterion(chi_4_pred, chi_4[:bs_i] - 1)
310 |
311 | train_loss = res_loss + chi_1_loss + chi_2_loss + chi_3_loss + chi_4_loss
312 | train_loss.backward()
313 | optimizer.step()
314 |
315 | # acc
316 | train_acc, _ = acc_util.get_acc(out, y[:bs_i], cm=None)
317 | train_top_k_acc = acc_util.get_top_k_acc(out, y[:bs_i], k=3)
318 | train_coarse_acc, _ = acc_util.get_acc(out, y[:bs_i], label_dict=acc_util.label_coarse)
319 | train_polar_acc, _ = acc_util.get_acc(out, y[:bs_i], label_dict=acc_util.label_polar)
320 |
321 | chi_1_acc, _ = acc_util.get_acc(chi_1_pred, chi_1[:bs_i] - 1, ignore_idx=-1)
322 | chi_2_acc, _ = acc_util.get_acc(chi_2_pred, chi_2[:bs_i] - 1, ignore_idx=-1)
323 | chi_3_acc, _ = acc_util.get_acc(chi_3_pred, chi_3[:bs_i] - 1, ignore_idx=-1)
324 | chi_4_acc, _ = acc_util.get_acc(chi_4_pred, chi_4[:bs_i] - 1, ignore_idx=-1)
325 |
326 | # tensorboard logging
327 | map(
328 | lambda x: log.log_scalar("seq_chi_pred/%s" % x[0], x[1]),
329 | zip(
330 | ["res_loss", "chi_1_loss", "chi_2_loss", "chi_3_loss", "chi_4_loss", "train_acc", "chi_1_acc", "chi_2_acc", "chi_3_acc", "chi_4_acc", "train_top3_acc", "train_coarse_acc", "train_polar_acc",],
331 | [res_loss.item(), chi_1_loss.item(), chi_2_loss.item(), chi_3_loss.item(), chi_4_loss.item(), train_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc, train_top_k_acc, train_coarse_acc, train_polar_acc,],
332 | ),
333 | )
334 |
335 | if it % validation_frequency == 0 or it == len(train_dataloader) - 1:
336 |
337 | if it > 0:
338 | if torch.cuda.device_count() > 1 and args.cuda:
339 | torch.save(
340 | model.module.state_dict(), log.log_path + "/seq_chi_pred_baseline_curr_weights.pt",
341 | )
342 | else:
343 | torch.save(
344 | model.state_dict(), log.log_path + "/seq_chi_pred_baseline_curr_weights.pt",
345 | )
346 | torch.save(
347 | optimizer.state_dict(), log.log_path + "/seq_chi_pred_baseline_curr_optimizer.pt",
348 | )
349 |
350 | # NOTE -- saving models for each validation step
351 | if it > 0 and (it % save_frequency == 0 or it == len(train_dataloader) - 1):
352 | if torch.cuda.device_count() > 1 and args.cuda:
353 | torch.save(
354 | model.module.state_dict(), log.log_path + "/seq_chi_pred_baseline_epoch_%0.3d_%s_weights.pt" % (epoch, it),
355 | )
356 | else:
357 | torch.save(
358 | model.state_dict(), log.log_path + "/seq_chi_pred_baseline_epoch_%0.3d_%s_weights.pt" % (epoch, it),
359 | )
360 |
361 | torch.save(
362 | optimizer.state_dict(), log.log_path + "/seq_chi_pred_baseline_epoch_%0.3d_%s_optimizer.pt" % (epoch, it),
363 | )
364 |
365 | ##NOTE -- turning back on model.eval()
366 | model.eval()
367 | # eval on the test set
368 | (test_gen, curr_test_loss, test_chi_1_loss, test_chi_2_loss, test_chi_3_loss, test_chi_4_loss, curr_test_acc, curr_test_top_k_acc, coarse_acc, polar_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc,) = test(
369 | model,
370 | test_gen,
371 | test_dataloader,
372 | criterion,
373 | chi_1_criterion,
374 | chi_2_criterion,
375 | chi_3_criterion,
376 | chi_4_criterion,
377 | max_it=len(test_dataloader),
378 | n_iters=min(10, len(test_dataloader)),
379 | desc="test",
380 | batch_size=args.batchSize,
381 | use_cuda=use_cuda,
382 | )
383 |
384 | map(
385 | lambda x: log.log_scalar("seq_chi_pred/%s" % x[0], x[1]),
386 | zip(
387 | [
388 | "test_loss",
389 | "test_chi_1_loss",
390 | "test_chi_2_loss",
391 | "test_chi_3_loss",
392 | "test_chi_4_loss",
393 | "test_acc",
394 | "test_chi_1_acc",
395 | "test_chi_2_acc",
396 | "test_chi_3_acc",
397 | "test_chi_4_acc",
398 | "test_acc_top3",
399 | "test_coarse_acc",
400 | "test_polar_acc",
401 | ],
402 | [
403 | curr_test_loss.item(),
404 | chi_1_loss.item(),
405 | chi_2_loss.item(),
406 | chi_3_loss.item(),
407 | chi_4_loss.item(),
408 | curr_test_acc.item(),
409 | chi_1_acc.item(),
410 | chi_2_acc.item(),
411 | chi_3_acc.item(),
412 | chi_4_acc.item(),
413 | curr_test_top_k_acc.item(),
414 | coarse_acc.item(),
415 | polar_acc.item(),
416 | ],
417 | ),
418 | )
419 |
420 | model.train()
421 |
422 | log.advance_iteration()
423 |
424 |
425 | if __name__ == "__main__":
426 | main()
427 |
--------------------------------------------------------------------------------
/txt/resfiles/NATRO_all.txt:
--------------------------------------------------------------------------------
1 | 1 - 90 NATRO # set all residues in 3mx7_gt.pdb to NATIVE ROTAMERS (skips designing at all)
2 |
--------------------------------------------------------------------------------
/txt/resfiles/PIKAA_all_one_AA.txt:
--------------------------------------------------------------------------------
1 | 1 - 90 PIKAA C
2 |
--------------------------------------------------------------------------------
/txt/resfiles/full_example.txt:
--------------------------------------------------------------------------------
1 | APOLAR
2 | start
3 | 65 POLAR
4 | 34 PIKAA EHKNRQDST
5 | 35 - 40 PIKAA EHK
6 | 41 - 45 POLAR
7 | 80 - 85 NATRO
8 | 90 NATRO
9 |
--------------------------------------------------------------------------------
/txt/resfiles/generate_resfile.py:
--------------------------------------------------------------------------------
1 | # import sys
2 |
3 | # if len(sys.argv) == 0:
4 | # print("Please provide a path to a .fasta file with a sequence")
5 | # sys.exit()
6 |
7 | # def get_sequence(path):
8 | # """
9 | # Get a sequence from a FASTA file with an initial sequence for the Protein Sequence Design Algorithm
10 | # """
11 | # sequence = ""
12 | # with open(path, "r") as f:
13 | # lines = f.readlines()
14 | # print(lines)
15 | # sequence = lines[1] + lines[0]
16 |
17 | # print(sequence)
18 |
19 | sequence = "TMPSTYAFKLPIQTETGVARVRSVIKKVSLTLSAYQVDYLLNTATVTSPVAWADMVDGVQAAGVEIQYGQFF"
20 | sequence = list(sequence)
21 |
22 | with open("init_seq_1cc8_gt.txt", "w") as file1:
23 | for i in range(1, len(sequence) + 1):
24 | # command = " ".join([str(i), "TPIKAA", sequence[i], "\n"]
25 | file1.write("{} TPIKAA {} \n".format(str(i), sequence[i-1]))
26 |
27 | # get_sequence("../../../sequenced_results/1bkr_gt_sequenced/init_seq.fasta")
28 |
--------------------------------------------------------------------------------
/txt/resfiles/init_seq_1acf_gt.txt:
--------------------------------------------------------------------------------
1 | 1 TPIKAA A
2 | 2 TPIKAA R
3 | 3 TPIKAA E
4 | 4 TPIKAA T
5 | 5 TPIKAA W
6 | 6 TPIKAA V
7 | 7 TPIKAA D
8 | 8 TPIKAA D
9 | 9 TPIKAA L
10 | 10 TPIKAA M
11 | 11 TPIKAA C
12 | 12 TPIKAA S
13 | 13 TPIKAA T
14 | 14 TPIKAA G
15 | 15 TPIKAA A
16 | 16 TPIKAA V
17 | 17 TPIKAA R
18 | 18 TPIKAA K
19 | 19 TPIKAA C
20 | 20 TPIKAA A
21 | 21 TPIKAA L
22 | 22 TPIKAA V
23 | 23 TPIKAA G
24 | 24 TPIKAA P
25 | 25 TPIKAA A
26 | 26 TPIKAA G
27 | 27 TPIKAA N
28 | 28 TPIKAA V
29 | 29 TPIKAA Y
30 | 30 TPIKAA A
31 | 31 TPIKAA Q
32 | 32 TPIKAA A
33 | 33 TPIKAA P
34 | 34 TPIKAA G
35 | 35 TPIKAA Y
36 | 36 TPIKAA E
37 | 37 TPIKAA V
38 | 38 TPIKAA S
39 | 39 TPIKAA D
40 | 40 TPIKAA R
41 | 41 TPIKAA Q
42 | 42 TPIKAA G
43 | 43 TPIKAA E
44 | 44 TPIKAA L
45 | 45 TPIKAA V
46 | 46 TPIKAA A
47 | 47 TPIKAA D
48 | 48 TPIKAA G
49 | 49 TPIKAA L
50 | 50 TPIKAA K
51 | 51 TPIKAA K
52 | 52 TPIKAA P
53 | 53 TPIKAA R
54 | 54 TPIKAA G
55 | 55 TPIKAA V
56 | 56 TPIKAA S
57 | 57 TPIKAA S
58 | 58 TPIKAA S
59 | 59 TPIKAA T
60 | 60 TPIKAA F
61 | 61 TPIKAA G
62 | 62 TPIKAA L
63 | 63 TPIKAA D
64 | 64 TPIKAA G
65 | 65 TPIKAA M
66 | 66 TPIKAA R
67 | 67 TPIKAA F
68 | 68 TPIKAA D
69 | 69 TPIKAA V
70 | 70 TPIKAA L
71 | 71 TPIKAA D
72 | 72 TPIKAA T
73 | 73 TPIKAA S
74 | 74 TPIKAA D
75 | 75 TPIKAA R
76 | 76 TPIKAA S
77 | 77 TPIKAA L
78 | 78 TPIKAA F
79 | 79 TPIKAA A
80 | 80 TPIKAA N
81 | 81 TPIKAA L
82 | 82 TPIKAA D
83 | 83 TPIKAA L
84 | 84 TPIKAA H
85 | 85 TPIKAA G
86 | 86 TPIKAA V
87 | 87 TPIKAA L
88 | 88 TPIKAA C
89 | 89 TPIKAA V
90 | 90 TPIKAA F
91 | 91 TPIKAA T
92 | 92 TPIKAA L
93 | 93 TPIKAA K
94 | 94 TPIKAA S
95 | 95 TPIKAA I
96 | 96 TPIKAA I
97 | 97 TPIKAA V
98 | 98 TPIKAA G
99 | 99 TPIKAA S
100 | 100 TPIKAA L
101 | 101 TPIKAA S
102 | 102 TPIKAA G
103 | 103 TPIKAA D
104 | 104 TPIKAA M
105 | 105 TPIKAA A
106 | 106 TPIKAA A
107 | 107 TPIKAA A
108 | 108 TPIKAA M
109 | 109 TPIKAA A
110 | 110 TPIKAA A
111 | 111 TPIKAA Q
112 | 112 TPIKAA L
113 | 113 TPIKAA V
114 | 114 TPIKAA E
115 | 115 TPIKAA G
116 | 116 TPIKAA L
117 | 117 TPIKAA A
118 | 118 TPIKAA E
119 | 119 TPIKAA A
120 | 120 TPIKAA L
121 | 121 TPIKAA M
122 | 122 TPIKAA V
123 | 123 TPIKAA Y
124 | 124 TPIKAA G
125 | 125 TPIKAA E
126 |
--------------------------------------------------------------------------------
/txt/resfiles/init_seq_1bkr_gt.txt:
--------------------------------------------------------------------------------
1 | 1 TPIKAA E
2 | 2 TPIKAA I
3 | 3 TPIKAA R
4 | 4 TPIKAA K
5 | 5 TPIKAA Q
6 | 6 TPIKAA R
7 | 7 TPIKAA F
8 | 8 TPIKAA F
9 | 9 TPIKAA D
10 | 10 TPIKAA F
11 | 11 TPIKAA C
12 | 12 TPIKAA R
13 | 13 TPIKAA K
14 | 14 TPIKAA V
15 | 15 TPIKAA T
16 | 16 TPIKAA A
17 | 17 TPIKAA G
18 | 18 TPIKAA W
19 | 19 TPIKAA Q
20 | 20 TPIKAA N
21 | 21 TPIKAA V
22 | 22 TPIKAA N
23 | 23 TPIKAA L
24 | 24 TPIKAA T
25 | 25 TPIKAA D
26 | 26 TPIKAA F
27 | 27 TPIKAA A
28 | 28 TPIKAA S
29 | 29 TPIKAA N
30 | 30 TPIKAA F
31 | 31 TPIKAA R
32 | 32 TPIKAA H
33 | 33 TPIKAA G
34 | 34 TPIKAA F
35 | 35 TPIKAA C
36 | 36 TPIKAA F
37 | 37 TPIKAA Q
38 | 38 TPIKAA A
39 | 39 TPIKAA L
40 | 40 TPIKAA I
41 | 41 TPIKAA Q
42 | 42 TPIKAA K
43 | 43 TPIKAA V
44 | 44 TPIKAA V
45 | 45 TPIKAA P
46 | 46 TPIKAA E
47 | 47 TPIKAA L
48 | 48 TPIKAA F
49 | 49 TPIKAA N
50 | 50 TPIKAA F
51 | 51 TPIKAA S
52 | 52 TPIKAA D
53 | 53 TPIKAA M
54 | 54 TPIKAA K
55 | 55 TPIKAA K
56 | 56 TPIKAA E
57 | 57 TPIKAA E
58 | 58 TPIKAA P
59 | 59 TPIKAA K
60 | 60 TPIKAA T
61 | 61 TPIKAA N
62 | 62 TPIKAA L
63 | 63 TPIKAA E
64 | 64 TPIKAA N
65 | 65 TPIKAA A
66 | 66 TPIKAA F
67 | 67 TPIKAA K
68 | 68 TPIKAA Y
69 | 69 TPIKAA A
70 | 70 TPIKAA Q
71 | 71 TPIKAA R
72 | 72 TPIKAA K
73 | 73 TPIKAA L
74 | 74 TPIKAA G
75 | 75 TPIKAA I
76 | 76 TPIKAA P
77 | 77 TPIKAA E
78 | 78 TPIKAA I
79 | 79 TPIKAA I
80 | 80 TPIKAA K
81 | 81 TPIKAA P
82 | 82 TPIKAA A
83 | 83 TPIKAA E
84 | 84 TPIKAA V
85 | 85 TPIKAA A
86 | 86 TPIKAA Q
87 | 87 TPIKAA E
88 | 88 TPIKAA G
89 | 89 TPIKAA P
90 | 90 TPIKAA S
91 | 91 TPIKAA E
92 | 92 TPIKAA A
93 | 93 TPIKAA D
94 | 94 TPIKAA V
95 | 95 TPIKAA L
96 | 96 TPIKAA Q
97 | 97 TPIKAA W
98 | 98 TPIKAA V
99 | 99 TPIKAA M
100 | 100 TPIKAA T
101 | 101 TPIKAA F
102 | 102 TPIKAA L
103 | 103 TPIKAA Q
104 | 104 TPIKAA Y
105 | 105 TPIKAA L
106 | 106 TPIKAA A
107 | 107 TPIKAA S
108 | 108 TPIKAA M
109 |
--------------------------------------------------------------------------------
/txt/resfiles/init_seq_1cc8_gt.txt:
--------------------------------------------------------------------------------
1 | 1 TPIKAA T
2 | 2 TPIKAA M
3 | 3 TPIKAA P
4 | 4 TPIKAA S
5 | 5 TPIKAA T
6 | 6 TPIKAA Y
7 | 7 TPIKAA A
8 | 8 TPIKAA F
9 | 9 TPIKAA K
10 | 10 TPIKAA L
11 | 11 TPIKAA P
12 | 12 TPIKAA I
13 | 13 TPIKAA Q
14 | 14 TPIKAA T
15 | 15 TPIKAA E
16 | 16 TPIKAA T
17 | 17 TPIKAA G
18 | 18 TPIKAA V
19 | 19 TPIKAA A
20 | 20 TPIKAA R
21 | 21 TPIKAA V
22 | 22 TPIKAA R
23 | 23 TPIKAA S
24 | 24 TPIKAA V
25 | 25 TPIKAA I
26 | 26 TPIKAA K
27 | 27 TPIKAA K
28 | 28 TPIKAA V
29 | 29 TPIKAA S
30 | 30 TPIKAA L
31 | 31 TPIKAA T
32 | 32 TPIKAA L
33 | 33 TPIKAA S
34 | 34 TPIKAA A
35 | 35 TPIKAA Y
36 | 36 TPIKAA Q
37 | 37 TPIKAA V
38 | 38 TPIKAA D
39 | 39 TPIKAA Y
40 | 40 TPIKAA L
41 | 41 TPIKAA L
42 | 42 TPIKAA N
43 | 43 TPIKAA T
44 | 44 TPIKAA A
45 | 45 TPIKAA T
46 | 46 TPIKAA V
47 | 47 TPIKAA T
48 | 48 TPIKAA S
49 | 49 TPIKAA P
50 | 50 TPIKAA V
51 | 51 TPIKAA A
52 | 52 TPIKAA W
53 | 53 TPIKAA A
54 | 54 TPIKAA D
55 | 55 TPIKAA M
56 | 56 TPIKAA V
57 | 57 TPIKAA D
58 | 58 TPIKAA G
59 | 59 TPIKAA V
60 | 60 TPIKAA Q
61 | 61 TPIKAA A
62 | 62 TPIKAA A
63 | 63 TPIKAA G
64 | 64 TPIKAA V
65 | 65 TPIKAA E
66 | 66 TPIKAA I
67 | 67 TPIKAA Q
68 | 68 TPIKAA Y
69 | 69 TPIKAA G
70 | 70 TPIKAA Q
71 | 71 TPIKAA F
72 | 72 TPIKAA F
73 |
--------------------------------------------------------------------------------
/txt/resfiles/init_seq_3mx7_gt.txt:
--------------------------------------------------------------------------------
1 | 1 TPIKAA F
2 | 2 TPIKAA F
3 | 3 TPIKAA N
4 | 4 TPIKAA L
5 | 5 TPIKAA V
6 | 6 TPIKAA G
7 | 7 TPIKAA V
8 | 8 TPIKAA W
9 | 9 TPIKAA E
10 | 10 TPIKAA V
11 | 11 TPIKAA D
12 | 12 TPIKAA L
13 | 13 TPIKAA S
14 | 14 TPIKAA D
15 | 15 TPIKAA G
16 | 16 TPIKAA S
17 | 17 TPIKAA H
18 | 18 TPIKAA R
19 | 19 TPIKAA I
20 | 20 TPIKAA V
21 | 21 TPIKAA F
22 | 22 TPIKAA Q
23 | 23 TPIKAA E
24 | 24 TPIKAA E
25 | 25 TPIKAA E
26 | 26 TPIKAA A
27 | 27 TPIKAA A
28 | 28 TPIKAA G
29 | 29 TPIKAA R
30 | 30 TPIKAA R
31 | 31 TPIKAA S
32 | 32 TPIKAA I
33 | 33 TPIKAA Y
34 | 34 TPIKAA C
35 | 35 TPIKAA D
36 | 36 TPIKAA D
37 | 37 TPIKAA H
38 | 38 TPIKAA E
39 | 39 TPIKAA I
40 | 40 TPIKAA Y
41 | 41 TPIKAA R
42 | 42 TPIKAA Q
43 | 43 TPIKAA D
44 | 44 TPIKAA N
45 | 45 TPIKAA V
46 | 46 TPIKAA P
47 | 47 TPIKAA L
48 | 48 TPIKAA L
49 | 49 TPIKAA R
50 | 50 TPIKAA S
51 | 51 TPIKAA Y
52 | 52 TPIKAA Q
53 | 53 TPIKAA V
54 | 54 TPIKAA L
55 | 55 TPIKAA P
56 | 56 TPIKAA L
57 | 57 TPIKAA S
58 | 58 TPIKAA K
59 | 59 TPIKAA G
60 | 60 TPIKAA R
61 | 61 TPIKAA V
62 | 62 TPIKAA S
63 | 63 TPIKAA G
64 | 64 TPIKAA F
65 | 65 TPIKAA M
66 | 66 TPIKAA E
67 | 67 TPIKAA I
68 | 68 TPIKAA T
69 | 69 TPIKAA P
70 | 70 TPIKAA Q
71 | 71 TPIKAA K
72 | 72 TPIKAA A
73 | 73 TPIKAA G
74 | 74 TPIKAA D
75 | 75 TPIKAA Y
76 | 76 TPIKAA R
77 | 77 TPIKAA Y
78 | 78 TPIKAA S
79 | 79 TPIKAA F
80 | 80 TPIKAA C
81 | 81 TPIKAA I
82 | 82 TPIKAA N
83 | 83 TPIKAA G
84 | 84 TPIKAA Q
85 | 85 TPIKAA Q
86 | 86 TPIKAA R
87 | 87 TPIKAA I
88 | 88 TPIKAA I
89 | 89 TPIKAA G
90 | 90 TPIKAA K
91 |
--------------------------------------------------------------------------------
/txt/resfiles/resfile_1acf_gt_ex8.txt:
--------------------------------------------------------------------------------
1 | 97 POLAR
2 | 89 POLAR
3 | 22 POLAR
4 | 79 POLAR
5 |
--------------------------------------------------------------------------------
/txt/resfiles/resfile_1bkr_gt_ex6.txt:
--------------------------------------------------------------------------------
1 | 37 POLAR
2 | 98 POLAR
3 | 31 POLAR
4 | 40 POLAR
5 |
--------------------------------------------------------------------------------
/txt/resfiles/resfile_3mx7_gt_ex1.txt:
--------------------------------------------------------------------------------
1 | 65 POLAR
2 | 21 POLAR
3 | 32 POLAR
4 |
--------------------------------------------------------------------------------
/txt/resfiles/resfile_3mx7_gt_ex2.txt:
--------------------------------------------------------------------------------
1 | 65 POLAR
2 | 21 POLAR
3 | 32 POLAR
4 | 52 POLAR # previously APOLAR (V)
5 | 79 POLAR # previously APOLAR (Y)
6 |
--------------------------------------------------------------------------------
/txt/resfiles/some_PIKAA_one.txt:
--------------------------------------------------------------------------------
1 | 34 - 36 PIKAA C
2 | 30 - 33 POLAR
3 |
--------------------------------------------------------------------------------
/txt/resfiles/testing_TPIKAA_TNOTAA.txt:
--------------------------------------------------------------------------------
1 | 30 - 40 PIKAA CDAK
2 | 41 PIKAA A
3 | 41 TPIKAA K
4 | 45 TPIKAA D
5 | 46 TNOTAA HKRDESTNQAVLIMFYWPG # has to be C
6 |
--------------------------------------------------------------------------------
/txt/test_idx.txt:
--------------------------------------------------------------------------------
1 | 0
2 | 1
3 | 2
4 | 3
5 |
--------------------------------------------------------------------------------