├── .gitignore ├── LICENSE ├── README.md ├── affinity_search ├── NOTES ├── addrequests.py ├── cleanparams.py ├── do1request.py ├── evaluate.py ├── evaluate_cross.py ├── fix.py ├── ga_addrequests.py ├── getres.py ├── getresults.py ├── incremental_addrequests.py ├── makebesty.py ├── makemodel.py ├── makemodels.py ├── makemodels1.py ├── makemodels2.py ├── makemodels3.py ├── makemodels4.5.py ├── makemodels4.6.py ├── makemodels4.py ├── makemodels5.py ├── makemodels6.py ├── makereduced.py ├── outputjson.py ├── outputparams.py ├── outputsql.py ├── populatedefaults.py ├── populaterequests.py ├── populatesql.py ├── results.ipynb ├── reval.py ├── runline.py ├── single_axis_grid_search.py ├── summaries.ipynb └── trees.ipynb ├── bootstrap.py ├── calccenters.py ├── calctop.py ├── cd2020_pockets.txt ├── cgo_arrow.py ├── clean_kept_models.py ├── clustering.py ├── combine_fold_results.py ├── combine_rows.py ├── combine_rows_lowmem.py ├── compute_row.py ├── compute_seqs.py ├── counterexample_generation_jobs.py ├── create_caches.py ├── create_caches2.py ├── docs ├── README.md └── Using_the_Cluster.md ├── generate_counterexample_typeslines.py ├── generate_unique_lig_poses.py ├── gly_gly_gly.pdb ├── grid_visualization.py ├── models └── lenet1.template ├── pdbbind2017_affs.txt ├── predict.py ├── pymol_arrows.py ├── reduce_data.py ├── show_xyz_arrows.py ├── simple_grid_visualization.py ├── timemodel.py ├── train.py ├── types2xyz.py └── types_extender.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016, gnina 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of scripts nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /affinity_search/addrequests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Check the sql database to see if the number of pending jobs is below 4 | a threshold. If so, download the table and run spearmint twice, once for 5 | top and once for R, each time generating N*3 new jobs. For each of the 6 | N configurations there are 3 variants with different splits and seeds. 7 | 8 | ''' 9 | 10 | import sys, re, MySQLdb, argparse, os, json, subprocess 11 | import pandas as pd 12 | import makemodel 13 | import numpy as np 14 | from MySQLdb.cursors import DictCursor 15 | from outputjson import makejson 16 | from populaterequests import addrows 17 | 18 | def getcursor(): 19 | '''create a connection and return a cursor; 20 | doing this guards against dropped connections''' 21 | conn = MySQLdb.connect (host = args.host,user = "opter",passwd=args.password,db=args.db) 22 | conn.autocommit(True) 23 | cursor = conn.cursor(DictCursor) 24 | return cursor 25 | 26 | 27 | parser = argparse.ArgumentParser(description='Generate more configurations if needed') 28 | parser.add_argument('--host',type=str,help='Database host',required=True) 29 | parser.add_argument('-p','--password',type=str,help='Database password',required=True) 30 | parser.add_argument('--db',type=str,help='Database name',default='opt1') 31 | parser.add_argument('--pending_threshold',type=int,default=12,help='Number of pending jobs that triggers an update') 32 | parser.add_argument('-n','--num_configs',type=int,default=4,help='Number of configs to generate - will add 6X as many jobs') 33 | parser.add_argument('-s','--spearmint',type=str,help='Location of spearmint-lite.py',required=True) 34 | 35 | args = parser.parse_args() 36 | 37 | opts = makemodel.getoptions() 38 | 39 | 40 | # first see how many id=REQUESTED jobs there are 41 | cursor = getcursor() 42 | cursor.execute('SELECT COUNT(*) FROM params WHERE id = "REQUESTED"') 43 | rows = cursor.fetchone() 44 | pending = list(rows.values())[0] 45 | cursor.close() 46 | 47 | print("Pending jobs:",pending) 48 | 49 | #if more than pending_threshold, quit 50 | if pending > args.pending_threshold: 51 | sys.exit(0) 52 | 53 | #create gnina-spearmint directory if it doesn't exist already 54 | if not os.path.exists('gnina-spearmint'): 55 | os.makedirs('gnina-spearmint') 56 | 57 | #create config.json 58 | cout = open('gnina-spearmint/config.json','w') 59 | cout.write(json.dumps(makejson(), indent=4)+'\n') 60 | cout.close() 61 | 62 | #for each of top and R 63 | for metric in ['top','R']: 64 | #get the whole database 65 | cursor = getcursor() 66 | cursor.execute('SELECT * FROM params') 67 | rows = cursor.fetchall() 68 | resf = open('gnina-spearmint/results.dat','w') 69 | #write out a results.dat file, P for NULL metric, negated for real 70 | uniqconfigs = set() 71 | for row in rows: 72 | config = [] 73 | for (name,vals) in sorted(opts.items()): 74 | if name == 'resolution': 75 | val = str(float(row[name])) #gets returned as 1 instead of 1.0 76 | else: 77 | val = str(row[name]) 78 | config.append(val) 79 | uniqconfigs.add(tuple(config)) 80 | if row[metric]: #not null 81 | resf.write('%f 0 '%-row[metric]) #spearmint tries to _minimize_ so negate 82 | else: 83 | resf.write('P P ') 84 | resf.write(' '.join(config)) 85 | resf.write('\n') 86 | resf.close() 87 | 88 | gseed = len(uniqconfigs) 89 | # run spearmint-light, set the seed to the number of unique configurations 90 | subprocess.call(['python',args.spearmint, '--method=GPEIOptChooser', '--grid-size=20000', 91 | 'gnina-spearmint', '--n=%d'%args.num_configs, '--grid-seed=%d' % gseed]) 92 | print(['python',args.spearmint, '--method=GPEIOptChooser', '--grid-size=20000', 93 | 'gnina-spearmint', '--n=%d'%args.num_configs, '--grid-seed=%d' % gseed]) 94 | #get the generated lines from the file 95 | lines = open('gnina-spearmint/results.dat').readlines() 96 | newlines = lines[len(rows):] 97 | print(len(newlines),args.num_configs) 98 | assert(len(newlines) > 0) 99 | print(newlines) 100 | #add to database as REQUESTED jobs 101 | addrows('gnina-spearmint/results.dat',args.host,args.db,args.password,start=len(rows)) 102 | 103 | -------------------------------------------------------------------------------- /affinity_search/cleanparams.py: -------------------------------------------------------------------------------- 1 | import makemodel 2 | 3 | modeldefaults = makemodel.getdefaults() 4 | 5 | def cleanparams(p): 6 | '''standardize params that do not matter''' 7 | for i in range(1,6): 8 | if p['conv%d_width'%i] == 0: 9 | for suffix in ['func', 'init', 'norm', 'size', 'stride', 'width']: 10 | name = 'conv%d_%s'%(i,suffix) 11 | p[name] = modeldefaults[name] 12 | if p['pool%d_size'%i] == 0: 13 | name = 'pool%d_type'%i 14 | p[name] = modeldefaults[name] 15 | 16 | if p['fc_pose_hidden'] == 0: 17 | p['fc_pose_func'] = modeldefaults['fc_pose_func'] 18 | p['fc_pose_hidden2'] = modeldefaults['fc_pose_hidden2'] 19 | p['fc_pose_func2'] = modeldefaults['fc_pose_func2'] 20 | elif p['fc_pose_hidden2'] == 0: 21 | p['fc_pose_hidden2'] = modeldefaults['fc_pose_hidden2'] 22 | p['fc_pose_func2'] = modeldefaults['fc_pose_func2'] 23 | 24 | if p['fc_affinity_hidden'] == 0: 25 | p['fc_affinity_func'] = modeldefaults['fc_affinity_func'] 26 | p['fc_affinity_hidden2'] = modeldefaults['fc_affinity_hidden2'] 27 | p['fc_affinity_func2'] = modeldefaults['fc_affinity_func2'] 28 | elif p['fc_affinity_hidden2'] == 0: 29 | p['fc_affinity_hidden2'] = modeldefaults['fc_affinity_hidden2'] 30 | p['fc_affinity_func2'] = modeldefaults['fc_affinity_func2'] 31 | 32 | for (name,val) in p.items(): 33 | if 'item' in dir(val): 34 | p[name] = np.asscalar(val) 35 | if type(p[name]) == int: 36 | p[name] = float(p[name]) 37 | return p 38 | -------------------------------------------------------------------------------- /affinity_search/do1request.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | '''Connects to sql database. 3 | Checks atomically to see if there are an configurations that should be run 4 | because they are requested (R). If so, runs one 5 | ''' 6 | 7 | #https://hyperopt-186617.appspot.com 8 | 9 | import sys, re, MySQLdb, argparse, socket, tempfile 10 | import pandas as pd 11 | import numpy as np 12 | import makemodel 13 | import subprocess, os, json 14 | from MySQLdb.cursors import DictCursor 15 | 16 | parser = argparse.ArgumentParser(description='Run a configuration as part of a search') 17 | parser.add_argument('--data_root',type=str,help='Location of gninatypes directory',default='') 18 | parser.add_argument('--prefix',type=str,help='gninatypes prefix, needs to be absolute',required=True) 19 | parser.add_argument('--host',type=str,help='Database host',required=True) 20 | parser.add_argument('-p','--password',type=str,help='Database password',required=True) 21 | parser.add_argument('--db',type=str,help='Database name',required=True) 22 | parser.add_argument('--ligmap',type=str,help="Ligand atom typing map to use",default='') 23 | parser.add_argument('--recmap',type=str,help="Receptor atom typing map to use",default='') 24 | 25 | args = parser.parse_args() 26 | 27 | def get_script_path(): 28 | return os.path.dirname(os.path.realpath(sys.argv[0])) 29 | 30 | def rm(inprogressname): 31 | try: 32 | print("Removing",inprogressname) 33 | os.remove(inprogressname) 34 | except OSError: 35 | print("Error removing",inprogressname) 36 | pass 37 | 38 | sys.path.append(get_script_path()) 39 | sys.path.append(get_script_path()+'/..') #train 40 | 41 | def getcursor(): 42 | '''create a connection and return a cursor; 43 | doing this guards against dropped connections''' 44 | conn = MySQLdb.connect (host = args.host,user = "opter",passwd=args.password,db=args.db) 45 | conn.autocommit(True) 46 | cursor = conn.cursor(DictCursor) 47 | return cursor 48 | 49 | def getgpuid(): 50 | '''return unique id of gpu 0''' 51 | gpuid = '0000' 52 | try: 53 | output = subprocess.check_output('nvidia-smi',shell=True,stderr=subprocess.STDOUT) 54 | m = re.search(r'00000:(\S\S:\S\S.\S) ',output) 55 | if m: 56 | gpuid = m.group(1) 57 | except Exception as e: 58 | print(e.output) 59 | print(e) 60 | print("Error accessing gpu") 61 | sys.exit(1) 62 | return gpuid 63 | 64 | opts = makemodel.getoptions() 65 | cursor = getcursor() 66 | 67 | host = socket.gethostname() 68 | 69 | # determine a configuration to run 70 | configs = None #map from name to value 71 | 72 | # check for an in progress file 73 | inprogressname = '%s-%s-INPROGRESS' % (host,getgpuid()) 74 | print(inprogressname) 75 | 76 | if os.path.isfile(inprogressname): 77 | config = json.load(open(inprogressname)) 78 | d = config['msg'] 79 | print("Retrying with config: %s" % json.dumps(config)) 80 | else: 81 | #are there any requested configurations? if so select one 82 | cursor.execute('SELECT * FROM params WHERE id = "REQUESTED"') 83 | rows = cursor.fetchall() 84 | config = None 85 | for row in rows: 86 | # need to atomically update id 87 | ret = cursor.execute('UPDATE params SET id = "INPROGRESS", msg = "%s" WHERE serial = %s AND id = "REQUESTED"',[inprogressname,row['serial']]) 88 | if ret: # success! 89 | #set config 90 | config = row 91 | break 92 | 93 | if config: #write out what we're doing 94 | d = tempfile.mkdtemp(prefix=socket.gethostname() +'-',dir='.') 95 | config['msg'] = d 96 | progout = open(inprogressname,'w') 97 | if 'time' in config: 98 | del config['time'] 99 | progout.write(json.dumps(config)) 100 | progout.close() 101 | 102 | 103 | if not config: 104 | print("Nothing requested") 105 | sys.exit(2) # there was nothing to do, perhaps we should shutdown? 106 | 107 | #at this point have a configuration 108 | values = ['0','0'] 109 | for (name,val) in sorted(opts.items()): 110 | values.append(str(config[name])) 111 | 112 | cmdline = '%s/runline.py --prefix %s --data_root "%s" --seed %d --split %d --dir %s --line "%s"' % \ 113 | (get_script_path(), args.prefix,args.data_root,config['seed'],config['split'], config['msg'], ' '.join(values)) 114 | if(args.ligmap): cmdline += " --ligmap %s"%args.ligmap 115 | if(args.recmap): cmdline += " --recmap %s"%args.recmap 116 | print(cmdline) 117 | 118 | #call runline to insulate ourselves from catestrophic failure (caffe) 119 | try: 120 | output = subprocess.check_output(cmdline,shell=True,stderr=subprocess.STDOUT) 121 | d, R, rmse, auc, top = output.rstrip().split('\n')[-1].split() 122 | pid = os.getpid() 123 | out = open('output.%s.%d'%(host,pid),'w') 124 | out.write(output) 125 | except Exception as e: 126 | pid = os.getpid() 127 | out = open('output.%s.%d'%(host,pid),'w') 128 | if isinstance(e, subprocess.CalledProcessError): 129 | output = e.output 130 | out.write(output) 131 | cursor = getcursor() 132 | cursor.execute('UPDATE params SET id = "ERROR", msg = %s WHERE serial = %s',(str(pid),config['serial'])) 133 | print("Error") 134 | print(output) 135 | if re.search(r'out of memory',output) and (host.startswith('gnina') ): 136 | #host migration restarts don't seem to bring the gpu up in agood state 137 | print("REBOOTING") 138 | os.system("sudo reboot") 139 | rm(inprogressname) 140 | sys.exit(0) #we tried 141 | 142 | 143 | #if successful, store in database 144 | 145 | config['rmse'] = float(rmse) 146 | config['R'] = float(R) 147 | config['top'] = float(top) 148 | config['auc'] = float(auc) 149 | config['id'] = d 150 | config['msg'] = 'SUCCESS' 151 | 152 | serial = config['serial'] 153 | del config['serial'] 154 | sql = 'UPDATE params SET {} WHERE serial = {}'.format(', '.join('{}=%s'.format(k) for k in config),serial) 155 | cursor = getcursor() 156 | cursor.execute(sql, list(config.values())) 157 | 158 | rm(inprogressname) 159 | 160 | -------------------------------------------------------------------------------- /affinity_search/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | '''Take a prefix and model name run predictions, and generate evaluations for crystal, bestonly, 4 | and all test sets (take max affinity; if pose score is available also consider 5 | max pose score). 6 | Generates graphs and overall CV results. Takes the prefix and (for now) assumes trial 0. 7 | Will evaluate 100k model and best model prior to 100k, 50k and 25k 8 | ''' 9 | 10 | import numpy as np 11 | import os, sys 12 | #os.environ["GLOG_minloglevel"] = "10" 13 | sys.path.append("/home/dkoes/git/gninascripts/") 14 | sys.path.append("/net/pulsar/home/koes/dkoes/git/gninascripts/") 15 | 16 | import train, predict 17 | import matplotlib, caffe 18 | import matplotlib.pyplot as plt 19 | import glob, re, sklearn, collections, argparse, sys 20 | import sklearn.metrics 21 | import scipy.stats 22 | 23 | 24 | def evaluate_fold(testfile, caffemodel, modelname, datadir='../..',hasrmsd=False): 25 | '''Evaluate the passed model and the specified test set. 26 | Returns tuple: 27 | (correct, prediction, receptor, ligand, label (optional), posescore (optional)) 28 | label and posescore are only provided is trained on pose data 29 | ''' 30 | if not os.path.exists(modelname): 31 | print(modelname,"does not exist") 32 | 33 | caffe.set_mode_gpu() 34 | test_model = 'predict.%d.prototxt' % os.getpid() 35 | train.write_model_file(test_model, modelname, testfile, testfile, datadir) 36 | test_net = caffe.Net(test_model, caffemodel, caffe.TEST) 37 | lines = open(testfile).readlines() 38 | res = None 39 | i = 0 #index in batch 40 | correct = 0 41 | prediction = 0 42 | receptor = '' 43 | ligand = '' 44 | label = 0 45 | posescore = -1 46 | ret = [] 47 | for line in lines: 48 | #check if we need a new batch of results 49 | if not res or i >= batch_size: 50 | res = test_net.forward() 51 | if 'output' in res: 52 | batch_size = res['output'].shape[0] 53 | else: 54 | batch_size = res['affout'].shape[0] 55 | i = 0 56 | 57 | if 'labelout' in res: 58 | label = float(res['labelout'][i]) 59 | 60 | if 'output' in res: 61 | posescore = float(res['output'][i][1]) 62 | 63 | if 'affout' in res: 64 | correct = float(res['affout'][i]) 65 | 66 | if 'predaff' in res: 67 | prediction = float(res['predaff'][i]) 68 | if not np.isfinite(prediction).all(): 69 | os.remove(test_model) 70 | return [] #gracefully handle nan? 71 | 72 | #extract ligand/receptor for input file 73 | tokens = line.split() 74 | rmsd = -1 75 | for t in range(len(tokens)): 76 | if tokens[t].lower()=='none': 77 | #Flag that none as the receptor file, for ligand-only models 78 | ligand=tokens[t+1] 79 | 80 | #we assume that ligand is rec/ 81 | #set if correct, bail if not. 82 | m=re.search(r'(\S+)/(\S+)gninatypes',ligand) 83 | 84 | #Check that the match is not none, and that ligand ends in gninatypes 85 | if m is not None: 86 | receptor=m.group(1) 87 | else: 88 | print('Error: none receptor detected and ligand is improperly formatted.') 89 | print('Ligand must be formatted: /.gninatypes') 90 | print('Bailing.') 91 | sys.exit(1) 92 | break 93 | 94 | elif tokens[t].endswith('gninatypes'): 95 | receptor = tokens[t] 96 | ligand = tokens[t+1] 97 | break 98 | if hasrmsd: 99 | rmsd = float(tokens[2]) 100 | #(correct, prediction, receptor, ligand, label (optional), posescore (optional)) 101 | if posescore < 0: 102 | ret.append((correct, prediction, receptor, ligand)) 103 | elif hasrmsd: 104 | ret.append((correct, prediction, receptor, ligand, label, posescore, rmsd)) 105 | else: 106 | ret.append((correct, prediction, receptor, ligand, label, posescore)) 107 | 108 | i += 1 #batch index 109 | 110 | os.remove(test_model) 111 | return ret 112 | 113 | 114 | def reduce_results(results, index): 115 | '''Return results with only one tuple for every receptor value, 116 | taking the one with the max value at index in the tuple (predicted affinity or pose score) 117 | ''' 118 | res = dict() #indexed by receptor 119 | for r in results: 120 | name = r[2] 121 | if name not in res: 122 | res[name] = r 123 | elif res[name][index] < r[index]: 124 | res[name] = r 125 | return list(res.values()) 126 | 127 | def analyze_results(results, outname, uniquify=None): 128 | '''Compute error metrics from resuls. RMSE, Pearson, Spearman. 129 | If uniquify is set, AUC and top-1 percentage are also computed, 130 | uniquify can be None, 'affinity', or 'pose' and is set with 131 | the all training set to select the pose used for scoring. 132 | Returns tuple: 133 | (RMSE, Pearson, Spearman, AUCpose, AUCaffinity, top-1) 134 | Writes (correct,prediction) pairs to outname.predictions 135 | ''' 136 | 137 | #calc auc before reduction 138 | if uniquify and len(results[0]) > 5: 139 | labels = np.array([r[4] for r in results]) 140 | posescores = np.array([r[5] for r in results]) 141 | predictions = np.array([r[1] for r in results]) 142 | aucpose = sklearn.metrics.roc_auc_score(labels, posescores) 143 | aucaff = sklearn.metrics.roc_auc_score(labels, predictions) 144 | 145 | if uniquify == 'affinity': 146 | results = reduce_results(results, 1) 147 | elif uniquify == 'pose': 148 | results = reduce_results(results, 5) 149 | 150 | predictions = np.array([r[1] for r in results]) 151 | correctaff = np.array([abs(r[0]) for r in results]) 152 | #(correct, prediction, receptor, ligand, label (optional), posescore (optional)) 153 | 154 | rmse = np.sqrt(sklearn.metrics.mean_squared_error(correctaff, predictions)) 155 | R = scipy.stats.pearsonr(correctaff, predictions)[0] 156 | S = scipy.stats.spearmanr(correctaff, predictions)[0] 157 | out = open('%s.predictions'%outname,'w') 158 | for (c,p) in zip(correctaff,predictions): 159 | out.write('%f %f\n' % (c,p)) 160 | out.write('#RMSD %f\n'%rmse) 161 | out.write('#R %f\n'%R) 162 | 163 | if uniquify and len(results[0]) > 5: 164 | labels = np.array([r[4] for r in results]) 165 | top = np.count_nonzero(labels > 0)/float(len(labels)) 166 | return (rmse, R, S, aucpose, aucaff, top) 167 | else: 168 | return (rmse, R, S) 169 | 170 | 171 | if __name__ == '__main__': 172 | if len(sys.argv) <= 4: 173 | print("Need caffemodel prefix, modelname, output name and test prefixes (which should include __ at end)") 174 | sys.exit(1) 175 | 176 | name = sys.argv[1] 177 | modelname = sys.argv[2] 178 | out = open(sys.argv[3],'w') 179 | 180 | allresults = [] 181 | last = None 182 | #for each test dataset 183 | for testprefix in sys.argv[4:]: 184 | m = re.search('([^/ ]*)_(\d+)_$', testprefix) 185 | print(m,testprefix) 186 | if not m: 187 | print(testprefix,"does not end in slicenum") 188 | slicenum = int(m.group(2)) 189 | testname = m.group(1) 190 | #find the relevant models for each fold 191 | testresults = {'best25': [], 'best50': [], 'best100': [], 'last': [], 'best250': [] } 192 | for fold in [0,1,2]: 193 | best25 = 0 194 | best50 = 0 195 | best100 = 0 196 | best250 = 0 197 | lastm = 0 198 | #identify best iteration models at each cut point for this fold 199 | for model in glob.glob('%s.%d_iter_*.caffemodel'%(name,fold)): 200 | m = re.search(r'_iter_(\d+).caffemodel', model) 201 | inum = int(m.group(1)) 202 | if inum < 25000 and inum > best25: 203 | best25 = inum 204 | if inum < 50000 and inum > best50: 205 | best50 = inum 206 | if inum < 100000 and inum > best100: 207 | best100 = inum 208 | if inum < 250000 and inum > best250: 209 | best250 = inum 210 | if inum > lastm: 211 | lastm = inum 212 | #evalute this fold 213 | testfile = '../types/%stest%d.types' % (testprefix,fold) 214 | #todo, avoid redundant repetitions 215 | if best25 > 0: testresults['best25'] += evaluate_fold(testfile, '%s.%d_iter_%d.caffemodel' % (name,fold,best25), modelname) 216 | if best50 > 0: testresults['best50'] += evaluate_fold(testfile, '%s.%d_iter_%d.caffemodel' % (name,fold,best50), modelname) 217 | if best100 > 0: testresults['best100'] += evaluate_fold(testfile, '%s.%d_iter_%d.caffemodel' % (name,fold,best100), modelname) 218 | if best250 > 0: testresults['best250'] += evaluate_fold(testfile, '%s.%d_iter_%d.caffemodel' % (name,fold,best250), modelname) 219 | if lastm > 0: testresults['last'] += evaluate_fold(testfile, '%s.%d_iter_%d.caffemodel' % (name,fold,lastm), modelname) 220 | 221 | 222 | for n in list(testresults.keys()): 223 | if len(testresults[n]) == 0: 224 | continue 225 | if len(testresults[n][0]) == 6: 226 | allresults.append( ('%s_pose'%testname, n) + analyze_results(testresults[n],('%s_pose_'%testname)+name+'_'+n,'pose')) 227 | allresults.append( ('%s_affinity'%testname, n) + analyze_results(testresults[n],('%s_affinity_'%testname)+name+'_'+n,'affinity')) 228 | 229 | 230 | for a in allresults: 231 | out.write(' '.join(map(str,a))+'\n') 232 | -------------------------------------------------------------------------------- /affinity_search/evaluate_cross.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | '''Take a prefix, model name and output and generate predictions for 4 | a cross-docked formatted dataset. Will use the last model 5 | 6 | ''' 7 | 8 | import numpy as np 9 | import os, sys, argparse 10 | #os.environ["GLOG_minloglevel"] = "10" 11 | sys.path.append("/home/dkoes/git/gninascripts/") 12 | sys.path.append("/net/pulsar/home/koes/dkoes/git/gninascripts/") 13 | sys.path.append("/home/dkoes/git/gninascripts/affinity_search") 14 | sys.path.append("/net/pulsar/home/koes/dkoes/git/gninascripts/affinity_search") 15 | import train, predict 16 | import matplotlib, caffe 17 | import matplotlib.pyplot as plt 18 | import glob, re, sklearn, collections, argparse, sys 19 | import sklearn.metrics 20 | import scipy.stats 21 | 22 | from evaluate import evaluate_fold 23 | 24 | 25 | def reduce_results(results, index, which): 26 | '''Return results with only one tuple for every pocket-ligand value, 27 | taking the one with the max value at index in the tuple (predicted affinity or pose score) 28 | ''' 29 | res = dict() #indexed by pocketligand 30 | for r in results: 31 | lname = r[3] 32 | m = re.search(r'(\S+)/...._._rec_...._(\S+)_(lig|uff_min)',lname) 33 | pocket = m.group(1) 34 | lig = m.group(2) 35 | key = pocket+':'+lig 36 | if key not in res: 37 | res[key] = r 38 | else: 39 | if which == 'small': #select smallest by index 40 | if res[key][index] > r[index]: 41 | res[key] = r 42 | elif res[key][index] < r[index]: 43 | res[key] = r 44 | 45 | return list(res.values()) 46 | 47 | def analyze_cross_results(results,outname,uniquify): 48 | '''Compute error metrics from resulst. 49 | results is formated: (correct, prediction, receptor, ligand, label, posescore,rmsd) 50 | This is assumed to be a cross docked input, where receptor filename is 51 | POCKET/PDB_CH_rec_0.gninatypes 52 | and the ligand is 53 | POCKET/PDB1_CH_rec_PDB2_lig_...gninatypes 54 | 55 | 56 | RMSE, Pearson, Spearman, AUC and top-1 percentage are computed. 57 | 58 | select can be pose or rmsd. With pose we use the best pose scoring pose 59 | of all ligand poses in a pocket. For rmsd we use the lowest rmsd line for 60 | that ligand in the pocket. 61 | 62 | AUC is calculated before any reduction. 63 | 64 | Writes the reduced set to outname.predictions 65 | ''' 66 | 67 | #calc auc before reduction 68 | labels = np.array([r[4] for r in results]) 69 | posescores = np.array([r[5] for r in results]) 70 | predictions = np.array([r[1] for r in results]) 71 | aucpose = sklearn.metrics.roc_auc_score(labels, posescores) 72 | aucaff = sklearn.metrics.roc_auc_score(labels, predictions) 73 | 74 | if uniquify == 'rmsd': 75 | results = reduce_results(results, 6, 'small') 76 | elif uniquify == 'affinity': 77 | results = reduce_results(results, 1, 'large') 78 | elif uniquify == 'pose': 79 | results = reduce_results(results, 5, 'large') 80 | 81 | predictions = np.array([r[1] for r in results]) 82 | correctaff = np.array([abs(r[0]) for r in results]) 83 | #(correct, prediction, receptor, ligand, label (optional), posescore (optional)) 84 | #skip cases with no affinity 85 | predictions = predictions[correctaff != 0] 86 | correctaff = np.abs(correctaff[correctaff != 0]) 87 | rmse = np.sqrt(sklearn.metrics.mean_squared_error(correctaff, predictions)) 88 | R = scipy.stats.pearsonr(correctaff, predictions)[0] 89 | S = scipy.stats.spearmanr(correctaff, predictions)[0] 90 | out = open('%s.predictions'%outname,'w') 91 | out.write('aff,pred,rec,lig,lab,score,rmsd\n') 92 | for res in results: 93 | out.write(','.join(map(str,res))+'\n') 94 | out.write('#RMSD %f\n'%rmse) 95 | out.write('#R %f\n'%R) 96 | 97 | labels = np.array([r[4] for r in results]) 98 | top = np.count_nonzero(labels > 0)/float(len(labels)) 99 | return (rmse, R, S, aucpose, aucaff, top) 100 | 101 | 102 | 103 | if __name__ == '__main__': 104 | 105 | parser=argparse.ArgumentParser(description='Evaluate 3 fold CV data') 106 | parser.add_argument('-d','--datadir',default='.',help='ROOT folder for files specified in types files. Defaults to current working directory.') 107 | parser.add_argument('-w','--weights_prefix',type=str, required=True, help='Prefix to the weights. Format ._iter_.caffemodel') 108 | parser.add_argument('-m','--model',type=str, required=True, help='Caffe model file.') 109 | parser.add_argument('-o','--outprefix',type=str,required=True, help='Prefix for output files. Generates several *.predictions files, and a .summary file') 110 | parser.add_argument('-t','--testprefix',type=str,required=True, nargs='+',help='Prefix to the test types files. Format test.types') 111 | parser.add_argument('--has_rmsd',action='store_true', default=False, help='Flag that RMSD values are in the types files. Defaults to False') 112 | 113 | args=parser.parse_args() 114 | 115 | datadir = args.datadir 116 | name = args.weights_prefix 117 | modelname = args.model 118 | testname = args.outprefix 119 | out = open(testname+'.summary','w') 120 | 121 | m = re.search(r'(_fn\d)',testname) 122 | if m: #remove fold name 123 | testname = testname.replace(m.group(1),'') 124 | allresults = [] 125 | last = None 126 | #for each test dataset 127 | for testprefix in args.testprefix: 128 | print(testprefix) 129 | #find the relevant models for each fold 130 | 131 | testresults = [] 132 | for fold in [0,1,2]: #blah! hard coded 133 | lastm = 0 134 | #identify last iteration model for this fold 135 | for model in glob.glob('%s.%d_iter_*.caffemodel'%(name,fold)): 136 | m = re.search(r'_iter_(\d+).caffemodel', model) 137 | inum = int(m.group(1)) 138 | if inum > lastm: 139 | lastm = inum 140 | 141 | #evalute this fold 142 | testfile = '%stest%d.types' % (testprefix,fold) 143 | testresults += evaluate_fold(testfile, '%s.%d_iter_%d.caffemodel' % (name,fold,lastm), modelname, datadir, args.has_rmsd) 144 | 145 | if len(testresults) == 0: 146 | print("Missing data with",testprefix) 147 | if args.has_rmsd: 148 | assert(len(testresults[0]) == 7) 149 | else: 150 | assert(len(testresults[0]) == 6) 151 | 152 | allresults.append( (testname,'pose') + analyze_cross_results(testresults,testname+'_pose','pose')) 153 | if args.has_rmsd: 154 | allresults.append( (testname,'rmsd') + analyze_cross_results(testresults,testname+'_rmsd','rmsd')) 155 | allresults.append( (testname,'affinity') + analyze_cross_results(testresults,testname+'_affinity','affinity')) 156 | 157 | 158 | for a in allresults: 159 | out.write(' '.join(map(str,a))+'\n') 160 | -------------------------------------------------------------------------------- /affinity_search/fix.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Grab all the "Sucess" examples from the database, look if any of the directories exist. 4 | If they do, reval them and update the database''' 5 | 6 | import sys, re, MySQLdb, os, argparse, subprocess 7 | from MySQLdb.cursors import DictCursor 8 | 9 | 10 | parser = argparse.ArgumentParser(description='Fix evaluation of trained models.') 11 | parser.add_argument('--data_root',type=str,help='Location of gninatypes directory',default='') 12 | parser.add_argument('--prefix',type=str,help='Prefix, not including split',default='../data/refined/all_0.5_') 13 | parser.add_argument('-p','--password',type=str,help='Database password',required=True) 14 | parser.add_argument('--reval',type=str,help="reva.py",default='./reval.py') 15 | args = parser.parse_args() 16 | 17 | def getcursor(): 18 | '''create a connection and return a cursor; 19 | doing this guards against dropped connections''' 20 | conn = MySQLdb.connect (host = "35.196.158.205",user = "opter",passwd=args.password,db="opt1") 21 | conn.autocommit(True) 22 | cursor = conn.cursor(DictCursor) 23 | return cursor 24 | 25 | cursor = getcursor() 26 | cursor.execute('SELECT * FROM params WHERE msg = "Sucess"') 27 | rows = cursor.fetchall() 28 | 29 | for row in rows: 30 | if not os.path.isdir(row['id']): 31 | continue 32 | 33 | # need to atomically update msg 34 | ret = cursor.execute('UPDATE params SET msg = "Pending" WHERE serial = %s AND msg = "Sucess"',[row['serial']]) 35 | if not ret: # try next 36 | continue 37 | 38 | print(row['id']) 39 | 40 | cmdline = '%s --prefix %s --data_root "%s" --split %d --dir %s' % \ 41 | (args.reval, args.prefix,args.data_root,row['split'], row['id']) 42 | print(cmdline) 43 | 44 | #call runline to insulate ourselves from catestrophic failure (caffe) 45 | try: 46 | output = subprocess.check_output(cmdline,shell=True,stderr=subprocess.STDOUT) 47 | d, R, rmse, auc, top = output.rstrip().split('\n')[-1].split() 48 | except Exception as e: 49 | print(e.output) 50 | print(e) 51 | print("Problem with",row['id']) 52 | continue 53 | 54 | print(d, R, rmse, auc, top) 55 | sql = 'UPDATE params SET R={},rmse={},msg="SUCCESS" WHERE serial = {}'.format(R,rmse,row['serial']) 56 | cursor = getcursor() 57 | cursor.execute(sql) 58 | -------------------------------------------------------------------------------- /affinity_search/ga_addrequests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Train a random forest on model performance from an sql database and then 4 | run a genetic algorithm to propose new, better models to run. 5 | 6 | 7 | ''' 8 | 9 | import sys, re, MySQLdb, argparse, os, json, subprocess 10 | import pandas as pd 11 | import makemodel 12 | import numpy as np 13 | from MySQLdb.cursors import DictCursor 14 | from outputjson import makejson 15 | from MySQLdb.cursors import DictCursor 16 | from frozendict import frozendict 17 | 18 | import sklearn 19 | from sklearn.ensemble import * 20 | from sklearn.preprocessing import * 21 | from sklearn.feature_extraction import * 22 | 23 | import deap 24 | from deap import base, creator, gp, tools 25 | from deap import algorithms 26 | 27 | from deap import * 28 | import multiprocessing 29 | 30 | def getcursor(host,passwd,db): 31 | '''create a connection and return a cursor; 32 | doing this guards against dropped connections''' 33 | conn = MySQLdb.connect (host = host,user = "opter",passwd=passwd,db=db) 34 | conn.autocommit(True) 35 | cursor = conn.cursor(DictCursor) 36 | return cursor 37 | 38 | def cleanparams(p): 39 | '''standardize params that do not matter''' 40 | modeldefaults = makemodel.getdefaults() 41 | for i in range(1,6): 42 | if p['conv%d_width'%i] == 0: 43 | for suffix in ['func', 'init', 'norm', 'size', 'stride', 'width']: 44 | name = 'conv%d_%s'%(i,suffix) 45 | p[name] = modeldefaults[name] 46 | if p['pool%d_size'%i] == 0: 47 | name = 'pool%d_type'%i 48 | p[name] = modeldefaults[name] 49 | if p['fc_pose_hidden'] == 0: 50 | p['fc_pose_func'] = modeldefaults['fc_pose_func'] 51 | p['fc_pose_hidden2'] = modeldefaults['fc_pose_hidden2'] 52 | p['fc_pose_func2'] = modeldefaults['fc_pose_func2'] 53 | p['fc_pose_init'] = modeldefaults['fc_pose_init'] 54 | elif p['fc_pose_hidden2'] == 0: 55 | p['fc_pose_hidden2'] = modeldefaults['fc_pose_hidden2'] 56 | p['fc_pose_func2'] = modeldefaults['fc_pose_func2'] 57 | 58 | if p['fc_affinity_hidden'] == 0: 59 | p['fc_affinity_func'] = modeldefaults['fc_affinity_func'] 60 | p['fc_affinity_hidden2'] = modeldefaults['fc_affinity_hidden2'] 61 | p['fc_affinity_func2'] = modeldefaults['fc_affinity_func2'] 62 | p['fc_affinity_init'] = modeldefaults['fc_affinity_init'] 63 | elif p['fc_affinity_hidden2'] == 0: 64 | p['fc_affinity_hidden2'] = modeldefaults['fc_affinity_hidden2'] 65 | p['fc_affinity_func2'] = modeldefaults['fc_affinity_func2'] 66 | return p 67 | 68 | def randParam(param, choices): 69 | '''randomly select a choice for param''' 70 | if isinstance(choices, makemodel.Range): #discretize 71 | choices = np.linspace(choices.min,choices.max, 9) 72 | return np.asscalar(np.random.choice(choices)) 73 | 74 | def randomIndividual(): 75 | ret = dict() 76 | options = makemodel.getoptions() 77 | for (param,choices) in options.items(): 78 | ret[param] = randParam(param, choices) 79 | 80 | return cleanparams(ret) 81 | 82 | def evaluateIndividual(ind): 83 | x = dictvec.transform(ind) 84 | return [rf.predict(x)[0]] 85 | 86 | def mutateIndividual(ind, indpb=0.05): 87 | '''for each param, with prob indpb randomly sample another choice''' 88 | options = makemodel.getoptions() 89 | for (param,choices) in options.items(): 90 | if np.random.rand() < indpb: 91 | ind[param] = randParam(param, choices) 92 | return (ind,) 93 | 94 | def crossover(ind1, ind2, indpdb=0.5): 95 | '''swap choices with probability indpb''' 96 | options = makemodel.getoptions() 97 | for (param,choices) in options.items(): 98 | if np.random.rand() < indpdb: 99 | tmp = ind1[param] 100 | ind1[param] = ind2[param] 101 | ind2[param] = tmp 102 | return (ind1,ind2) 103 | 104 | def runGA(pop): 105 | '''run GA with early stopping if not improving''' 106 | hof = tools.HallOfFame(10) 107 | stats = tools.Statistics(lambda ind: ind.fitness.values) 108 | stats.register("avg", np.mean) 109 | stats.register("std", np.std) 110 | stats.register("min", np.min) 111 | stats.register("max", np.max) 112 | best = 0 113 | pop = toolbox.clone(pop) 114 | for i in range(40): 115 | pop, log = algorithms.eaMuPlusLambda(pop, toolbox, mu=300, lambda_=300, cxpb=0.5, mutpb=0.2, ngen=25, 116 | stats=stats, halloffame=hof, verbose=True) 117 | newmax = log[-1]['max'] 118 | if best == newmax: 119 | break 120 | best = newmax 121 | return pop 122 | 123 | def addrows(config,host,db,password): 124 | '''add rows from fname into database, starting at row start''' 125 | 126 | conn = MySQLdb.connect (host = host,user = "opter",passwd=password,db=db) 127 | cursor = conn.cursor() 128 | 129 | items = list(config.items()) 130 | names = ','.join([str(n) for (n,v) in items]) 131 | values = ','.join(['%s' for (n,v) in items]) 132 | names += ',id' 133 | values += ',"REQUESTED"' 134 | 135 | #do five variations 136 | for split in range(5): 137 | seed = np.random.randint(0,100000) 138 | n = names + ',split,seed' 139 | v = values + ',%d,%d' % (split,seed) 140 | insert = 'INSERT INTO params (%s) VALUES (%s)' % (n,v) 141 | cursor.execute(insert,[v for (n,v) in items]) 142 | 143 | conn.commit() 144 | 145 | 146 | parser = argparse.ArgumentParser(description='Generate more configurations with random forest and genetic algorithms') 147 | parser.add_argument('--host',type=str,help='Database host',required=True) 148 | parser.add_argument('-p','--password',type=str,help='Database password',required=True) 149 | parser.add_argument('--db',type=str,help='Database name',default='database') 150 | parser.add_argument('--pending_threshold',type=int,default=0,help='Number of pending jobs that triggers an update') 151 | parser.add_argument('-n','--num_configs',type=int,default=1,help='Number of configs to generate - will add a multiple as many jobs') 152 | args = parser.parse_args() 153 | 154 | 155 | 156 | # first see how many id=REQUESTED jobs there are 157 | cursor = getcursor(args.host,args.password,args.db) 158 | cursor.execute('SELECT COUNT(*) FROM params WHERE id = "REQUESTED"') 159 | rows = cursor.fetchone() 160 | pending = list(rows.values())[0] 161 | #print "Pending jobs:",pending 162 | sys.stdout.write('%d '%pending) 163 | sys.stdout.flush() 164 | 165 | #if more than pending_threshold, quit 166 | if pending > args.pending_threshold: 167 | sys.exit(0) 168 | 169 | 170 | cursor = getcursor(args.host,args.password,args.db) 171 | cursor.execute('SELECT * FROM params WHERE id != "REQUESTED"') 172 | rows = cursor.fetchall() 173 | data = pd.DataFrame(list(rows)) 174 | #make errors zero - appropriate if error is due to parameters 175 | data.loc[data.id == 'ERROR','R'] = 0 176 | data.loc[data.id == 'ERROR','rmse'] = 0 177 | data.loc[data.id == 'ERROR','top'] = 0 178 | data.loc[data.id == 'ERROR','auc'] = 0 179 | 180 | data['Rtop'] = data.R*data.top 181 | data = data.dropna('index').apply(pd.to_numeric, errors='ignore') 182 | 183 | #convert data to be useful for sklearn 184 | notparams = ['R','auc','Rtop','id','msg','rmse','seed','serial','time','top','split'] 185 | X = data.drop(notparams,axis=1) 186 | y = data.Rtop 187 | 188 | dictvec = DictVectorizer() 189 | #standardize meaningless params 190 | Xv = dictvec.fit_transform(list(map(cleanparams,X.to_dict(orient='records')))) 191 | 192 | print("\nTraining %d\n"%Xv.shape[0]) 193 | #train model 194 | rf = RandomForestRegressor(n_estimators=20) 195 | rf.fit(Xv,y) 196 | 197 | #set up GA 198 | creator.create("FitnessMax", base.Fitness, weights=(1.0,)) 199 | creator.create("Individual", dict, fitness=creator.FitnessMax) 200 | 201 | toolbox = base.Toolbox() 202 | toolbox.register("individual", tools.initIterate, creator.Individual, randomIndividual) 203 | toolbox.register("population", tools.initRepeat, list, toolbox.individual) 204 | toolbox.register("mutate",mutateIndividual) 205 | toolbox.register("mate",crossover) 206 | toolbox.register("select", tools.selTournament, tournsize=3) 207 | toolbox.register("evaluate", evaluateIndividual) 208 | 209 | pool = multiprocessing.Pool() 210 | toolbox.register("map", pool.map) 211 | 212 | #setup initial population 213 | initpop = [ creator.Individual(cleanparams(x)) for x in X.to_dict('records')] 214 | 215 | evals = pool.map(toolbox.evaluate, initpop) 216 | top = sorted([l[0] for l in evals],reverse=True)[0] 217 | 218 | print("Best in training set: %f"%top) 219 | 220 | seen = set(map(frozendict,initpop)) 221 | #include some random individuals 222 | randpop = toolbox.population(n=len(initpop)) 223 | 224 | pop = runGA(initpop+randpop) 225 | 226 | #make sure sorted 227 | pop = sorted(pop,key=lambda x: -x.fitness.values[0]) 228 | #remove already evaluated configs 229 | pop = [p for p in pop if frozendict(p) not in seen] 230 | 231 | print("Best recommended: %f"%pop[0].fitness.values[0]) 232 | 233 | uniquified = [] 234 | for config in pop: 235 | config = cleanparams(config) 236 | fr = frozendict(config) 237 | if fr not in seen: 238 | seen.add(fr) 239 | uniquified.append(config) 240 | 241 | print(len(uniquified),len(pop)) 242 | 243 | for config in uniquified[:args.num_configs]: 244 | addrows(config, args.host,args.db,args.password) 245 | -------------------------------------------------------------------------------- /affinity_search/getres.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Return the top and R statistics for every row of the database that has them''' 4 | 5 | import sys, re, MySQLdb, argparse, os, json, subprocess 6 | import pandas as pd 7 | import makemodel 8 | import numpy as np 9 | 10 | def getcursor(): 11 | '''create a connection and return a cursor; 12 | doing this guards against dropped connections''' 13 | conn = MySQLdb.connect (host = args.host,user = "opter",passwd=args.password,db=args.db) 14 | conn.autocommit(True) 15 | cursor = conn.cursor() 16 | return cursor 17 | 18 | 19 | parser = argparse.ArgumentParser(description='Return top and R statistics for successful rows in database') 20 | parser.add_argument('--host',type=str,help='Database host',required=True) 21 | parser.add_argument('-p','--password',type=str,help='Database password',required=True) 22 | parser.add_argument('--db',type=str,help='Database name',default='opt1') 23 | 24 | args = parser.parse_args() 25 | 26 | cursor = getcursor() 27 | cursor.execute('SELECT serial,top,R,auc,rmse FROM params WHERE rmse IS NOT NULL') 28 | rows = cursor.fetchall() 29 | for row in rows: 30 | print('%d %f %f %f %f' % row) 31 | -------------------------------------------------------------------------------- /affinity_search/getresults.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Return aggregated statistics for database''' 4 | 5 | import sys, re, MySQLdb, argparse, os, json, subprocess 6 | import pandas as pd 7 | import makemodel 8 | import numpy as np 9 | from MySQLdb.cursors import DictCursor 10 | 11 | def getcursor(host,passwd,db): 12 | '''create a connection and return a cursor; 13 | doing this guards against dropped connections''' 14 | conn = MySQLdb.connect (host = host,user = "opter",passwd=passwd,db=db) 15 | conn.autocommit(True) 16 | cursor = conn.cursor(DictCursor) 17 | return cursor 18 | 19 | def __my_flatten_cols(self, how="_".join, reset_index=True): 20 | how = (lambda iter: list(iter)[-1]) if how == "last" else how 21 | self.columns = [how([_f for _f in map(str, levels) if _f]) for levels in self.columns.values] \ 22 | if isinstance(self.columns, pd.MultiIndex) else self.columns 23 | return self.reset_index() if reset_index else self 24 | pd.DataFrame.my_flatten_cols = __my_flatten_cols 25 | 26 | def getres(host, password, db, mingroup, priority, selected_params): 27 | '''return dataframe grouped by all params with params selected out''' 28 | cursor = getcursor(host,password,db) 29 | cursor.execute('SELECT * FROM params') 30 | rows = cursor.fetchall() 31 | data = pd.DataFrame(list(rows)) 32 | #make errors zero - appropriate if error is due to parameters 33 | data.loc[data.id == 'ERROR','R'] = 0 34 | data.loc[data.id == 'ERROR','rmse'] = 0 35 | data.loc[data.id == 'ERROR','top'] = 0 36 | data.loc[data.id == 'ERROR','auc'] = 0 37 | 38 | data['Rtop'] = data.R*data.top 39 | nonan = data.dropna('index').apply(pd.to_numeric, errors='ignore') 40 | 41 | #read in prioritized list of parameters 42 | params = open(priority).read().rstrip().split() 43 | 44 | grouped = nonan.groupby(params) 45 | metrics = grouped.agg([np.mean,np.std,np.min,np.max]) 46 | metrics = metrics[grouped.size() >= mingroup] 47 | metrics = metrics.my_flatten_cols() 48 | 49 | metrics = metrics.reset_index() 50 | 51 | sel = ['rmse_mean','top_mean','R_mean','auc_mean','Rtop_mean','rmse_std','top_std','R_std','auc_std','Rtop_std','serial_min','serial_max'] 52 | if selected_params: 53 | sel += selected_params 54 | return metrics.loc[:,sel] 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser(description='Return aggregatedfor successful rows in database') 58 | parser.add_argument('--host',type=str,help='Database host',required=True) 59 | parser.add_argument('-p','--password',type=str,help='Database password',required=True) 60 | parser.add_argument('--db',type=str,help='Database name',default='opt1') 61 | parser.add_argument('--mingroup',type=int,help='required number of evaluations of a model for it to count',default=5) 62 | parser.add_argument('--priority',type=str,help='priority order of parameters',required=True,default="priority") 63 | parser.add_argument('-s','--selected_params',nargs='*',help='parameters whose values should be printed with metrics') 64 | 65 | args = parser.parse_args() 66 | 67 | metrics = getres(**vars(args)) 68 | print(metrics.to_csv(sep='\t',index_label='index')) 69 | 70 | 71 | -------------------------------------------------------------------------------- /affinity_search/incremental_addrequests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Checks an sql database to determine what jobs to run next for 4 | hyperparameter optimization. 5 | 6 | Works incrementally by taking a prioritized order for evaluating 7 | parameters. It is assumed the database is already populated with at least 8 | one good model. The best model is identified according to some metric 9 | (I'm thinking R, or maybe top*R - an average of identical models is taken). The parameters 10 | for this model become the defaults. Then, in priority order, the i'th parameter is considered 11 | We compoute the average metric for all evaluated models. We ignore models that don't 12 | have a required number of minimum evaluations 13 | We check if the metric has improved on the previous last best 14 | If it hasn't: 15 | All paramters > i are set so they only have defaults as options in the spearmint config 16 | Any result rows that do not match the defaults for parameters >i are omitted. 17 | We run spearmint to get new suggestions and add them to the database 18 | Otherwise 19 | We increment i and save information (best value at previous level 20 | 21 | Note this stores ongoing information in a file INCREMENTAL.info 22 | If this file doesn't exist we start from the beginning. 23 | 24 | 25 | ''' 26 | 27 | import sys, re, MySQLdb, argparse, os, json, subprocess 28 | import pandas as pd 29 | import makemodel 30 | import numpy as np 31 | from MySQLdb.cursors import DictCursor 32 | from outputjson import makejson 33 | from populaterequests import addrows 34 | 35 | def getcursor(): 36 | '''create a connection and return a cursor; 37 | doing this guards against dropped connections''' 38 | conn = MySQLdb.connect (host = args.host,user = "opter",passwd=args.password,db=args.db) 39 | conn.autocommit(True) 40 | cursor = conn.cursor(DictCursor) 41 | return cursor 42 | 43 | 44 | parser = argparse.ArgumentParser(description='Generate more configurations if needed') 45 | parser.add_argument('--host',type=str,help='Database host',required=True) 46 | parser.add_argument('-p','--password',type=str,help='Database password',required=True) 47 | parser.add_argument('--db',type=str,help='Database name',default='database') 48 | parser.add_argument('--pending_threshold',type=int,default=12,help='Number of pending jobs that triggers an update') 49 | parser.add_argument('-n','--num_configs',type=int,default=5,help='Number of configs to generate - will add a multiple as many jobs') 50 | parser.add_argument('-s','--spearmint',type=str,help='Location of spearmint-lite.py',required=True) 51 | parser.add_argument('--model_threshold',type=int,default=12,help='Number of unique models to evaluate at a level before giving up and going to the next level') 52 | parser.add_argument('--priority',type=file,help='priority order of parameters',required=True) 53 | parser.add_argument('--info',type=str,help='incremental information file',default='INCREMENTAL.info') 54 | parser.add_argument('--mingroup',type=int,help='required number of evaluations of a model for it to count',default=5) 55 | args = parser.parse_args() 56 | 57 | 58 | 59 | # first see how many id=REQUESTED jobs there are 60 | cursor = getcursor() 61 | cursor.execute('SELECT COUNT(*) FROM params WHERE id = "REQUESTED"') 62 | rows = cursor.fetchone() 63 | pending = list(rows.values())[0] 64 | 65 | #get options 66 | options = sorted(makemodel.getoptions().items()) 67 | 68 | #print "Pending jobs:",pending 69 | sys.stdout.write('%d '%pending) 70 | sys.stdout.flush() 71 | 72 | #if more than pending_threshold, quit 73 | if pending > args.pending_threshold: 74 | sys.exit(0) 75 | 76 | #create gnina-spearmint directory if it doesn't exist already 77 | if not os.path.exists('gnina-spearmint-incremental'): 78 | os.makedirs('gnina-spearmint-incremental') 79 | 80 | #read in prioritized list of parameters 81 | params = args.priority.read().rstrip().split() 82 | 83 | #get data and compute average metric of each model 84 | cursor.execute('SELECT * FROM params') 85 | rows = cursor.fetchall() 86 | data = pd.DataFrame(list(rows)) 87 | #make errors zero - appropriate if error is due to parameters 88 | data.loc[data.id == 'ERROR','R'] = 0 89 | data.loc[data.id == 'ERROR','rmse'] = 0 90 | data.loc[data.id == 'ERROR','top'] = 0 91 | data.loc[data.id == 'ERROR','auc'] = 0 92 | 93 | nonan = data.dropna('index') 94 | 95 | grouped = nonan.groupby(params) 96 | metrics = grouped.mean()[['R','top']] 97 | metrics = metrics[grouped.size() >= args.mingroup] 98 | metrics['Rtop'] = metrics.R * metrics.top 99 | defaultparams = metrics['Rtop'].idxmax() #this is in priority order 100 | bestRtop = metrics['Rtop'].max() 101 | 102 | 103 | print("Best",bestRtop) 104 | #figure out what param we are on 105 | if os.path.exists(args.info): 106 | #info file has what iteration we are on and the previous best when we moved to that iteration 107 | (level, prevbest) = open(args.info).readlines()[-1].split() 108 | level = int(level) 109 | prevbest = float(prevbest) 110 | else: 111 | #very first time we've run 112 | level = 0 113 | prevbest = bestRtop 114 | info = open(args.info,'w') 115 | info.write('0 %f\n'%bestRtop) 116 | info.close() 117 | 118 | #check to see if we should promote level 119 | if bestRtop > prevbest*1.01: 120 | level += 1 121 | info = open(args.info,'a') 122 | info.write('%d %f\n' % (level,bestRtop)) 123 | info.close() 124 | 125 | 126 | try: #remove pickle file in case number of parameters has changed 127 | if level != 50: os.remove('gnina-spearmint-incremental/chooser.GPEIOptChooser.pkl') 128 | except: 129 | pass 130 | 131 | #create config.json without defaulted parameters 132 | config = makejson() 133 | defaults = dict() 134 | for (i,(name,value)) in enumerate(zip(params,defaultparams)): 135 | if i > level: 136 | defaults[name] = str(value) 137 | del config[name] 138 | 139 | cout = open('gnina-spearmint-incremental/config.json','w') 140 | cout.write(json.dumps(config, indent=4)+'\n') 141 | cout.close() 142 | 143 | #output results.data using top*R 144 | #don't use averages, since in theory spearmint will use the distribution intelligently 145 | #also include rows without values to avoid repetition 146 | 147 | resf = open('gnina-spearmint-incremental/results.dat','w') 148 | uniqconfigs = set() 149 | evalconfigs = set() 150 | validrows = 0 151 | for (i,row) in data.iterrows(): 152 | outrow = [] 153 | for (name,vals) in options: 154 | if name == 'resolution': 155 | val = str(float(row[name])) #gets returned as 1 instead of 1.0 156 | else: 157 | val = str(row[name]) 158 | 159 | if name in defaults: # is this row acceptable 160 | if type(row[name]) == float or type(row[name]) == int: 161 | if np.abs(float(defaults[name])-row[name]) > 0.00001: 162 | break 163 | elif row[name] != defaults[name]: 164 | break 165 | else: 166 | outrow.append(val) 167 | else: #execute if we didn't break 168 | validrows += 1 169 | uniqconfigs.add(tuple(outrow)) 170 | Rtop = row['R']*row['top'] 171 | if np.isfinite(Rtop): 172 | resf.write('%f 0 '% -Rtop) 173 | evalconfigs.add(tuple(outrow)) 174 | else: 175 | resf.write('P P ') 176 | #outrow is in opt order, but with defaults removed 177 | resf.write(' '.join(outrow)) 178 | resf.write('\n') 179 | resf.close() 180 | 181 | gseed = len(uniqconfigs) #not clear this actually makes sense in our context.. 182 | print("Uniq configs:",gseed) 183 | print("Evaled configs:",len(evalconfigs)) 184 | 185 | #a very generous threshold - multiply by level rather than keep track of number of uniq models in this level 186 | threshold = (level+1)*args.model_threshold 187 | if len(evalconfigs) > threshold: 188 | #promote level, although this will not effect this invocation 189 | level += 1 190 | info = open(args.info,'a') 191 | info.write('%d %f\n'%(level,bestRtop)) 192 | info.close() 193 | try: 194 | os.remove('gnina-spearmint-incremental/chooser.GPEIOptChooser.pkl') 195 | except: 196 | pass 197 | 198 | # run spearmint-light, set the seed to the number of unique configurations 199 | spearargs = ['python',args.spearmint, '--method=GPEIOptChooser', '--grid-size=20000', 200 | 'gnina-spearmint-incremental', '--n=%d'%args.num_configs, '--grid-seed=%d' % gseed] 201 | print(' '.join(spearargs)) 202 | 203 | subprocess.call(spearargs) 204 | #get the generated lines from the file 205 | lines = open('gnina-spearmint-incremental/results.dat').readlines() 206 | newlines = np.unique(lines[validrows:]) 207 | print(len(newlines),args.num_configs) 208 | assert(len(newlines) > 0) 209 | out = open('gnina-spearmint-incremental/newrows.dat','w') 210 | for line in newlines: 211 | vals = line.rstrip().split() 212 | pos = 2 213 | outrow = [vals[0],vals[1]] 214 | for (name,_) in options: 215 | if name in defaults: 216 | outrow.append(defaults[name]) 217 | else: #not defaults in opt order 218 | outrow.append(vals[pos]) 219 | pos += 1 220 | assert(pos == len(vals)) 221 | out.write(' '.join(outrow)) 222 | out.write('\n') 223 | print(' '.join(outrow)) 224 | out.close() 225 | #add to database as REQUESTED jobs 226 | 227 | addrows('gnina-spearmint-incremental/newrows.dat',args.host,args.db,args.password) 228 | 229 | -------------------------------------------------------------------------------- /affinity_search/makebesty.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys,re, collections 4 | '''convert the provided example file to only have a single positive affinity 5 | for the best rmsd example''' 6 | 7 | #first identify best rmsd 8 | bestval = dict() 9 | bestlig = dict() 10 | for line in open(sys.argv[1]): 11 | # 0 1a30/1a30_rec.gninatypes 1a30/1a30_ligand_0.gninatypes # 8.46937 -8.3175 12 | vals = line.rstrip().split() 13 | rmsd = float(vals[5]) 14 | rec = vals[2] 15 | if rec not in bestval or rmsd < bestval[rec]: 16 | bestval[rec] = rmsd 17 | bestlig[rec] = vals[3] 18 | 19 | for line in open(sys.argv[1]): 20 | vals = line.rstrip().split() 21 | rec = vals[2] 22 | if vals[3] == bestlig[rec] or float(vals[1]) < 0: 23 | print(line.rstrip()) 24 | else: 25 | print(vals[0],-float(vals[1]),' '.join(vals[2:])) 26 | 27 | -------------------------------------------------------------------------------- /affinity_search/makemodels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Generate models for affinity predictions''' 4 | 5 | # variables: 6 | # RESOLUTION(0.5) 7 | # CONVKERNEL(3) 8 | # BALANCED(true) 9 | # POSEPREDICT 10 | # RECEPTOR(false) 11 | # AFFMIN(0) 12 | # AFFMAX(0) 13 | # AFFSTEP(0) 14 | # GAP(0) 15 | basemodel = '''layer { 16 | name: "data" 17 | type: "MolGridData" 18 | top: "data" 19 | top: "label" 20 | top: "affinity" 21 | include { 22 | phase: TEST 23 | } 24 | molgrid_data_param { 25 | source: "TESTFILE" 26 | batch_size: 50 27 | dimension: 23.5 28 | resolution: RESOLUTION 29 | shuffle: false 30 | balanced: false 31 | has_affinity: true 32 | root_folder: "../../" 33 | } 34 | } 35 | layer { 36 | name: "data" 37 | type: "MolGridData" 38 | top: "data" 39 | top: "label" 40 | top: "affinity" 41 | include { 42 | phase: TRAIN 43 | } 44 | molgrid_data_param { 45 | source: "TRAINFILE" 46 | batch_size: 50 47 | dimension: 23.5 48 | resolution: RESOLUTION 49 | shuffle: true 50 | balanced: BALANCED 51 | stratify_receptor: RECEPTOR 52 | stratify_affinity_min: AFFMIN 53 | stratify_affinity_max: AFFMAX 54 | stratify_affinity_step: AFFSTEP 55 | has_affinity: true 56 | random_rotation: true 57 | random_translate: 2 58 | root_folder: "../.." 59 | } 60 | } 61 | 62 | layer { 63 | name: "unit1_pool" 64 | type: "Pooling" 65 | bottom: "data" 66 | top: "unit1_pool" 67 | pooling_param { 68 | pool: MAX 69 | kernel_size: 2 70 | stride: 2 71 | } 72 | } 73 | layer { 74 | name: "unit1_conv1" 75 | type: "Convolution" 76 | bottom: "unit1_pool" 77 | top: "unit1_conv1" 78 | convolution_param { 79 | num_output: 32 80 | pad: 1 81 | kernel_size: CONVKERNEL 82 | stride: 1 83 | weight_filler { 84 | type: "xavier" 85 | } 86 | } 87 | } 88 | layer { 89 | name: "unit1_relu1" 90 | type: "ReLU" 91 | bottom: "unit1_conv1" 92 | top: "unit1_conv1" 93 | } 94 | layer { 95 | name: "unit2_pool" 96 | type: "Pooling" 97 | bottom: "unit1_conv1" 98 | top: "unit2_pool" 99 | pooling_param { 100 | pool: MAX 101 | kernel_size: 2 102 | stride: 2 103 | } 104 | } 105 | layer { 106 | name: "unit2_conv1" 107 | type: "Convolution" 108 | bottom: "unit2_pool" 109 | top: "unit2_conv1" 110 | convolution_param { 111 | num_output: 64 112 | pad: 1 113 | kernel_size: CONVKERNEL 114 | stride: 1 115 | weight_filler { 116 | type: "xavier" 117 | } 118 | } 119 | } 120 | layer { 121 | name: "unit2_relu1" 122 | type: "ReLU" 123 | bottom: "unit2_conv1" 124 | top: "unit2_conv1" 125 | } 126 | layer { 127 | name: "unit3_pool" 128 | type: "Pooling" 129 | bottom: "unit2_conv1" 130 | top: "unit3_pool" 131 | pooling_param { 132 | pool: MAX 133 | kernel_size: 2 134 | stride: 2 135 | } 136 | } 137 | layer { 138 | name: "unit3_conv1" 139 | type: "Convolution" 140 | bottom: "unit3_pool" 141 | top: "unit3_conv1" 142 | convolution_param { 143 | num_output: 128 144 | pad: 1 145 | kernel_size: CONVKERNEL 146 | stride: 1 147 | weight_filler { 148 | type: "xavier" 149 | } 150 | } 151 | } 152 | layer { 153 | name: "unit3_relu1" 154 | type: "ReLU" 155 | bottom: "unit3_conv1" 156 | top: "unit3_conv1" 157 | } 158 | 159 | layer { 160 | name: "split" 161 | type: "Split" 162 | bottom: "unit3_conv1" 163 | top: "split" 164 | } 165 | 166 | POSEPREDICT 167 | 168 | layer { 169 | name: "output_fc_aff" 170 | type: "InnerProduct" 171 | bottom: "split" 172 | top: "output_fc_aff" 173 | inner_product_param { 174 | num_output: 1 175 | weight_filler { 176 | type: "xavier" 177 | } 178 | } 179 | } 180 | 181 | layer { 182 | name: "rmsd" 183 | type: "AffinityLoss" 184 | bottom: "output_fc_aff" 185 | bottom: "affinity" 186 | top: "rmsd" 187 | affinity_loss_param { 188 | scale: 0.1 189 | gap: GAP 190 | } 191 | } 192 | 193 | layer { 194 | name: "predaff" 195 | type: "Flatten" 196 | bottom: "output_fc_aff" 197 | top: "predaff" 198 | } 199 | 200 | layer { 201 | name: "affout" 202 | type: "Split" 203 | bottom: "affinity" 204 | top: "affout" 205 | include { 206 | phase: TEST 207 | } 208 | } 209 | ''' 210 | posepredict='''layer { 211 | name: "output_fc" 212 | type: "InnerProduct" 213 | bottom: "split" 214 | top: "output_fc" 215 | inner_product_param { 216 | num_output: 2 217 | weight_filler { 218 | type: "xavier" 219 | } 220 | } 221 | } 222 | layer { 223 | name: "loss" 224 | type: "SoftmaxWithLoss" 225 | bottom: "output_fc" 226 | bottom: "label" 227 | top: "loss" 228 | } 229 | 230 | layer { 231 | name: "output" 232 | type: "Softmax" 233 | bottom: "output_fc" 234 | top: "output" 235 | } 236 | layer { 237 | name: "labelout" 238 | type: "Split" 239 | bottom: "label" 240 | top: "labelout" 241 | include { 242 | phase: TEST 243 | } 244 | } 245 | ''' 246 | 247 | def makemodel(**kwargs): 248 | m = basemodel 249 | for (k,v) in kwargs.items(): 250 | m = m.replace(k,str(v)) 251 | return m 252 | 253 | 254 | 255 | conv = 3 256 | resolution = 0.5 257 | 258 | models = [] 259 | for gap in [0,1,2]: 260 | for pose in ['', posepredict]: 261 | for balanced in ['true','false']: 262 | for receptor in ['true','false']: 263 | for affstrat in [True,False]: 264 | if affstrat: 265 | amin = 2 #first group will be < 3 266 | amax = 10 #last bin will be > 9 267 | astep = 1 268 | else: 269 | amin = amax = astep = 0 270 | 271 | model = makemodel(GAP=gap,POSEPREDICT=pose,RESOLUTION=resolution, CONVKERNEL=conv, RECEPTOR=receptor,AFFMIN=amin,AFFMAX=amax,AFFSTEP=astep,BALANCED=balanced) 272 | m = 'affinity_g%d_p%d_rec%d_astrat%d_b%d.model'%(gap,len(pose)>0,receptor=='true',affstrat,balanced=='true') 273 | models.append(m) 274 | out = open(m,'w') 275 | out.write(model) 276 | 277 | 278 | unbalanced = set(['bestonly','crystal','posonly']) 279 | single = set(['bestonly','crystal']) 280 | for i in ['all','besty','posonly','crystal','bestonly']: 281 | #some around valid for the model - assume we die quickly? 282 | for m in models: 283 | if i in unbalanced: 284 | if'_b1' in m: continue 285 | if '_p1' in m: continue 286 | else: #balanced 287 | if '_b0' in m: continue 288 | if i in single: 289 | if '_rec1' in m: continue #only one per receptor, not much point 290 | 291 | print("train.py -m %s -p %s_0.5_0_ --keep_best -t 1000 -i 100000 --reduced -o %s_%s"%(m,i,i,m.replace('.model',''))) 292 | -------------------------------------------------------------------------------- /affinity_search/makemodels1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Generate models for affinity predictions''' 4 | 5 | # variables: 6 | # BALANCED(true) 7 | # POSEPREDICT 8 | # RECEPTOR(false) 9 | # AFFMIN(0) 10 | # AFFMAX(0) 11 | # AFFSTEP(0) 12 | basemodel = '''layer { 13 | name: "data" 14 | type: "MolGridData" 15 | top: "data" 16 | top: "label" 17 | top: "affinity" 18 | include { 19 | phase: TEST 20 | } 21 | molgrid_data_param { 22 | source: "TESTFILE" 23 | batch_size: 50 24 | dimension: 23.5 25 | resolution: 0.5 26 | shuffle: false 27 | balanced: false 28 | has_affinity: true 29 | root_folder: "../.." 30 | } 31 | } 32 | layer { 33 | name: "data" 34 | type: "MolGridData" 35 | top: "data" 36 | top: "label" 37 | top: "affinity" 38 | include { 39 | phase: TRAIN 40 | } 41 | molgrid_data_param { 42 | source: "TRAINFILE" 43 | batch_size: 50 44 | dimension: 23.5 45 | resolution: 0.5 46 | shuffle: true 47 | balanced: BALANCED 48 | stratify_receptor: RECEPTOR 49 | stratify_affinity_min: AFFMIN 50 | stratify_affinity_max: AFFMAX 51 | stratify_affinity_step: AFFSTEP 52 | has_affinity: true 53 | random_rotation: true 54 | random_translate: 2 55 | root_folder: "../.." 56 | } 57 | } 58 | 59 | layer { 60 | name: "unit1_pool" 61 | type: "Pooling" 62 | bottom: "data" 63 | top: "unit1_pool" 64 | pooling_param { 65 | pool: MAX 66 | kernel_size: 2 67 | stride: 2 68 | } 69 | } 70 | layer { 71 | name: "unit1_conv1" 72 | type: "Convolution" 73 | bottom: "unit1_pool" 74 | top: "unit1_conv1" 75 | convolution_param { 76 | num_output: 32 77 | pad: 1 78 | kernel_size: 3 79 | stride: 1 80 | weight_filler { 81 | type: "xavier" 82 | } 83 | } 84 | } 85 | layer { 86 | name: "unit1_relu1" 87 | type: "ReLU" 88 | bottom: "unit1_conv1" 89 | top: "unit1_conv1" 90 | } 91 | layer { 92 | name: "unit2_pool" 93 | type: "Pooling" 94 | bottom: "unit1_conv1" 95 | top: "unit2_pool" 96 | pooling_param { 97 | pool: MAX 98 | kernel_size: 2 99 | stride: 2 100 | } 101 | } 102 | layer { 103 | name: "unit2_conv1" 104 | type: "Convolution" 105 | bottom: "unit2_pool" 106 | top: "unit2_conv1" 107 | convolution_param { 108 | num_output: 64 109 | pad: 1 110 | kernel_size: 3 111 | stride: 1 112 | weight_filler { 113 | type: "xavier" 114 | } 115 | } 116 | } 117 | layer { 118 | name: "unit2_relu1" 119 | type: "ReLU" 120 | bottom: "unit2_conv1" 121 | top: "unit2_conv1" 122 | } 123 | layer { 124 | name: "unit3_pool" 125 | type: "Pooling" 126 | bottom: "unit2_conv1" 127 | top: "unit3_pool" 128 | pooling_param { 129 | pool: MAX 130 | kernel_size: 2 131 | stride: 2 132 | } 133 | } 134 | layer { 135 | name: "unit3_conv1" 136 | type: "Convolution" 137 | bottom: "unit3_pool" 138 | top: "unit3_conv1" 139 | convolution_param { 140 | num_output: 128 141 | pad: 1 142 | kernel_size: 3 143 | stride: 1 144 | weight_filler { 145 | type: "xavier" 146 | } 147 | } 148 | } 149 | layer { 150 | name: "unit3_relu1" 151 | type: "ReLU" 152 | bottom: "unit3_conv1" 153 | top: "unit3_conv1" 154 | } 155 | 156 | layer { 157 | name: "split" 158 | type: "Split" 159 | bottom: "unit3_conv1" 160 | top: "split" 161 | } 162 | 163 | POSEPREDICT 164 | 165 | layer { 166 | name: "output_fc_aff" 167 | type: "InnerProduct" 168 | bottom: "split" 169 | top: "output_fc_aff" 170 | inner_product_param { 171 | num_output: 1 172 | weight_filler { 173 | type: "xavier" 174 | } 175 | } 176 | } 177 | 178 | layer { 179 | name: "rmsd" 180 | type: "AffinityLoss" 181 | bottom: "output_fc_aff" 182 | bottom: "affinity" 183 | top: "rmsd" 184 | affinity_loss_param { 185 | scale: 0.1 186 | gap: 0 187 | } 188 | } 189 | 190 | layer { 191 | name: "predaff" 192 | type: "Flatten" 193 | bottom: "output_fc_aff" 194 | top: "predaff" 195 | } 196 | 197 | layer { 198 | name: "affout" 199 | type: "Split" 200 | bottom: "affinity" 201 | top: "affout" 202 | include { 203 | phase: TEST 204 | } 205 | } 206 | ''' 207 | posepredict='''layer { 208 | name: "output_fc" 209 | type: "InnerProduct" 210 | bottom: "split" 211 | top: "output_fc" 212 | inner_product_param { 213 | num_output: 2 214 | weight_filler { 215 | type: "xavier" 216 | } 217 | } 218 | } 219 | layer { 220 | name: "loss" 221 | type: "SoftmaxWithLoss" 222 | bottom: "output_fc" 223 | bottom: "label" 224 | top: "loss" 225 | } 226 | 227 | layer { 228 | name: "output" 229 | type: "Softmax" 230 | bottom: "output_fc" 231 | top: "output" 232 | } 233 | layer { 234 | name: "labelout" 235 | type: "Split" 236 | bottom: "label" 237 | top: "labelout" 238 | include { 239 | phase: TEST 240 | } 241 | } 242 | ''' 243 | 244 | def makemodel(**kwargs): 245 | m = basemodel 246 | for (k,v) in kwargs.items(): 247 | m = m.replace(k,str(v)) 248 | return m 249 | 250 | 251 | 252 | conv = 3 253 | resolution = 0.5 254 | 255 | models = [] 256 | for pose in ['', posepredict]: 257 | for balanced in ['true','false']: 258 | for receptor in ['true','false']: 259 | for affstrat in [True,False]: 260 | if affstrat: 261 | amin = 2 #first group will be < 3 262 | amax = 10 #last bin will be > 9 263 | astep = 1 264 | else: 265 | amin = amax = astep = 0 266 | 267 | model = makemodel(POSEPREDICT=pose, RECEPTOR=receptor,AFFMIN=amin,AFFMAX=amax,AFFSTEP=astep,BALANCED=balanced) 268 | m = 'affinity_p%d_rec%d_astrat%d_b%d.model'%(len(pose)>0,receptor=='true',affstrat,balanced=='true') 269 | models.append(m) 270 | out = open(m,'w') 271 | out.write(model) 272 | 273 | 274 | unbalanced = set(['bestonly','crystal','posonly']) 275 | single = set(['bestonly','crystal']) 276 | for i in ['all','besty','posonly','crystal','bestonly']: 277 | #some around valid for the model - assume we die quickly? 278 | for m in models: 279 | if i in unbalanced: 280 | if'_b1' in m: continue 281 | if '_p1' in m: continue 282 | else: #balanced 283 | if '_b0' in m: continue 284 | if i in single: 285 | if '_rec1' in m: continue #only one per receptor, not much point 286 | 287 | print("train.py -m %s -p ../types/%s_0.5_0_ --keep_best -t 1000 -i 100000 --reduced -o %s_%s"%(m,i,i,m.replace('.model',''))) 288 | -------------------------------------------------------------------------------- /affinity_search/makemodels2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Generate models for affinity predictions''' 4 | 5 | # variables: 6 | 7 | 8 | # GAP(0) 9 | # PENALTY (0) 10 | # HUBER (false) 11 | # DELTA (0) 12 | # RANKLOSS (0) 13 | basemodel = '''layer { 14 | name: "data" 15 | type: "MolGridData" 16 | top: "data" 17 | top: "label" 18 | top: "affinity" 19 | include { 20 | phase: TEST 21 | } 22 | molgrid_data_param { 23 | source: "TESTFILE" 24 | batch_size: 50 25 | dimension: 23.5 26 | resolution: 0.5 27 | shuffle: false 28 | balanced: false 29 | has_affinity: true 30 | root_folder: "../../" 31 | } 32 | } 33 | layer { 34 | name: "data" 35 | type: "MolGridData" 36 | top: "data" 37 | top: "label" 38 | top: "affinity" 39 | include { 40 | phase: TRAIN 41 | } 42 | molgrid_data_param { 43 | source: "TRAINFILE" 44 | batch_size: 50 45 | dimension: 23.5 46 | resolution: 0.5 47 | shuffle: true 48 | balanced: true 49 | stratify_receptor: true 50 | stratify_affinity_min: 0 51 | stratify_affinity_max: 0 52 | stratify_affinity_step: 0 53 | has_affinity: true 54 | random_rotation: true 55 | random_translate: 2 56 | root_folder: "../../" 57 | } 58 | } 59 | 60 | layer { 61 | name: "unit1_pool" 62 | type: "Pooling" 63 | bottom: "data" 64 | top: "unit1_pool" 65 | pooling_param { 66 | pool: MAX 67 | kernel_size: 2 68 | stride: 2 69 | } 70 | } 71 | layer { 72 | name: "unit1_conv1" 73 | type: "Convolution" 74 | bottom: "unit1_pool" 75 | top: "unit1_conv1" 76 | convolution_param { 77 | num_output: 32 78 | pad: 1 79 | kernel_size: 3 80 | stride: 1 81 | weight_filler { 82 | type: "xavier" 83 | } 84 | } 85 | } 86 | layer { 87 | name: "unit1_relu1" 88 | type: "ReLU" 89 | bottom: "unit1_conv1" 90 | top: "unit1_conv1" 91 | } 92 | layer { 93 | name: "unit2_pool" 94 | type: "Pooling" 95 | bottom: "unit1_conv1" 96 | top: "unit2_pool" 97 | pooling_param { 98 | pool: MAX 99 | kernel_size: 2 100 | stride: 2 101 | } 102 | } 103 | layer { 104 | name: "unit2_conv1" 105 | type: "Convolution" 106 | bottom: "unit2_pool" 107 | top: "unit2_conv1" 108 | convolution_param { 109 | num_output: 64 110 | pad: 1 111 | kernel_size: 3 112 | stride: 1 113 | weight_filler { 114 | type: "xavier" 115 | } 116 | } 117 | } 118 | layer { 119 | name: "unit2_relu1" 120 | type: "ReLU" 121 | bottom: "unit2_conv1" 122 | top: "unit2_conv1" 123 | } 124 | layer { 125 | name: "unit3_pool" 126 | type: "Pooling" 127 | bottom: "unit2_conv1" 128 | top: "unit3_pool" 129 | pooling_param { 130 | pool: MAX 131 | kernel_size: 2 132 | stride: 2 133 | } 134 | } 135 | layer { 136 | name: "unit3_conv1" 137 | type: "Convolution" 138 | bottom: "unit3_pool" 139 | top: "unit3_conv1" 140 | convolution_param { 141 | num_output: 128 142 | pad: 1 143 | kernel_size: 3 144 | stride: 1 145 | weight_filler { 146 | type: "xavier" 147 | } 148 | } 149 | } 150 | layer { 151 | name: "unit3_relu1" 152 | type: "ReLU" 153 | bottom: "unit3_conv1" 154 | top: "unit3_conv1" 155 | } 156 | 157 | layer { 158 | name: "split" 159 | type: "Split" 160 | bottom: "unit3_conv1" 161 | top: "split" 162 | } 163 | 164 | layer { 165 | name: "output_fc" 166 | type: "InnerProduct" 167 | bottom: "split" 168 | top: "output_fc" 169 | inner_product_param { 170 | num_output: 2 171 | weight_filler { 172 | type: "xavier" 173 | } 174 | } 175 | } 176 | layer { 177 | name: "loss" 178 | type: "SoftmaxWithLoss" 179 | bottom: "output_fc" 180 | bottom: "label" 181 | top: "loss" 182 | } 183 | 184 | layer { 185 | name: "output" 186 | type: "Softmax" 187 | bottom: "output_fc" 188 | top: "output" 189 | } 190 | layer { 191 | name: "labelout" 192 | type: "Split" 193 | bottom: "label" 194 | top: "labelout" 195 | include { 196 | phase: TEST 197 | } 198 | } 199 | 200 | layer { 201 | name: "output_fc_aff" 202 | type: "InnerProduct" 203 | bottom: "split" 204 | top: "output_fc_aff" 205 | inner_product_param { 206 | num_output: 1 207 | weight_filler { 208 | type: "xavier" 209 | } 210 | } 211 | } 212 | 213 | layer { 214 | name: "rmsd" 215 | type: "AffinityLoss" 216 | bottom: "output_fc_aff" 217 | bottom: "affinity" 218 | top: "rmsd" 219 | affinity_loss_param { 220 | scale: 0.1 221 | gap: GAP 222 | penalty: PENALTY 223 | pseudohuber: HUBER 224 | delta: DELTA 225 | } 226 | } 227 | 228 | layer { 229 | name: "predaff" 230 | type: "Flatten" 231 | bottom: "output_fc_aff" 232 | top: "predaff" 233 | } 234 | 235 | layer { 236 | name: "affout" 237 | type: "Split" 238 | bottom: "affinity" 239 | top: "affout" 240 | include { 241 | phase: TEST 242 | } 243 | } 244 | ''' 245 | 246 | 247 | def makemodel(**kwargs): 248 | m = basemodel 249 | for (k,v) in kwargs.items(): 250 | m = m.replace(k,str(v)) 251 | return m 252 | 253 | 254 | 255 | # GAP(0) 256 | # PENALTY (0) 257 | # HUBER (false) 258 | # DELTA (0) 259 | # RANKLOSS (0) 260 | 261 | models = [] 262 | for gap in [0,1,2]: 263 | for penalty in [0,1,2,4]: 264 | for delta in [0,1,2,4,6]: 265 | if delta == 0: huber = "false" 266 | else: huber = "true" 267 | model = makemodel(GAP=gap,PENALTY=penalty,HUBER=huber, DELTA=delta) 268 | m = 'affinity_g%d_p%d_h%d.model'%(gap,penalty, delta) 269 | models.append(m) 270 | out = open(m,'w') 271 | out.write(model) 272 | 273 | for m in models: 274 | print("train.py -m %s -p ../types/all_0.5_0_ --keep_best -t 1000 -i 100000 --reduced -o all_%s"%(m,m.replace('.model',''))) 275 | -------------------------------------------------------------------------------- /affinity_search/makemodels3.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Generate models for affinity predictions''' 4 | 5 | # variables: 6 | # non-linearity: ReLU, leaky, ELU, Sigmoid, TanH (PReLU not current ndim compat) 7 | # normalization: none, LRN, Batch 8 | # learning rate: 0.01, 0.001, 0.1 9 | 10 | modelstart = '''layer { 11 | name: "data" 12 | type: "MolGridData" 13 | top: "data" 14 | top: "label" 15 | top: "affinity" 16 | include { 17 | phase: TEST 18 | } 19 | molgrid_data_param { 20 | source: "TESTFILE" 21 | batch_size: 50 22 | dimension: 23.5 23 | resolution: 0.5 24 | shuffle: false 25 | balanced: false 26 | has_affinity: true 27 | root_folder: "../../" 28 | } 29 | } 30 | layer { 31 | name: "data" 32 | type: "MolGridData" 33 | top: "data" 34 | top: "label" 35 | top: "affinity" 36 | include { 37 | phase: TRAIN 38 | } 39 | molgrid_data_param { 40 | source: "TRAINFILE" 41 | batch_size: 50 42 | dimension: 23.5 43 | resolution: 0.5 44 | shuffle: true 45 | balanced: true 46 | stratify_receptor: true 47 | stratify_affinity_min: 0 48 | stratify_affinity_max: 0 49 | stratify_affinity_step: 0 50 | has_affinity: true 51 | random_rotation: true 52 | random_translate: 2 53 | root_folder: "../../" 54 | } 55 | } 56 | ''' 57 | 58 | endmodel = '''layer { 59 | name: "split" 60 | type: "Split" 61 | bottom: "LASTCONV" 62 | top: "split" 63 | } 64 | 65 | layer { 66 | name: "output_fc" 67 | type: "InnerProduct" 68 | bottom: "split" 69 | top: "output_fc" 70 | inner_product_param { 71 | num_output: 2 72 | weight_filler { 73 | type: "xavier" 74 | } 75 | } 76 | } 77 | layer { 78 | name: "loss" 79 | type: "SoftmaxWithLoss" 80 | bottom: "output_fc" 81 | bottom: "label" 82 | top: "loss" 83 | } 84 | 85 | layer { 86 | name: "output" 87 | type: "Softmax" 88 | bottom: "output_fc" 89 | top: "output" 90 | } 91 | layer { 92 | name: "labelout" 93 | type: "Split" 94 | bottom: "label" 95 | top: "labelout" 96 | include { 97 | phase: TEST 98 | } 99 | } 100 | 101 | layer { 102 | name: "output_fc_aff" 103 | type: "InnerProduct" 104 | bottom: "split" 105 | top: "output_fc_aff" 106 | inner_product_param { 107 | num_output: 1 108 | weight_filler { 109 | type: "xavier" 110 | } 111 | } 112 | } 113 | 114 | layer { 115 | name: "rmsd" 116 | type: "AffinityLoss" 117 | bottom: "output_fc_aff" 118 | bottom: "affinity" 119 | top: "rmsd" 120 | affinity_loss_param { 121 | scale: 0.1 122 | gap: 1 123 | penalty: 0 124 | pseudohuber: false 125 | delta: 0 126 | } 127 | } 128 | 129 | layer { 130 | name: "predaff" 131 | type: "Flatten" 132 | bottom: "output_fc_aff" 133 | top: "predaff" 134 | } 135 | 136 | layer { 137 | name: "affout" 138 | type: "Split" 139 | bottom: "affinity" 140 | top: "affout" 141 | include { 142 | phase: TEST 143 | } 144 | } 145 | 146 | ''' 147 | 148 | convunit = ''' 149 | layer { 150 | name: "unitNUMBER_pool" 151 | type: "Pooling" 152 | bottom: "INLAYER" 153 | top: "unitNUMBER_pool" 154 | pooling_param { 155 | pool: MAX 156 | kernel_size: 2 157 | stride: 2 158 | } 159 | } 160 | layer { 161 | name: "unitNUMBER_conv1" 162 | type: "Convolution" 163 | bottom: "unitNUMBER_pool" 164 | top: "unitNUMBER_conv1" 165 | convolution_param { 166 | num_output: 32 167 | pad: 1 168 | kernel_size: 3 169 | stride: 1 170 | weight_filler { 171 | type: "xavier" 172 | } 173 | } 174 | }''' 175 | 176 | 177 | norms = { 178 | 'none': '', 179 | 'batch': '''layer { 180 | name: "unitNUMBER_norm" 181 | type: "BatchNorm" 182 | bottom: "unitNUMBER_conv1" 183 | top: "unitNUMBER_conv1" 184 | } 185 | 186 | layer { 187 | name: "unitNUMBER_scale" 188 | type: "Scale" 189 | bottom: "unitNUMBER_conv1" 190 | top: "unitNUMBER_conv1" 191 | scale_param { 192 | bias_term: true 193 | } 194 | } 195 | ''', 196 | 'lrn': '''layer { 197 | name: "unitNUMBER_norm" 198 | type: "LRN" 199 | bottom: "unitNUMBER_conv1" 200 | top: "unitNUMBER_conv1" 201 | } 202 | 203 | layer { 204 | name: "unitNUMBER_scale" 205 | type: "Scale" 206 | bottom: "unitNUMBER_conv1" 207 | top: "unitNUMBER_conv1" 208 | scale_param { 209 | bias_term: true 210 | } 211 | } 212 | ''' 213 | } 214 | 215 | relus = { 216 | 'relu': '''layer { 217 | name: "unitNUMBER_func" 218 | type: "ReLU" 219 | bottom: "unitNUMBER_conv1" 220 | top: "unitNUMBER_conv1" 221 | }''', 222 | 'leaky': '''layer { 223 | name: "unitNUMBER_func" 224 | type: "ReLU" 225 | bottom: "unitNUMBER_conv1" 226 | top: "unitNUMBER_conv1" 227 | relu_param{ 228 | negative_slope: 0.01 229 | } 230 | }''', 231 | 'elu':'''layer { 232 | name: "unitNUMBER_func" 233 | type: "ELU" 234 | bottom: "unitNUMBER_conv1" 235 | top: "unitNUMBER_conv1" 236 | }''', 237 | 'sigmoid':'''layer { 238 | name: "unitNUMBER_func" 239 | type: "Sigmoid" 240 | bottom: "unitNUMBER_conv1" 241 | top: "unitNUMBER_conv1" 242 | }''', 243 | 'tanh':'''layer { 244 | name: "unitNUMBER_func" 245 | type: "TanH" 246 | bottom: "unitNUMBER_conv1" 247 | top: "unitNUMBER_conv1" 248 | }''' 249 | } 250 | 251 | # normalization: none, LRN (across and within), Batch 252 | # learning rat 253 | def create_unit(num, norm, func): 254 | 255 | ret = convunit.replace('NUMBER', str(num)) 256 | if num == 1: 257 | ret = ret.replace('INLAYER','data') 258 | else: 259 | ret = ret.replace('INLAYER', 'unit%d_conv1'%(num-1)) 260 | ret += norms[norm].replace('NUMBER', str(num)) 261 | ret += relus[func].replace('NUMBER', str(num)) 262 | return ret 263 | 264 | 265 | def makemodel(norm, func): 266 | m = modelstart 267 | for i in [1,2,3]: 268 | m += create_unit(i, norm, func) 269 | m += endmodel.replace('LASTCONV','unit3_conv1') 270 | 271 | return m 272 | 273 | 274 | models = [] 275 | for norm in sorted(norms.keys()): 276 | for func in sorted(relus.keys()): 277 | model = makemodel(norm, func) 278 | m = 'affinity_%s_%s.model'%(norm,func) 279 | models.append(m) 280 | out = open(m,'w') 281 | out.write(model) 282 | 283 | for m in models: 284 | for lr in [0.001, 0.01, 0.1]: 285 | print("train.py -m %s -p ../types/all_0.5_0_ --base_lr %f --keep_best -t 1000 -i 100000 --reduced -o all_%s_lr%.3f"%(m,lr,m.replace('.model',''),lr)) 286 | -------------------------------------------------------------------------------- /affinity_search/makemodels4.5.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Generate models for affinity predictions''' 4 | 5 | # initialization: gaussian, positive_unitball, uniform, xavier, msra, radial 6 | # kernel size: 3, 5 7 | 8 | modelstart = '''layer { 9 | name: "data" 10 | type: "MolGridData" 11 | top: "data" 12 | top: "label" 13 | top: "affinity" 14 | include { 15 | phase: TEST 16 | } 17 | molgrid_data_param { 18 | source: "TESTFILE" 19 | batch_size: 10 20 | dimension: 23.5 21 | resolution: 0.5 22 | shuffle: false 23 | balanced: false 24 | has_affinity: true 25 | root_folder: "../../" 26 | } 27 | } 28 | layer { 29 | name: "data" 30 | type: "MolGridData" 31 | top: "data" 32 | top: "label" 33 | top: "affinity" 34 | include { 35 | phase: TRAIN 36 | } 37 | molgrid_data_param { 38 | source: "TRAINFILE" 39 | batch_size: 50 40 | dimension: 23.5 41 | resolution: 0.5 42 | shuffle: true 43 | balanced: true 44 | stratify_receptor: true 45 | stratify_affinity_min: 0 46 | stratify_affinity_max: 0 47 | stratify_affinity_step: 0 48 | has_affinity: true 49 | random_rotation: true 50 | random_translate: 2 51 | root_folder: "../../" 52 | } 53 | } 54 | ''' 55 | 56 | endmodel = '''layer { 57 | name: "split" 58 | type: "Split" 59 | bottom: "LASTCONV" 60 | top: "split" 61 | } 62 | 63 | layer { 64 | name: "output_fc" 65 | type: "InnerProduct" 66 | bottom: "split" 67 | top: "output_fc" 68 | inner_product_param { 69 | num_output: 2 70 | weight_filler { 71 | type: "xavier" 72 | } 73 | } 74 | } 75 | layer { 76 | name: "loss" 77 | type: "SoftmaxWithLoss" 78 | bottom: "output_fc" 79 | bottom: "label" 80 | top: "loss" 81 | } 82 | 83 | layer { 84 | name: "output" 85 | type: "Softmax" 86 | bottom: "output_fc" 87 | top: "output" 88 | } 89 | layer { 90 | name: "labelout" 91 | type: "Split" 92 | bottom: "label" 93 | top: "labelout" 94 | include { 95 | phase: TEST 96 | } 97 | } 98 | 99 | layer { 100 | name: "output_fc_aff" 101 | type: "InnerProduct" 102 | bottom: "split" 103 | top: "output_fc_aff" 104 | inner_product_param { 105 | num_output: 1 106 | weight_filler { 107 | type: "xavier" 108 | } 109 | } 110 | } 111 | 112 | layer { 113 | name: "rmsd" 114 | type: "AffinityLoss" 115 | bottom: "output_fc_aff" 116 | bottom: "affinity" 117 | top: "rmsd" 118 | affinity_loss_param { 119 | scale: 0.1 120 | gap: 1 121 | penalty: 0 122 | pseudohuber: false 123 | delta: 0 124 | } 125 | } 126 | 127 | layer { 128 | name: "predaff" 129 | type: "Flatten" 130 | bottom: "output_fc_aff" 131 | top: "predaff" 132 | } 133 | 134 | layer { 135 | name: "affout" 136 | type: "Split" 137 | bottom: "affinity" 138 | top: "affout" 139 | include { 140 | phase: TEST 141 | } 142 | } 143 | 144 | ''' 145 | 146 | convunit = ''' 147 | layer { 148 | name: "unitNUMBER_pool" 149 | type: "Pooling" 150 | bottom: "INLAYER" 151 | top: "unitNUMBER_pool" 152 | pooling_param { 153 | pool: MAX 154 | kernel_size: 2 155 | stride: 2 156 | } 157 | } 158 | layer { 159 | name: "unitNUMBER_conv1" 160 | type: "Convolution" 161 | bottom: "unitNUMBER_pool" 162 | top: "unitNUMBER_conv1" 163 | convolution_param { 164 | num_output: OUTPUT 165 | pad: PAD 166 | kernel_size: KSIZE 167 | stride: 1 168 | weight_filler { 169 | type: "FILLER" 170 | symmetric_fraction: FRACTION 171 | } 172 | } 173 | }''' 174 | 175 | 176 | 177 | finishunit = ''' 178 | layer { 179 | name: "unitNUMBER_norm" 180 | type: "LRN" 181 | bottom: "unitNUMBER_conv1" 182 | top: "unitNUMBER_conv1" 183 | } 184 | 185 | layer { 186 | name: "unitNUMBER_scale" 187 | type: "Scale" 188 | bottom: "unitNUMBER_conv1" 189 | top: "unitNUMBER_conv1" 190 | scale_param { 191 | bias_term: true 192 | } 193 | } 194 | layer { 195 | name: "unitNUMBER_func" 196 | type: "ELU" 197 | bottom: "unitNUMBER_conv1" 198 | top: "unitNUMBER_conv1" 199 | } 200 | '''; 201 | 202 | # normalization: none, LRN (across and within), Batch 203 | # learning rat 204 | # depth 3, width 32 (doubled) 205 | def create_unit(num, ksize, filler, fraction): 206 | width = 32 207 | double = True 208 | ret = convunit.replace('NUMBER', str(num)) 209 | if num == 1: 210 | ret = ret.replace('INLAYER','data') 211 | else: 212 | ret = ret.replace('INLAYER', 'unit%d_conv1'%(num-1)) 213 | 214 | if num == 4: 215 | ksize = 3 #only 3x3 at this point 216 | 217 | pad = int(ksize/2) 218 | ret = ret.replace('PAD',str(pad)) 219 | ret = ret.replace('KSIZE', str(ksize)) 220 | outsize = width 221 | if double: 222 | outsize *= 2**(num-1) 223 | ret = ret.replace('OUTPUT', str(outsize)) 224 | ret = ret.replace('FILLER', filler) 225 | ret = ret.replace('FRACTION', str(fraction)) 226 | 227 | ret += finishunit.replace('NUMBER', str(num)) 228 | return ret 229 | 230 | 231 | def makemodel(ksize, filler, fraction): 232 | m = modelstart 233 | depth = 3 234 | for i in range(1,depth+1): 235 | m += create_unit(i, ksize, filler, fraction) 236 | m += endmodel.replace('LASTCONV','unit%d_conv1'%depth) 237 | 238 | return m 239 | 240 | 241 | models = [] 242 | 243 | for filler in ['radial','gaussian', 'positive_unitball', 'uniform', 'xavier', 'msra','radial.5']: 244 | fraction = 1.0 245 | if filler == 'radial.5': 246 | filler = 'radial' 247 | fraction = 0.5 248 | for ksize in [5,3]: 249 | model = makemodel(ksize, filler, fraction) 250 | m = 'affinity_%s_%.1f_%d.model'%(filler,fraction,ksize) 251 | models.append(m) 252 | out = open(m,'w') 253 | out.write(model) 254 | 255 | 256 | for m in models: 257 | print("train.py -m %s -p ../types/all_0.5_0_ --keep_best -t 1000 -i 100000 --reduced -o all_%s"%(m,m.replace('.model',''))) 258 | -------------------------------------------------------------------------------- /affinity_search/makemodels4.6.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Generate models for affinity predictions''' 4 | 5 | # [SGD|Adam] * [regular|rankloss|ranklosswneg] [xaviar|radial] [0.01|0.001] 6 | 7 | modelstart = '''layer { 8 | name: "data" 9 | type: "MolGridData" 10 | top: "data" 11 | top: "label" 12 | top: "affinity" 13 | include { 14 | phase: TEST 15 | } 16 | molgrid_data_param { 17 | source: "TESTFILE" 18 | batch_size: 10 19 | dimension: 23.5 20 | resolution: 0.5 21 | shuffle: false 22 | balanced: false 23 | has_affinity: true 24 | root_folder: "../../" 25 | } 26 | } 27 | layer { 28 | name: "data" 29 | type: "MolGridData" 30 | top: "data" 31 | top: "label" 32 | top: "affinity" 33 | include { 34 | phase: TRAIN 35 | } 36 | molgrid_data_param { 37 | source: "TRAINFILE" 38 | batch_size: 50 39 | dimension: 23.5 40 | resolution: 0.5 41 | shuffle: true 42 | balanced: true 43 | stratify_receptor: true 44 | stratify_affinity_min: 0 45 | stratify_affinity_max: 0 46 | stratify_affinity_step: 0 47 | has_affinity: true 48 | random_rotation: true 49 | random_translate: 2 50 | root_folder: "../../" 51 | } 52 | } 53 | ''' 54 | 55 | endmodel = '''layer { 56 | name: "split" 57 | type: "Split" 58 | bottom: "LASTCONV" 59 | top: "split" 60 | } 61 | 62 | layer { 63 | name: "output_fc" 64 | type: "InnerProduct" 65 | bottom: "split" 66 | top: "output_fc" 67 | inner_product_param { 68 | num_output: 2 69 | weight_filler { 70 | type: "xavier" 71 | } 72 | } 73 | } 74 | layer { 75 | name: "loss" 76 | type: "SoftmaxWithLoss" 77 | bottom: "output_fc" 78 | bottom: "label" 79 | top: "loss" 80 | } 81 | 82 | layer { 83 | name: "output" 84 | type: "Softmax" 85 | bottom: "output_fc" 86 | top: "output" 87 | } 88 | layer { 89 | name: "labelout" 90 | type: "Split" 91 | bottom: "label" 92 | top: "labelout" 93 | include { 94 | phase: TEST 95 | } 96 | } 97 | 98 | layer { 99 | name: "output_fc_aff" 100 | type: "InnerProduct" 101 | bottom: "split" 102 | top: "output_fc_aff" 103 | inner_product_param { 104 | num_output: 1 105 | weight_filler { 106 | type: "xavier" 107 | } 108 | } 109 | } 110 | 111 | layer { 112 | name: "rmsd" 113 | type: "AffinityLoss" 114 | bottom: "output_fc_aff" 115 | bottom: "affinity" 116 | top: "rmsd" 117 | affinity_loss_param { 118 | scale: 0.1 119 | gap: 1 120 | penalty: 0 121 | pseudohuber: false 122 | delta: 0 123 | ranklossmult: RANKLOSS 124 | ranklossneg: RANKNEG 125 | } 126 | } 127 | 128 | layer { 129 | name: "predaff" 130 | type: "Flatten" 131 | bottom: "output_fc_aff" 132 | top: "predaff" 133 | } 134 | 135 | layer { 136 | name: "affout" 137 | type: "Split" 138 | bottom: "affinity" 139 | top: "affout" 140 | include { 141 | phase: TEST 142 | } 143 | } 144 | 145 | ''' 146 | 147 | convunit = ''' 148 | layer { 149 | name: "unitNUMBER_pool" 150 | type: "Pooling" 151 | bottom: "INLAYER" 152 | top: "unitNUMBER_pool" 153 | pooling_param { 154 | pool: MAX 155 | kernel_size: 2 156 | stride: 2 157 | } 158 | } 159 | layer { 160 | name: "unitNUMBER_conv1" 161 | type: "Convolution" 162 | bottom: "unitNUMBER_pool" 163 | top: "unitNUMBER_conv1" 164 | convolution_param { 165 | num_output: OUTPUT 166 | pad: 1 167 | kernel_size: 3 168 | stride: 1 169 | weight_filler { 170 | type: "FILLER" 171 | symmetric_fraction: FRACTION 172 | } 173 | } 174 | }''' 175 | 176 | 177 | 178 | finishunit = ''' 179 | layer { 180 | name: "unitNUMBER_norm" 181 | type: "LRN" 182 | bottom: "unitNUMBER_conv1" 183 | top: "unitNUMBER_conv1" 184 | } 185 | 186 | layer { 187 | name: "unitNUMBER_scale" 188 | type: "Scale" 189 | bottom: "unitNUMBER_conv1" 190 | top: "unitNUMBER_conv1" 191 | scale_param { 192 | bias_term: true 193 | } 194 | } 195 | layer { 196 | name: "unitNUMBER_func" 197 | type: "ELU" 198 | bottom: "unitNUMBER_conv1" 199 | top: "unitNUMBER_conv1" 200 | } 201 | '''; 202 | 203 | # normalization: none, LRN (across and within), Batch 204 | # learning rat 205 | # depth 3, width 32 (doubled) 206 | def create_unit(num, filler, fraction): 207 | width = 32 208 | double = True 209 | ret = convunit.replace('NUMBER', str(num)) 210 | if num == 1: 211 | ret = ret.replace('INLAYER','data') 212 | else: 213 | ret = ret.replace('INLAYER', 'unit%d_conv1'%(num-1)) 214 | 215 | outsize = width 216 | if double: 217 | outsize *= 2**(num-1) 218 | ret = ret.replace('OUTPUT', str(outsize)) 219 | ret = ret.replace('FILLER', filler) 220 | ret = ret.replace('FRACTION', str(fraction)) 221 | 222 | ret += finishunit.replace('NUMBER', str(num)) 223 | return ret 224 | 225 | 226 | def makemodel(filler, fraction, ranklossm, rankneg): 227 | m = modelstart 228 | depth = 3 229 | for i in range(1,depth+1): 230 | m += create_unit(i, filler, fraction) 231 | m += endmodel.replace('LASTCONV','unit%d_conv1'%depth).replace('RANKLOSS',str(ranklossm)).replace('RANKNEG',str(rankneg)) 232 | 233 | return m 234 | 235 | 236 | models = [] 237 | 238 | # [SGD|Adam] * [regular|rankloss|ranklosswneg] [xaviar|radial|radial.5] [0.01|0.001][ 239 | 240 | for ranklossm in [0, 0.01,0.1,1]: 241 | for rankneg in [0,1]: 242 | if ranklossm == 0 and rankneg == 1: 243 | continue 244 | for filler in ['xavier']: 245 | fraction = 1.0 246 | if filler == 'radial.5': 247 | filler = 'radial' 248 | fraction = 0.5 249 | model = makemodel(filler, fraction,ranklossm, rankneg) 250 | m = 'affinity_%.3f_%d.model'%(ranklossm,rankneg) 251 | models.append(m) 252 | out = open(m,'w') 253 | out.write(model) 254 | 255 | 256 | for m in models: 257 | for baselr in [0.01, 0.001]: 258 | for solver in ['SGD','Adam']: 259 | print("train.py -m %s -p ../types/all_0.5_0_ --keep_best -t 1000 -i 100000 --solver %s --base_lr %f --reduced -o all_%s_%s_%.3f"%(m,solver, baselr, m.replace('.model',''),solver,baselr)) 260 | -------------------------------------------------------------------------------- /affinity_search/makemodels4.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Generate models for affinity predictions''' 4 | 5 | # variables: 6 | # kernel size: 3, 5, 7 7 | # depth: 2, 3, 4 8 | # width: 16, 32, 64, 128 9 | # doubling of width: true/false (e.g, 128->256->512) 10 | 11 | modelstart = '''layer { 12 | name: "data" 13 | type: "MolGridData" 14 | top: "data" 15 | top: "label" 16 | top: "affinity" 17 | include { 18 | phase: TEST 19 | } 20 | molgrid_data_param { 21 | source: "TESTFILE" 22 | batch_size: 10 23 | dimension: 23.5 24 | resolution: 0.5 25 | shuffle: false 26 | balanced: false 27 | has_affinity: true 28 | root_folder: "../../" 29 | } 30 | } 31 | layer { 32 | name: "data" 33 | type: "MolGridData" 34 | top: "data" 35 | top: "label" 36 | top: "affinity" 37 | include { 38 | phase: TRAIN 39 | } 40 | molgrid_data_param { 41 | source: "TRAINFILE" 42 | batch_size: 50 43 | dimension: 23.5 44 | resolution: 0.5 45 | shuffle: true 46 | balanced: true 47 | stratify_receptor: true 48 | stratify_affinity_min: 0 49 | stratify_affinity_max: 0 50 | stratify_affinity_step: 0 51 | has_affinity: true 52 | random_rotation: true 53 | random_translate: 2 54 | root_folder: "../../" 55 | } 56 | } 57 | ''' 58 | 59 | endmodel = '''layer { 60 | name: "split" 61 | type: "Split" 62 | bottom: "LASTCONV" 63 | top: "split" 64 | } 65 | 66 | layer { 67 | name: "output_fc" 68 | type: "InnerProduct" 69 | bottom: "split" 70 | top: "output_fc" 71 | inner_product_param { 72 | num_output: 2 73 | weight_filler { 74 | type: "xavier" 75 | } 76 | } 77 | } 78 | layer { 79 | name: "loss" 80 | type: "SoftmaxWithLoss" 81 | bottom: "output_fc" 82 | bottom: "label" 83 | top: "loss" 84 | } 85 | 86 | layer { 87 | name: "output" 88 | type: "Softmax" 89 | bottom: "output_fc" 90 | top: "output" 91 | } 92 | layer { 93 | name: "labelout" 94 | type: "Split" 95 | bottom: "label" 96 | top: "labelout" 97 | include { 98 | phase: TEST 99 | } 100 | } 101 | 102 | layer { 103 | name: "output_fc_aff" 104 | type: "InnerProduct" 105 | bottom: "split" 106 | top: "output_fc_aff" 107 | inner_product_param { 108 | num_output: 1 109 | weight_filler { 110 | type: "xavier" 111 | } 112 | } 113 | } 114 | 115 | layer { 116 | name: "rmsd" 117 | type: "AffinityLoss" 118 | bottom: "output_fc_aff" 119 | bottom: "affinity" 120 | top: "rmsd" 121 | affinity_loss_param { 122 | scale: 0.1 123 | gap: 1 124 | penalty: 0 125 | pseudohuber: false 126 | delta: 0 127 | } 128 | } 129 | 130 | layer { 131 | name: "predaff" 132 | type: "Flatten" 133 | bottom: "output_fc_aff" 134 | top: "predaff" 135 | } 136 | 137 | layer { 138 | name: "affout" 139 | type: "Split" 140 | bottom: "affinity" 141 | top: "affout" 142 | include { 143 | phase: TEST 144 | } 145 | } 146 | 147 | ''' 148 | 149 | convunit = ''' 150 | layer { 151 | name: "unitNUMBER_pool" 152 | type: "Pooling" 153 | bottom: "INLAYER" 154 | top: "unitNUMBER_pool" 155 | pooling_param { 156 | pool: MAX 157 | kernel_size: 2 158 | stride: 2 159 | } 160 | } 161 | layer { 162 | name: "unitNUMBER_conv1" 163 | type: "Convolution" 164 | bottom: "unitNUMBER_pool" 165 | top: "unitNUMBER_conv1" 166 | convolution_param { 167 | num_output: OUTPUT 168 | pad: PAD 169 | kernel_size: KSIZE 170 | stride: 1 171 | weight_filler { 172 | type: "xavier" 173 | } 174 | } 175 | }''' 176 | 177 | 178 | 179 | finishunit = ''' 180 | layer { 181 | name: "unitNUMBER_norm" 182 | type: "LRN" 183 | bottom: "unitNUMBER_conv1" 184 | top: "unitNUMBER_conv1" 185 | } 186 | 187 | layer { 188 | name: "unitNUMBER_scale" 189 | type: "Scale" 190 | bottom: "unitNUMBER_conv1" 191 | top: "unitNUMBER_conv1" 192 | scale_param { 193 | bias_term: true 194 | } 195 | } 196 | layer { 197 | name: "unitNUMBER_func" 198 | type: "ELU" 199 | bottom: "unitNUMBER_conv1" 200 | top: "unitNUMBER_conv1" 201 | } 202 | '''; 203 | 204 | # normalization: none, LRN (across and within), Batch 205 | # learning rat 206 | def create_unit(num, ksize, width, double): 207 | 208 | ret = convunit.replace('NUMBER', str(num)) 209 | if num == 1: 210 | ret = ret.replace('INLAYER','data') 211 | else: 212 | ret = ret.replace('INLAYER', 'unit%d_conv1'%(num-1)) 213 | 214 | if num == 4: 215 | ksize = 3 #only 3x3 at this point 216 | 217 | pad = int(ksize/2) 218 | ret = ret.replace('PAD',str(pad)) 219 | ret = ret.replace('KSIZE', str(ksize)) 220 | outsize = width 221 | if double: 222 | outsize *= 2**(num-1) 223 | ret = ret.replace('OUTPUT', str(outsize)) 224 | 225 | ret += finishunit.replace('NUMBER', str(num)) 226 | return ret 227 | 228 | 229 | def makemodel(depth, width, double, ksize): 230 | m = modelstart 231 | for i in range(1,depth+1): 232 | m += create_unit(i, ksize, width, double) 233 | m += endmodel.replace('LASTCONV','unit%d_conv1'%depth) 234 | 235 | return m 236 | 237 | 238 | models = [] 239 | for depth in [4,3,2]: 240 | for width in [128, 64, 32, 16]: 241 | for double in [True, False]: 242 | for ksize in [7,5,3]: 243 | model = makemodel(depth,width, double, ksize) 244 | m = 'affinity_%d_%d_%d_%d.model'%(depth,width,int(double),ksize) 245 | models.append(m) 246 | out = open(m,'w') 247 | out.write(model) 248 | 249 | for m in models: 250 | print("train.py -m %s -p ../types/all_0.5_0_ --keep_best -t 1000 -i 100000 --reduced -o all_%s"%(m,m.replace('.model',''))) 251 | -------------------------------------------------------------------------------- /affinity_search/makemodels5.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Generate models for affinity predictions''' 4 | 5 | # variables: 6 | # kernel size: 3 or 7 7 | # width: [32,32,32] [64,32,32] [64,32,16] [32,16,16] 8 | # pool or not 9 | # stride 1,2,3 (>1 if no pool) 10 | 11 | modelstart = '''layer { 12 | name: "data" 13 | type: "MolGridData" 14 | top: "data" 15 | top: "label" 16 | top: "affinity" 17 | include { 18 | phase: TEST 19 | } 20 | molgrid_data_param { 21 | source: "TESTFILE" 22 | batch_size: 10 23 | dimension: 23.5 24 | resolution: 0.5 25 | shuffle: false 26 | balanced: false 27 | has_affinity: true 28 | root_folder: "../../" 29 | } 30 | } 31 | layer { 32 | name: "data" 33 | type: "MolGridData" 34 | top: "data" 35 | top: "label" 36 | top: "affinity" 37 | include { 38 | phase: TRAIN 39 | } 40 | molgrid_data_param { 41 | source: "TRAINFILE" 42 | batch_size: 50 43 | dimension: 23.5 44 | resolution: 0.5 45 | shuffle: true 46 | balanced: true 47 | stratify_receptor: true 48 | stratify_affinity_min: 0 49 | stratify_affinity_max: 0 50 | stratify_affinity_step: 0 51 | has_affinity: true 52 | random_rotation: true 53 | random_translate: 2 54 | root_folder: "../../" 55 | } 56 | } 57 | ''' 58 | 59 | endmodel = '''layer { 60 | name: "split" 61 | type: "Split" 62 | bottom: "LASTCONV" 63 | top: "split" 64 | } 65 | 66 | layer { 67 | name: "output_fc" 68 | type: "InnerProduct" 69 | bottom: "split" 70 | top: "output_fc" 71 | inner_product_param { 72 | num_output: 2 73 | weight_filler { 74 | type: "xavier" 75 | } 76 | } 77 | } 78 | layer { 79 | name: "loss" 80 | type: "SoftmaxWithLoss" 81 | bottom: "output_fc" 82 | bottom: "label" 83 | top: "loss" 84 | } 85 | 86 | layer { 87 | name: "output" 88 | type: "Softmax" 89 | bottom: "output_fc" 90 | top: "output" 91 | } 92 | layer { 93 | name: "labelout" 94 | type: "Split" 95 | bottom: "label" 96 | top: "labelout" 97 | include { 98 | phase: TEST 99 | } 100 | } 101 | 102 | layer { 103 | name: "output_fc_aff" 104 | type: "InnerProduct" 105 | bottom: "split" 106 | top: "output_fc_aff" 107 | inner_product_param { 108 | num_output: 1 109 | weight_filler { 110 | type: "xavier" 111 | } 112 | } 113 | } 114 | 115 | layer { 116 | name: "rmsd" 117 | type: "AffinityLoss" 118 | bottom: "output_fc_aff" 119 | bottom: "affinity" 120 | top: "rmsd" 121 | affinity_loss_param { 122 | scale: 0.1 123 | gap: 1 124 | penalty: 0 125 | pseudohuber: false 126 | delta: 0 127 | } 128 | } 129 | 130 | layer { 131 | name: "predaff" 132 | type: "Flatten" 133 | bottom: "output_fc_aff" 134 | top: "predaff" 135 | } 136 | 137 | layer { 138 | name: "affout" 139 | type: "Split" 140 | bottom: "affinity" 141 | top: "affout" 142 | include { 143 | phase: TEST 144 | } 145 | } 146 | 147 | ''' 148 | 149 | poollayer = ''' 150 | layer { 151 | name: "unitNUMBER_pool" 152 | type: "Pooling" 153 | bottom: "INLAYER" 154 | top: "unitNUMBER_pool" 155 | pooling_param { 156 | pool: MAX 157 | kernel_size: 2 158 | stride: 2 159 | } 160 | } 161 | ''' 162 | 163 | fakepool = ''' 164 | layer { 165 | name: "unitNUMBER_pool" 166 | type: "Split" 167 | bottom: "INLAYER" 168 | top: "unitNUMBER_pool" 169 | } 170 | ''' 171 | 172 | convunit = ''' 173 | POOLLAYER 174 | 175 | layer { 176 | name: "unitNUMBER_conv1" 177 | type: "Convolution" 178 | bottom: "unitNUMBER_pool" 179 | top: "unitNUMBER_conv1" 180 | convolution_param { 181 | num_output: OUTPUT 182 | pad: PAD 183 | kernel_size: KSIZE 184 | stride: STRIDE 185 | weight_filler { 186 | type: "xavier" 187 | } 188 | } 189 | }''' 190 | 191 | 192 | 193 | finishunit = ''' 194 | layer { 195 | name: "unitNUMBER_norm" 196 | type: "LRN" 197 | bottom: "unitNUMBER_conv1" 198 | top: "unitNUMBER_conv1" 199 | } 200 | 201 | layer { 202 | name: "unitNUMBER_scale" 203 | type: "Scale" 204 | bottom: "unitNUMBER_conv1" 205 | top: "unitNUMBER_conv1" 206 | scale_param { 207 | bias_term: true 208 | } 209 | } 210 | layer { 211 | name: "unitNUMBER_func" 212 | type: "ELU" 213 | bottom: "unitNUMBER_conv1" 214 | top: "unitNUMBER_conv1" 215 | } 216 | '''; 217 | 218 | # normalization: none, LRN (across and within), Batch 219 | # learning rat 220 | def create_unit(num, ksize, width, pool,stride): 221 | 222 | if pool: 223 | ret = convunit.replace('POOLLAYER',poollayer) 224 | else: 225 | ret = convunit.replace('POOLLAYER',fakepool) 226 | ret = ret.replace('NUMBER', str(num)) 227 | if num == 1: 228 | ret = ret.replace('INLAYER','data') 229 | else: 230 | ret = ret.replace('INLAYER', 'unit%d_conv1'%(num-1)) 231 | 232 | 233 | pad = int(ksize/2) 234 | ret = ret.replace('PAD',str(pad)) 235 | ret = ret.replace('KSIZE', str(ksize)) 236 | ret = ret.replace('STRIDE',str(stride)) 237 | ret = ret.replace('OUTPUT', str(width)) 238 | 239 | ret += finishunit.replace('NUMBER', str(num)) 240 | return ret 241 | 242 | 243 | def makemodel(widths, ksize, pool, stride): 244 | m = modelstart 245 | for (i,w) in enumerate(widths): 246 | m += create_unit(i+1, ksize, w, pool, stride) 247 | m += endmodel.replace('LASTCONV','unit%d_conv1'%len(widths)) 248 | 249 | return m 250 | 251 | 252 | models = [] 253 | for widths in [[32,32,32], [64,32,32], [64,32,16], [32,16,16]]: 254 | for ksize in [7,3]: 255 | for pool in [True,False]: 256 | for stride in [1,2,3]: 257 | if stride > 1 and pool: 258 | continue 259 | model = makemodel(widths, ksize, pool, stride) 260 | m = 'affinity_%s_%d_%d_%d.model'%('-'.join(map(str,widths)),ksize,int(pool),stride) 261 | models.append(m) 262 | out = open(m,'w') 263 | out.write(model) 264 | 265 | for m in models: 266 | print("train.py -m %s -p ../types/all_0.5_0_ --keep_best -t 1000 -i 100000 --reduced -o all_%s"%(m,m.replace('.model',''))) 267 | -------------------------------------------------------------------------------- /affinity_search/makemodels6.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Generate models for affinity predictions''' 4 | 5 | # variables (for first layer only): 6 | # grid resolution: 0.25, 0.5, 1.0 # no 0.125 for now - blobs require small batch size 7 | # convolution: none, 2x2 (stride 2), 4x4 (stride 4), 8x8 (stride 8) 8 | # convolution can have activation function after it or not, be grouped or not 9 | # max pooling: none, 2x2, 4x4, 8x8 10 | # above are combined to generate 1A grid for next layer 11 | 12 | modelstart = '''layer { 13 | name: "data" 14 | type: "MolGridData" 15 | top: "data" 16 | top: "label" 17 | top: "affinity" 18 | include { 19 | phase: TEST 20 | } 21 | molgrid_data_param { 22 | source: "TESTFILE" 23 | batch_size: 1 24 | dimension: DIMENSION 25 | resolution: RESOLUTION 26 | shuffle: false 27 | balanced: false 28 | has_affinity: true 29 | root_folder: "../../" 30 | } 31 | } 32 | layer { 33 | name: "data" 34 | type: "MolGridData" 35 | top: "data" 36 | top: "label" 37 | top: "affinity" 38 | include { 39 | phase: TRAIN 40 | } 41 | molgrid_data_param { 42 | source: "TRAINFILE" 43 | batch_size: 20 44 | dimension: DIMENSION 45 | resolution: RESOLUTION 46 | shuffle: true 47 | balanced: true 48 | stratify_receptor: true 49 | stratify_affinity_min: 0 50 | stratify_affinity_max: 0 51 | stratify_affinity_step: 0 52 | has_affinity: true 53 | random_rotation: true 54 | random_translate: 2 55 | root_folder: "../../" 56 | } 57 | } 58 | 59 | 60 | FIRSTLAYER 61 | 62 | layer { 63 | name: "unit1_conv1" 64 | type: "Convolution" 65 | bottom: "INITIALNAME" 66 | top: "unit1_conv1" 67 | convolution_param { 68 | num_output: 32 69 | pad: 1 70 | kernel_size: 3 71 | stride: 1 72 | weight_filler { 73 | type: "xavier" 74 | } 75 | } 76 | } 77 | layer { 78 | name: "unit1_norm" 79 | type: "LRN" 80 | bottom: "unit1_conv1" 81 | top: "unit1_conv1" 82 | } 83 | 84 | layer { 85 | name: "unit1_scale" 86 | type: "Scale" 87 | bottom: "unit1_conv1" 88 | top: "unit1_conv1" 89 | scale_param { 90 | bias_term: true 91 | } 92 | } 93 | layer { 94 | name: "unit1_func" 95 | type: "ELU" 96 | bottom: "unit1_conv1" 97 | top: "unit1_conv1" 98 | } 99 | 100 | 101 | layer { 102 | name: "unit2_pool" 103 | type: "Pooling" 104 | bottom: "unit1_conv1" 105 | top: "unit2_pool" 106 | pooling_param { 107 | pool: MAX 108 | kernel_size: 2 109 | stride: 2 110 | } 111 | } 112 | 113 | 114 | layer { 115 | name: "unit2_conv1" 116 | type: "Convolution" 117 | bottom: "unit2_pool" 118 | top: "unit2_conv1" 119 | convolution_param { 120 | num_output: 32 121 | pad: 1 122 | kernel_size: 3 123 | stride: 1 124 | weight_filler { 125 | type: "xavier" 126 | } 127 | } 128 | } 129 | layer { 130 | name: "unit2_norm" 131 | type: "LRN" 132 | bottom: "unit2_conv1" 133 | top: "unit2_conv1" 134 | } 135 | 136 | layer { 137 | name: "unit2_scale" 138 | type: "Scale" 139 | bottom: "unit2_conv1" 140 | top: "unit2_conv1" 141 | scale_param { 142 | bias_term: true 143 | } 144 | } 145 | layer { 146 | name: "unit2_func" 147 | type: "ELU" 148 | bottom: "unit2_conv1" 149 | top: "unit2_conv1" 150 | } 151 | 152 | 153 | layer { 154 | name: "unit3_pool" 155 | type: "Pooling" 156 | bottom: "unit2_conv1" 157 | top: "unit3_pool" 158 | pooling_param { 159 | pool: MAX 160 | kernel_size: 2 161 | stride: 2 162 | } 163 | } 164 | 165 | 166 | layer { 167 | name: "unit3_conv1" 168 | type: "Convolution" 169 | bottom: "unit3_pool" 170 | top: "unit3_conv1" 171 | convolution_param { 172 | num_output: 32 173 | pad: 1 174 | kernel_size: 3 175 | stride: 1 176 | weight_filler { 177 | type: "xavier" 178 | } 179 | } 180 | } 181 | layer { 182 | name: "unit3_norm" 183 | type: "LRN" 184 | bottom: "unit3_conv1" 185 | top: "unit3_conv1" 186 | } 187 | 188 | layer { 189 | name: "unit3_scale" 190 | type: "Scale" 191 | bottom: "unit3_conv1" 192 | top: "unit3_conv1" 193 | scale_param { 194 | bias_term: true 195 | } 196 | } 197 | layer { 198 | name: "unit3_func" 199 | type: "ELU" 200 | bottom: "unit3_conv1" 201 | top: "unit3_conv1" 202 | } 203 | layer { 204 | name: "split" 205 | type: "Split" 206 | bottom: "unit3_conv1" 207 | top: "split" 208 | } 209 | 210 | layer { 211 | name: "output_fc" 212 | type: "InnerProduct" 213 | bottom: "split" 214 | top: "output_fc" 215 | inner_product_param { 216 | num_output: 2 217 | weight_filler { 218 | type: "xavier" 219 | } 220 | } 221 | } 222 | layer { 223 | name: "loss" 224 | type: "SoftmaxWithLoss" 225 | bottom: "output_fc" 226 | bottom: "label" 227 | top: "loss" 228 | } 229 | 230 | layer { 231 | name: "output" 232 | type: "Softmax" 233 | bottom: "output_fc" 234 | top: "output" 235 | } 236 | layer { 237 | name: "labelout" 238 | type: "Split" 239 | bottom: "label" 240 | top: "labelout" 241 | include { 242 | phase: TEST 243 | } 244 | } 245 | 246 | layer { 247 | name: "output_fc_aff" 248 | type: "InnerProduct" 249 | bottom: "split" 250 | top: "output_fc_aff" 251 | inner_product_param { 252 | num_output: 1 253 | weight_filler { 254 | type: "xavier" 255 | } 256 | } 257 | } 258 | 259 | layer { 260 | name: "rmsd" 261 | type: "AffinityLoss" 262 | bottom: "output_fc_aff" 263 | bottom: "affinity" 264 | top: "rmsd" 265 | affinity_loss_param { 266 | scale: 0.1 267 | gap: 1 268 | penalty: 0 269 | pseudohuber: false 270 | delta: 0 271 | } 272 | } 273 | 274 | layer { 275 | name: "predaff" 276 | type: "Flatten" 277 | bottom: "output_fc_aff" 278 | top: "predaff" 279 | } 280 | 281 | layer { 282 | name: "affout" 283 | type: "Split" 284 | bottom: "affinity" 285 | top: "affout" 286 | include { 287 | phase: TEST 288 | } 289 | } 290 | ''' 291 | 292 | poollayer = ''' 293 | layer { 294 | name: "initial_pool" 295 | type: "Pooling" 296 | bottom: "POOLINPUT" 297 | top: "initial_pool" 298 | pooling_param { 299 | pool: MAX 300 | kernel_size: SIZESTRIDE 301 | stride: SIZESTRIDE 302 | } 303 | } 304 | ''' 305 | 306 | convlayers = ['''layer { 307 | name: "initial_conv" 308 | type: "Convolution" 309 | bottom: "CONVINPUT" 310 | top: "initial_conv" 311 | convolution_param { 312 | num_output: 32 313 | pad: 0 314 | kernel_size: SIZESTRIDE 315 | stride: SIZESTRIDE 316 | weight_filler { 317 | type: "xavier" 318 | } 319 | } 320 | }''', '''layer { 321 | name: "initial_conv" 322 | type: "Convolution" 323 | bottom: "CONVINPUT" 324 | top: "initial_conv" 325 | convolution_param { 326 | num_output: 35 327 | group: 35 328 | pad: 0 329 | kernel_size: SIZESTRIDE 330 | stride: SIZESTRIDE 331 | weight_filler { 332 | type: "xavier" 333 | } 334 | } 335 | }'''] 336 | 337 | convafter = ''' 338 | layer { 339 | name: "initial_norm" 340 | type: "LRN" 341 | bottom: "initial_conv" 342 | top: "initial_norm" 343 | } 344 | 345 | layer { 346 | name: "initial_scale" 347 | type: "Scale" 348 | bottom: "initial_norm" 349 | top: "initial_scale" 350 | scale_param { 351 | bias_term: true 352 | } 353 | } 354 | layer { 355 | name: "initial_func" 356 | type: "ELU" 357 | bottom: "initial_scale" 358 | top: "initial_func" 359 | } 360 | ''' 361 | 362 | def makemodel(resolution, conv, pool, grouped, func, swapped): 363 | m = modelstart.replace('RESOLUTION','%.3f'%resolution) 364 | dim = 24-resolution 365 | m = m.replace('DIMENSION','%.3f'%dim) 366 | 367 | clayer = '' 368 | player = '' 369 | convname = 'initial_conv' 370 | if conv > 1: # have a layer 371 | clayer = convlayers[grouped] 372 | clayer = clayer.replace('SIZESTRIDE',str(conv)) 373 | if func: 374 | clayer += convafter 375 | convname = 'initial_func' 376 | if pool > 1: 377 | player = poollayer.replace('SIZESTRIDE',str(pool)) 378 | 379 | initial = '' 380 | 381 | if conv == 1 and pool == 1: 382 | m = m.replace('INITIALNAME','data') 383 | elif swapped: 384 | player = player.replace('POOLINPUT','data') 385 | if pool > 1: 386 | ipool = 'initial_pool' 387 | else: 388 | ipool = 'data' 389 | clayer = clayer.replace('CONVINPUT',ipool) 390 | 391 | initial = player+clayer 392 | if conv > 1: 393 | iconv = convname 394 | else: 395 | iconv = 'initial_pool' 396 | m = m.replace('INITIALNAME',iconv) 397 | else: 398 | clayer = clayer.replace('CONVINPUT','data') 399 | if conv > 1: 400 | iconv = convname 401 | else: 402 | iconv = 'data' 403 | player = player.replace('POOLINPUT',iconv) 404 | initial = clayer+player 405 | if pool > 1: 406 | ipool = 'initial_pool' 407 | else: 408 | ipool = convname 409 | m = m.replace('INITIALNAME',ipool) 410 | 411 | m = m.replace('FIRSTLAYER',initial) 412 | return m 413 | 414 | 415 | models = [] 416 | for resolution in [0.25, 0.5, 1.0]: 417 | for conv in [1,2,4,8]: 418 | for pool in [1,2,4,8]: 419 | if resolution*conv*pool == 1.0: 420 | #valid combination 421 | for grouped in [0,1]: 422 | for func in [0,1]: 423 | for swapped in [0,1]: 424 | if (conv == 1 or pool == 1) and swapped: # nothing to swap 425 | continue 426 | if conv == 1 and (grouped or func): 427 | continue #no conv layer 428 | 429 | model = makemodel(resolution, conv, pool, grouped, func, swapped) 430 | m = 'affinity_%.3f_conv%d_pool%d_grouped%d_func%d_swap%d.model'%(resolution,conv,pool,grouped,func,swapped) 431 | models.append(m) 432 | out = open(m,'w') 433 | out.write(model) 434 | 435 | 436 | for m in models: 437 | print("train.py -m %s -p ../types/all_0.5_0_ --keep_best -t 1000 -i 100000 --reduced -o all_%s"%(m,m.replace('.model',''))) 438 | -------------------------------------------------------------------------------- /affinity_search/makereduced.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys,re, collections 4 | '''reduce the provided example file to only have a single positive and single negative example per receptor (at most)''' 5 | 6 | #first identify best rmsd 7 | bestval = dict() 8 | bestlig = dict() 9 | for line in open(sys.argv[1]): 10 | vals = line.rstrip().split() 11 | if len(vals) < 6: 12 | continue 13 | rmsd = float(vals[5]) 14 | rec = vals[2] 15 | if rec not in bestval or rmsd < bestval[rec]: 16 | bestval[rec] = rmsd 17 | bestlig[rec] = vals[3] 18 | 19 | if len(bestlig) == 0: 20 | for line in open(sys.argv[1]): 21 | print(line.rstrip()) 22 | else: 23 | diddecoy = set() 24 | for line in open(sys.argv[1]): 25 | vals = line.rstrip().split() 26 | rec = vals[2] 27 | if rec in bestlig and vals[3] == bestlig[rec]: 28 | print(line.rstrip()) 29 | elif int(vals[0]) == 0 and rec not in diddecoy: 30 | diddecoy.add(rec) 31 | print(line.rstrip()) 32 | 33 | -------------------------------------------------------------------------------- /affinity_search/outputjson.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ''' 4 | Output the parameters that makemodel supports with their ranges 5 | ''' 6 | 7 | import makemodel 8 | import json, sys 9 | from collections import OrderedDict 10 | 11 | #extract from arguments to makemodel 12 | 13 | def makejson(): 14 | '''return dictionary of config options from makemodel''' 15 | opts = makemodel.getoptions() 16 | 17 | d=OrderedDict() 18 | for (name,vals) in sorted(opts.items()): 19 | paramsize=1 20 | if type(vals) == tuple: 21 | options=list(map(str,vals)) 22 | paramtype="enum" 23 | data=OrderedDict([("name",name), ("type", paramtype), ("size", paramsize),("options",options)]) 24 | elif isinstance(vals, makemodel.Range): 25 | parammin = vals.min 26 | parammax = vals.max 27 | paramtype="float" 28 | data=OrderedDict([("name",name), ("type", paramtype), ("min", parammin), ("max", parammax), ("size", paramsize)]) 29 | else: 30 | print("Unknown type") 31 | sys.exit(-1) 32 | d[name]=data 33 | return d 34 | 35 | if __name__ == '__main__': 36 | sys.stdout.write(json.dumps(makejson(), indent=4)+'\n') 37 | 38 | -------------------------------------------------------------------------------- /affinity_search/outputparams.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ''' 4 | Output the parameters that makemodel supports with their ranges 5 | ''' 6 | 7 | import makemodel, argparse 8 | 9 | #extract from arguments to makemodel 10 | opts = makemodel.getoptions() 11 | for (name,vals) in sorted(opts.items()): 12 | print(name,vals) 13 | -------------------------------------------------------------------------------- /affinity_search/outputsql.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ''' 4 | Output the parameters that makemodel supports with their ranges 5 | ''' 6 | 7 | import makemodel 8 | import json, sys 9 | from collections import OrderedDict 10 | 11 | #extract from arguments to makemodel 12 | opts = makemodel.getoptions() 13 | 14 | create = 'CREATE TABLE params (rmse DOUBLE, top DOUBLE, R DOUBLE, auc DOUBLE' 15 | 16 | #everything else make a string 17 | for (name,vals) in sorted(opts.items()): 18 | create += ', %s VARCHAR(32)' % name 19 | 20 | create += ');' 21 | 22 | print(create) 23 | -------------------------------------------------------------------------------- /affinity_search/populatedefaults.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Given a results.csv file and password for db, put the contents into google sql''' 4 | 5 | import sys, re, MySQLdb 6 | import pandas as pd 7 | import makemodel 8 | import numpy as np 9 | 10 | 11 | conn = MySQLdb.connect (host = "35.196.158.205",user = "opter",passwd=sys.argv[1],db="opt2") 12 | cursor = conn.cursor() 13 | 14 | opts = makemodel.getoptions() 15 | params = makemodel.getdefaults() 16 | 17 | 18 | params['id'] = 'REQUESTED' 19 | 20 | #do 5 variations 21 | for split in range(5): 22 | params['split'] = split 23 | params['seed'] = np.random.randint(0,100000) 24 | data = pd.DataFrame([params]) 25 | row = data.iloc[0] 26 | insert = 'INSERT INTO params (%s) VALUES (%s)' % (','.join(row.index),','.join(['%s']*len(row))) 27 | cursor.execute(insert,row) 28 | 29 | conn.commit() 30 | -------------------------------------------------------------------------------- /affinity_search/populaterequests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Given a results.dat file and password for db, put the contents into google sql as 4 | configurations that are being requested. Specify five of each.''' 5 | 6 | import sys, re, MySQLdb 7 | import pandas as pd 8 | import makemodel 9 | import numpy as np 10 | 11 | opts = makemodel.getoptions() 12 | 13 | def addrows(fname,host,db,password,start=0): 14 | '''add rows from fname into database, starting at row start''' 15 | data = pd.read_csv(fname,delim_whitespace=True,header=None) 16 | colnames = ['P1','P2'] 17 | for (name,val) in sorted(opts.items()): 18 | colnames.append(name) 19 | 20 | data.columns = colnames 21 | data = data.drop(['P1','P2'],axis=1) 22 | 23 | conn = MySQLdb.connect (host = host,user = "opter",passwd=password,db=db) 24 | cursor = conn.cursor() 25 | 26 | 27 | for (i,row) in data[start:].iterrows(): 28 | names = ','.join(row.index) 29 | values = ','.join(['%s']*len(row)) 30 | names += ',id' 31 | values += ',"REQUESTED"' 32 | #do five variations 33 | for split in range(5): 34 | seed = np.random.randint(0,100000) 35 | n = names + ',split,seed' 36 | v = values + ',%d,%d' % (split,seed) 37 | insert = 'INSERT INTO params (%s) VALUES (%s)' % (n,v) 38 | cursor.execute(insert,row) 39 | conn.commit() 40 | 41 | if __name__ == '__main__': 42 | addrows(sys.argv[1],"35.196.158.205","opt2",sys.argv[2]) 43 | -------------------------------------------------------------------------------- /affinity_search/populatesql.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Given a results.csv file and password for db, put the contents into google sql''' 4 | 5 | import sys, re, MySQLdb 6 | import pandas as pd 7 | 8 | conn = MySQLdb.connect (host = "35.196.158.205",user = "opter",passwd=sys.argv[2],db="opt2") 9 | cursor = conn.cursor() 10 | 11 | data = pd.read_csv(sys.argv[1]) 12 | 13 | for (i,row) in data.iterrows(): 14 | insert = 'INSERT INTO params (%s) VALUES (%s)' % (','.join(row.index),','.join(['%s']*len(row))) 15 | cursor.execute(insert,row) 16 | conn.commit() 17 | -------------------------------------------------------------------------------- /affinity_search/reval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Reevaluates *.caffemodel models in specified directory. Return same result as runline. Assume directory was created with do1request/runline.''' 4 | 5 | import sys,os 6 | def get_script_path(): 7 | return os.path.dirname(os.path.realpath(sys.argv[0])) 8 | 9 | sys.path.append(get_script_path()+'/..') #train 10 | import re,argparse, tempfile, os,glob 11 | import makemodel 12 | import socket 13 | import train 14 | import numpy as np 15 | import sklearn.metrics 16 | import scipy.stats 17 | import calctop, predict 18 | 19 | class Bunch(object): 20 | def __init__(self, adict): 21 | self.__dict__.update(adict) 22 | 23 | 24 | parser = argparse.ArgumentParser(description='Evaluate models in directory and report results.') 25 | parser.add_argument('--data_root',type=str,help='Location of gninatypes directory',default='') 26 | parser.add_argument('--prefix',type=str,help='Prefix, not including split',default='../data/refined/all_0.5_') 27 | parser.add_argument('--split',type=int,help='Which predefined split to use',required=True) 28 | parser.add_argument('--dir',type=str,help='Directory to use',required=True) 29 | args = parser.parse_args() 30 | 31 | 32 | os.chdir(args.dir) 33 | prefix = '%s%d_'% (args.prefix,args.split) 34 | 35 | #collect only the latest caffemodel file for each fold 36 | models = {} 37 | for caffefile in glob.glob('*.caffemodel'): 38 | m = re.search(r'\S+(\d+)_iter_(\d+).caffemodel',caffefile) 39 | if m: 40 | fold = int(m.group(1)) 41 | iter = int(m.group(2)) 42 | if fold not in models or models[fold][1] < iter: 43 | models[fold] = (caffefile,iter) 44 | 45 | 46 | #for each fold, collect the predictions 47 | predictions = [] #a list of tuples 48 | topresults = [] 49 | for fold in models: 50 | (caffefile, iter) = models[fold] 51 | testfile = prefix+'test%d.types'%fold 52 | pargs = predict.parse_args(['-m','model.model','-w',caffefile,'-d',args.data_root,'-i',testfile]) 53 | predictions += predict.predict(pargs)[0] 54 | topresults += calctop.evaluate_fold(testfile,caffefile,'model.model',args.data_root) 55 | 56 | #parse prediction lines 57 | expaffs = [] 58 | predaffs = [] 59 | scores = [] 60 | labels = [] 61 | for p in predictions: 62 | score = p[0] 63 | predaff = p[1] 64 | vals = p[2].split() 65 | label = float(vals[0]) 66 | aff = float(vals[1]) 67 | scores.append(score) 68 | labels.append(label) 69 | if aff > 0: 70 | expaffs.append(aff) 71 | predaffs.append(predaff) 72 | 73 | #don't consider bad poses for affinity 74 | R = scipy.stats.pearsonr(expaffs, predaffs)[0] 75 | rmse = np.sqrt(sklearn.metrics.mean_squared_error(expaffs,predaffs)) 76 | auc = sklearn.metrics.roc_auc_score(labels, scores) 77 | top = 0 #calctop.find_top_ligand(topresults,1)/100.0 78 | 79 | print(args.dir, R, rmse, auc, top) 80 | 81 | -------------------------------------------------------------------------------- /affinity_search/runline.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | '''Read a line formated like a spearmint results.dat line, 4 | construct the corresponding model, run the model with cross validation, 5 | and print the results; dies with error if parameters are invalid''' 6 | 7 | import sys,os 8 | def get_script_path(): 9 | return os.path.dirname(os.path.realpath(sys.argv[0])) 10 | 11 | sys.path.append(get_script_path()+'/..') #train 12 | import re,argparse, tempfile, os,glob 13 | import makemodel 14 | import socket 15 | import train 16 | import numpy as np 17 | import sklearn.metrics 18 | import scipy.stats 19 | import calctop 20 | from evaluate import evaluate_fold, analyze_results 21 | from train import Namespace 22 | 23 | class Bunch(object): 24 | def __init__(self, adict): 25 | self.__dict__.update(adict) 26 | 27 | 28 | parser = argparse.ArgumentParser(description='Run single model line and report results.') 29 | parser.add_argument('--line',type=str,help='Complete line',required=True) 30 | parser.add_argument('--seed',type=int,help='Random seed',default=0) 31 | parser.add_argument('--split',type=int,help='Which predefined split to use',default=0) 32 | parser.add_argument('--data_root',type=str,help='Location of gninatypes directory',default='') 33 | parser.add_argument('--prefix',type=str,help='Prefix, not including split',default='../data/refined/all_0.5_') 34 | parser.add_argument('--dir',type=str,help='Directory to use') 35 | parser.add_argument('--ligmap',type=str,help="Ligand atom typing map to use",default='') 36 | parser.add_argument('--recmap',type=str,help="Receptor atom typing map to use",default='') 37 | args = parser.parse_args() 38 | 39 | linevals = args.line.split()[2:] 40 | 41 | opts = makemodel.getoptions() 42 | 43 | if len(linevals) != len(opts): 44 | print("Wrong number of options in line (%d) compared to options (%d)" %(len(linevals),len(opts))) 45 | 46 | params = dict() 47 | for (i,(name,vals)) in enumerate(sorted(opts.items())): 48 | v = linevals[i] 49 | if v == 'False': 50 | v = 0 51 | if v == 'True': 52 | v = 1 53 | if type(vals) == tuple: 54 | if type(vals[0]) == int: 55 | v = int(v) 56 | elif type(vals[0]) == float: 57 | v = float(v) 58 | elif isinstance(vals, makemodel.Range): 59 | v = float(v) 60 | params[name] = v 61 | 62 | 63 | if(args.ligmap): params['ligmap'] = args.ligmap 64 | if(args.recmap): params['recmap'] = args.recmap 65 | 66 | params = Bunch(params) 67 | 68 | model = makemodel.create_model(params) 69 | 70 | host = socket.gethostname() 71 | 72 | if args.dir: 73 | d = args.dir 74 | try: 75 | os.makedirs(d) 76 | except OSError: 77 | pass 78 | else: 79 | d = tempfile.mkdtemp(prefix=host+'-',dir='.') 80 | 81 | os.chdir(d) 82 | mfile = open('model.model','w') 83 | mfile.write(model) 84 | mfile.close() 85 | 86 | #get hyperparamters 87 | base_lr = 10**params.base_lr_exp 88 | momentum=params.momentum 89 | weight_decay = 10**params.weight_decay_exp 90 | solver = params.solver 91 | 92 | #setup training 93 | prefix = '%s%d_'% (args.prefix,args.split) 94 | trainargs = train.parse_args(['--seed',str(args.seed),'--prefix',prefix,'--data_root', 95 | args.data_root,'-t','1000','-i','250000','-m','model.model','--checkpoint', 96 | '--reduced','-o',d,'--momentum',str(momentum),'--weight_decay',str(weight_decay), 97 | '--base_lr',str(base_lr),'--solver',solver,'--dynamic','--lr_policy','fixed'])[0] 98 | 99 | train_test_files = train.get_train_test_files(prefix=prefix, foldnums=None, allfolds=False, reduced=True, prefix2=None) 100 | if len(train_test_files) == 0: 101 | print("error: missing train/test files",prefix) 102 | sys.exit(1) 103 | 104 | 105 | outprefix = d 106 | #train 107 | numfolds = 0 108 | for i in train_test_files: 109 | 110 | outname = '%s.%s' % (outprefix, i) 111 | results = train.train_and_test_model(trainargs, train_test_files[i], outname) 112 | test, trainres = results 113 | 114 | if not np.isfinite(np.sum(trainres.y_score)): 115 | print("Non-finite trainres score") 116 | sys.exit(-1) 117 | if not np.isfinite(np.sum(test.y_score)): 118 | print("Non-finite test score") 119 | sys.exit(-1) 120 | if not np.isfinite(np.sum(trainres.y_predaff)): 121 | print("Non-finite trainres aff") 122 | sys.exit(-1) 123 | if not np.isfinite(np.sum(test.y_predaff)): 124 | print("Non-finite test aff") 125 | sys.exit(-1) 126 | 127 | #once all folds are trained, test and evaluate them 128 | testresults = [] 129 | for i in train_test_files: 130 | 131 | #get latest model file for this fold 132 | lasti = -1 133 | caffemodel = '' 134 | for model in glob.glob('%s.%d_iter_*.caffemodel'%(outprefix,i)): 135 | m = re.search(r'_iter_(\d+).caffemodel', model) 136 | inum = int(m.group(1)) 137 | if inum > lasti: 138 | lasti = inum 139 | caffemodel = model 140 | if lasti == -1: 141 | print("Couldn't find valid caffemodel file %s.%d_iter_*.caffemodel"%(outprefix,i)) 142 | sys.exit(-1) 143 | 144 | testresults += evaluate_fold(train_test_files[i]['test'], caffemodel, 'model.model',trainargs.data_root) 145 | 146 | 147 | (rmse, R, S, aucpose, aucaff, top) = analyze_results(testresults,'%s.summary'%outprefix,'pose') 148 | 149 | print(d, R, rmse, aucpose, top) 150 | 151 | -------------------------------------------------------------------------------- /affinity_search/single_axis_grid_search.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | '''Given desired set of parameters, generates all configurations 4 | obtained by enumerate each parameter individually (continuous are discretized). 5 | 6 | ''' 7 | 8 | import sys, re, MySQLdb, argparse, os, json, subprocess 9 | import pandas as pd 10 | import makemodel 11 | import numpy as np 12 | from MySQLdb.cursors import DictCursor 13 | from outputjson import makejson 14 | from populaterequests import addrows 15 | 16 | 17 | parser = argparse.ArgumentParser(description='Exhaustive grid search along single axes of variation') 18 | parser.add_argument('--host',type=str,help='Database host') 19 | parser.add_argument('-p','--password',type=str,help='Database password') 20 | parser.add_argument('--db',type=str,help='Database name',default='database') 21 | parser.add_argument('-o','--output',type=str,help="Output file",default="rows.txt") 22 | parser.add_argument('--parameters',type=file,help='parameters to enumerate',required=True) 23 | args = parser.parse_args() 24 | 25 | #get options 26 | defaults = makemodel.getdefaults() 27 | options = makemodel.getoptions() 28 | opts = sorted(options.items()) 29 | 30 | #read in list of parameters 31 | params = args.parameters.read().rstrip().split() 32 | 33 | outrows = set() #uniq configurations only (e.g., avoid replicating the default over and over again) 34 | for param in params: 35 | if param in options: 36 | choices = options[param] 37 | if isinstance(choices, makemodel.Range): 38 | choices = np.linspace(choices.min,choices.max, 9) 39 | #for each parameter value, create a row 40 | for val in choices: 41 | row = ['P','P'] #spearmint 42 | for (name,_) in opts: 43 | if name == param: 44 | row.append(val) 45 | else: 46 | row.append(defaults[name]) 47 | outrows.add(tuple(row)) 48 | 49 | out = open(args.output,'w') 50 | for row in outrows: 51 | out.write(' '.join(map(str,row))+'\n') 52 | 53 | out.close() 54 | 55 | if args.host: 56 | addrows(args.output,args.host,args.db,args.password) 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /bootstrap.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import predict 4 | import sklearn.metrics 5 | import argparse, sys 6 | import os 7 | import numpy as np 8 | import glob 9 | import re 10 | import matplotlib.pyplot as plt 11 | 12 | def calc_auc(predictions): 13 | y_true =[] 14 | y_score=[] 15 | for line in predictions: 16 | values= line.split(" ") 17 | y_true.append(float(values[1])) 18 | y_score.append(float(values[0])) 19 | auc = sklearn.metrics.roc_auc_score(y_true,y_score) 20 | return auc 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser(description='bootstrap(sampling with replacement) test') 24 | parser.add_argument('-m','--model',type=str,required=True,help="Model template. Must use TESTFILE with unshuffled, unbalanced input") 25 | parser.add_argument('-w','--weights',type=str,required=True,help="Model weights (.caffemodel)") 26 | parser.add_argument('-i','--input',type=str,required=True,help="Input .types file to predict") 27 | parser.add_argument('-g','--gpu',type=int,help='Specify GPU to run on',default=-1) 28 | parser.add_argument('-o','--output',type=str,default='',help='Output file name,default= predict_[model]_[input]') 29 | parser.add_argument('--iterations',type=int,default=1000,help="number of times to bootstrap") 30 | parser.add_argument('-k','--keep',action='store_true',default=False,help="Don't delete prototxt files") 31 | parser.add_argument('-n', '--number',action='store_true',default=False,help="if true uses caffemodel/input as is. if false uses all folds") 32 | parser.add_argument('--max_score',action='store_true',default=False,help="take max score per ligand as its score") 33 | parser.add_argument('--notcalc_predictions', type=str, default='',help='file of predictions') 34 | args = parser.parse_args() 35 | if args.output == '': 36 | output = 'bootstrap_%s_%s'%(args.model, args.input) 37 | else: 38 | output = args.output 39 | outname=output 40 | predictions=[] 41 | if args.notcalc_predictions=='': 42 | cm = args.weights 43 | ts = args.input 44 | if not args.number: 45 | foldnum = re.search('.[0-9]_iter',cm).group() 46 | cm=cm.replace(foldnum, '.[0-9]_iter') 47 | foldnum = re.search('[0-9].types',ts).group() 48 | ts=ts.replace(foldnum, '[NUMBER].types') 49 | 50 | for caffemodel in glob.glob(cm): 51 | testset = ts 52 | if not args.number: 53 | num = re.search('.[0-9]_iter',caffemodel).group() 54 | num=re.search(r'\d+', num).group() 55 | testset = ts.replace('[NUMBER]',num) 56 | args.input = testset 57 | args.weights = caffemodel 58 | predictions.extend(predict.predict_lines(args)) 59 | elif args.notcalc_predictions != '': 60 | for line in open(args.notcalc_predictions).readlines(): 61 | predictions.append(line) 62 | 63 | all_aucs=[] 64 | for _ in range(args.iterations): 65 | sample = np.random.choice(predictions,len(predictions), replace=True) 66 | all_aucs.append(calc_auc(sample)) 67 | mean=np.mean(all_aucs) 68 | std_dev = np.std(all_aucs) 69 | txt = 'mean: %.2f standard deviation: %.2f'%(mean,std_dev) 70 | print(txt) 71 | output = open(output, 'w') 72 | output.writelines('%.2f\n' %auc for auc in all_aucs) 73 | output.write(txt) 74 | output.close() 75 | 76 | plt.figure() 77 | plt.boxplot(all_aucs,0,'rs',0) 78 | plt.title('%s AUCs'%args.output, fontsize=22) 79 | plt.xlabel('AUC(%s)'%txt, fontsize=18) 80 | plt.savefig('%s_plot.pdf'%outname,bbox_inches='tight') 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /calccenters.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | '''Glob through files in current directory looking for */*_ligand.sdf and */*.gninatypes (assuming PDBbind layout). 4 | Calculate the distance between centers. If types files are passed, create versions with this information, 5 | optionally filtering. 6 | ''' 7 | 8 | import sys,glob,argparse,os 9 | import numpy as np 10 | import pybel 11 | import struct 12 | import openbabel 13 | 14 | openbabel.obErrorLog.StopLogging() 15 | 16 | parser = argparse.ArgumentParser() 17 | 18 | parser.add_argument('typefiles',metavar='file',type=str, nargs='+',help='Types files to process') 19 | parser.add_argument('--filter',type=float,default=100.0,help='Filter out examples greater the specified value') 20 | parser.add_argument('--suffix',type=str,default='_wc',help='Suffix for new types files') 21 | args = parser.parse_args() 22 | 23 | centerinfo = dict() 24 | #first process all gninatypes files in current directory tree 25 | for ligfile in glob.glob('*/*_ligand.sdf'): 26 | mol = next(pybel.readfile('sdf',ligfile)) 27 | #calc center 28 | center = np.mean([a.coords for a in mol.atoms],axis=0) 29 | dir = ligfile.split('/')[0] 30 | for gtypes in glob.glob('%s/*.gninatypes'%dir): 31 | buf = open(gtypes,'rb').read() 32 | n = len(buf)/4 33 | vals = np.array(struct.unpack('f'*n,buf)).reshape(n/4,4) 34 | lcenter = np.mean(vals,axis=0)[0:3] 35 | dist = np.linalg.norm(center-lcenter) 36 | centerinfo[gtypes] = dist 37 | 38 | for tfile in args.typefiles: 39 | fname,ext = os.path.splitext(tfile) 40 | outname = fname+args.suffix+ext 41 | out = open(outname,'w') 42 | for line in open(tfile): 43 | lfile = line.split('#')[0].split()[-1] 44 | if lfile not in centerinfo: 45 | print("Missing",lfile,tfile) 46 | sys.exit(0) 47 | else: 48 | d = centerinfo[lfile] 49 | if d < args.filter: 50 | out.write(line.rstrip()+" %f\n"%d) 51 | -------------------------------------------------------------------------------- /calctop.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import os, sys 5 | import os.path 6 | sys.path.append("/home/dkoes/git/gninascripts/") 7 | sys.path.append("/net/pulsar/home/koes/dkoes/git/gninascripts/") 8 | 9 | import train, predict 10 | import matplotlib, caffe 11 | import matplotlib.pyplot as plt 12 | import glob, re, sklearn, collections, argparse, sys 13 | import sklearn.metrics 14 | import scipy.stats 15 | 16 | def evaluate_fold(testfile, caffemodel, modelname,root_folder): 17 | '''Evaluate the passed model and the specified test set. 18 | Assumes the .model file is named a certain way. 19 | Returns tuple: 20 | (correct, prediction, receptor, ligand, label (optional), posescore (optional)) 21 | label and posescore are only provided is trained on pose data 22 | ''' 23 | caffe.set_mode_gpu() 24 | test_model = ('predict.%d.prototxt' % os.getpid()) 25 | print(("test_model:" + test_model)) 26 | train.write_model_file(test_model, modelname, testfile, testfile, root_folder) 27 | test_net = caffe.Net(test_model, caffemodel, caffe.TEST) 28 | lines = open(testfile).readlines() 29 | 30 | res = None 31 | i = 0 #index in batch 32 | correct = 0 33 | prediction = 0 34 | receptor = '' 35 | ligand = '' 36 | label = 0 37 | posescore = -1 38 | ret = [] 39 | 40 | for line in lines: 41 | #check if we need a new batch of results 42 | if not res or i >= batch_size: 43 | res = test_net.forward() 44 | if 'output' in res: 45 | batch_size = res['output'].shape[0] 46 | else: 47 | batch_size = res['affout'].shape[0] 48 | i = 0 49 | 50 | if 'labelout' in res: 51 | label = float(res['labelout'][i]) 52 | if 'output' in res: 53 | posescore = float(res['output'][i][1]) 54 | if 'affout' in res: 55 | correct = float(res['affout'][i]) 56 | if 'predaff' in res: 57 | prediction = float(res['predaff'][i]) 58 | if not np.isfinite(prediction).all(): 59 | os.remove(test_model) 60 | return [] #gracefully handle nan? 61 | 62 | #extract ligand/receptor for input file 63 | tokens = line.split() 64 | linelabel = int(tokens[0]) 65 | for t in range(len(tokens)): 66 | if tokens[t].endswith('gninatypes'): 67 | receptor = tokens[t] 68 | ligand = tokens[t+1] 69 | break 70 | 71 | #(correct, prediction, receptor, ligand, label (optional), posescore (optional)) 72 | if posescore < 0: 73 | ret.append((correct, prediction, receptor, ligand)) 74 | else: 75 | ret.append((correct, prediction, receptor, ligand, label, posescore)) 76 | 77 | if int(label) != linelabel: #sanity check 78 | print("Mismatched labels in calctop:",(label,linelabel,correct, prediction, receptor, ligand)) 79 | sys.exit(-1) 80 | i += 1 #batch index 81 | 82 | os.remove(test_model) 83 | return ret 84 | 85 | def find_top_ligand(results, topnum): 86 | targets={} 87 | correct_poses=0 88 | ligands=[] 89 | 90 | for r in results: 91 | rec = r[2] 92 | if rec in targets: 93 | #negate the label so that ties are always broken unfavorably 94 | targets[rec].append((r[5], -r[4])) #posescore and label 95 | if r[5] == None: 96 | print(("Error: Posescore does not exist for "+r[2])) 97 | exit() 98 | else: 99 | targets[rec] = [(r[5], -r[4])] 100 | num_targets=len(targets) 101 | 102 | for t in targets: 103 | targets[t].sort() 104 | top_tuples = targets[t][-topnum:] 105 | for i in top_tuples: 106 | if i[1]: 107 | correct_poses += 1 108 | break 109 | 110 | percent = float(correct_poses)/float(num_targets)*100.0 111 | return percent 112 | 113 | if __name__ == '__main__': 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument('-m','--model',type=str,required=True,help='Model filename') 116 | parser.add_argument('-p','--prefix',type=str,required=True,help='Prefix for test files') 117 | parser.add_argument('-c','--caffemodel',type=str,required=True,help='Prefix for caffemodel file') 118 | parser.add_argument('-o','--output',type=str,required=True,help='Output filename') 119 | parser.add_argument('-f','--folds',type=int,default=3,help='Number of folds') 120 | parser.add_argument('-i','--iterations',type=int,default=0,help='Iterations in caffemodel filename') 121 | parser.add_argument('-t','--top',type=int,default=10,help='Number of top ligands to look at') 122 | parser.add_argument('-d','--data_root',type=str,required=False,help="Root folder for relative paths in train/test files",default='') 123 | 124 | args = parser.parse_args() 125 | 126 | iterations=args.iterations 127 | if iterations == 0: 128 | highest_iter=0 129 | for name in glob.glob('*.caffemodel'): 130 | nums=(re.findall('\d+', name )) 131 | new_iter=int(nums[-1]) 132 | if new_iter>highest_iter: 133 | highest_iter=new_iter 134 | iterations=highest_iter 135 | 136 | modelname = (args.model) 137 | output = (args.output) 138 | 139 | results=[] 140 | for f in range(args.folds): 141 | 142 | iterations = args.iterations 143 | if not iterations: 144 | #find highest _for this fold_ 145 | highest = 0 146 | for name in glob.glob('%s.%d_iter*.caffemodel'%(args.caffemodel,f)): 147 | inum = int(re.findall(r'\d+', name)[-1]) 148 | if inum > highest: 149 | highest = inum 150 | iterations = highest 151 | 152 | caffemodel='%s.%d_iter_%d.caffemodel' % (args.caffemodel, f, iterations) 153 | if (os.path.isfile(caffemodel) == False): 154 | print(('Error: Caffemodel %s does not exist. Check --caffemodel, --iterations, and --folds arguments.'%caffemodel)) 155 | testfile = (args.prefix + "train" + str(f) + ".types") 156 | results += evaluate_fold(testfile, caffemodel, modelname, args.data_root) 157 | 158 | file=open(output, "w") 159 | for i in range(1, args.top+1): 160 | top = find_top_ligand(results,i) 161 | file.write("Percent of targets that contain the correct pose in the top %d: %f\n"%(i,top)) 162 | file.close() 163 | 164 | -------------------------------------------------------------------------------- /cgo_arrow.py: -------------------------------------------------------------------------------- 1 | ''' 2 | http://pymolwiki.org/index.php/cgo_arrow 3 | 4 | (c) 2013 Thomas Holder, Schrodinger Inc. 5 | 6 | License: BSD-2-Clause 7 | ''' 8 | 9 | from pymol import cmd, cgo, CmdException 10 | 11 | 12 | def cgo_arrow(atom1='pk1', atom2='pk2', radius=0.5, gap=0.0, hlength=-1, hradius=-1, 13 | color='blue red', name='', state=0): 14 | ''' 15 | DESCRIPTION 16 | 17 | Create a CGO arrow between two picked atoms. 18 | 19 | ARGUMENTS 20 | 21 | atom1 = string: single atom selection or list of 3 floats {default: pk1} 22 | 23 | atom2 = string: single atom selection or list of 3 floats {default: pk2} 24 | 25 | radius = float: arrow radius {default: 0.5} 26 | 27 | gap = float: gap between arrow tips and the two atoms {default: 0.0} 28 | 29 | hlength = float: length of head 30 | 31 | hradius = float: radius of head 32 | 33 | color = string: one or two color names {default: blue red} 34 | 35 | name = string: name of CGO object 36 | 37 | state = int: arrow state index 38 | ''' 39 | from chempy import cpv 40 | 41 | radius, gap = float(radius), float(gap) 42 | hlength, hradius = float(hlength), float(hradius) 43 | state = int(state) 44 | 45 | try: 46 | color1, color2 = color.split() 47 | except: 48 | color1 = color2 = color 49 | color1 = list(cmd.get_color_tuple(color1)) 50 | color2 = list(cmd.get_color_tuple(color2)) 51 | 52 | def get_coord(v): 53 | if not isinstance(v, str): 54 | return v 55 | if v.startswith('['): 56 | return cmd.safe_list_eval(v) 57 | return cmd.get_atom_coords(v) 58 | 59 | xyz1 = get_coord(atom1) 60 | xyz2 = get_coord(atom2) 61 | normal = cpv.normalize(cpv.sub(xyz1, xyz2)) 62 | 63 | if hlength < 0: 64 | hlength = radius * 3.0 65 | if hradius < 0: 66 | hradius = hlength * 0.6 67 | 68 | if gap: 69 | diff = cpv.scale(normal, gap) 70 | xyz1 = cpv.sub(xyz1, diff) 71 | xyz2 = cpv.add(xyz2, diff) 72 | 73 | xyz3 = cpv.add(cpv.scale(normal, hlength), xyz2) 74 | 75 | obj = [cgo.CONE] + xyz3 + xyz2 + [hradius, 0.0] + color2 + color2 + [1.0, 0.0] 76 | 77 | if cpv.distance(xyz1, xyz2) > hlength: # draw cylinder 78 | obj += [cgo.CYLINDER] + xyz1 + xyz3 + [radius] + color1 + color2 79 | 80 | if not name: 81 | name = cmd.get_unused_name('arrow') 82 | 83 | cmd.load_cgo(obj, name, state) 84 | 85 | cmd.extend('cgo_arrow', cgo_arrow) 86 | 87 | -------------------------------------------------------------------------------- /clean_kept_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | '''Go through the files in the current directory and remove and caffemodel/solverstate 4 | files where there is a higher iteration available''' 5 | 6 | import glob,re,sys,collections,os 7 | 8 | prefixes = sys.argv[1:] 9 | if not prefixes: 10 | prefixes = ['.'] 11 | 12 | for dirname in prefixes: 13 | for suffix in ['caffemodel','solverstate','checkpoint','gen_model_state','gen_solver_state']: 14 | files = collections.defaultdict(list) 15 | for fname in glob.glob('%s/*.%s'%(dirname,suffix)): 16 | m = re.search('(.*)_iter_(\d+)\.%s'%suffix,fname) 17 | if m: 18 | prefix = m.group(1) 19 | i = int(m.group(2)) 20 | files[prefix].append((i,fname)) 21 | for (k,files) in list(files.items()): 22 | toremove = sorted(files,reverse=True)[1:] 23 | for (i,fname) in toremove: 24 | print (fname) 25 | os.remove(fname) 26 | -------------------------------------------------------------------------------- /combine_rows.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | '''Combine the output of compute_rows.py into a pickle file for clustering.py''' 4 | 5 | import pickle, sys, collections 6 | import numpy as np 7 | 8 | target_names = [] 9 | targets = dict() # name to index 10 | values = collections.defaultdict(dict) # indexed by row name, col name 11 | 12 | for fname in sys.argv[1:]: 13 | for line in open(fname): 14 | (t1,t2,dist,lsim) = line.split() 15 | dist = float(dist) 16 | if t2 not in targets: 17 | targets[t2] = len(target_names) 18 | target_names.append(t2) 19 | values[t1][t2] = (dist,lsim) 20 | 21 | 22 | #must have fully filled out matrix 23 | l = len(target_names) 24 | m = np.empty((l,l)) 25 | lm = np.empty((l,l)) 26 | m[:] = np.NAN 27 | lm[:] = np.NAN 28 | 29 | for t1 in values.keys(): 30 | for t2 in values[t1].keys(): 31 | i = targets[t1] 32 | j = targets[t2] 33 | m[i][j] = values[t1][t2][0] 34 | lm[i][j] = values[t1][t2][1] 35 | 36 | #check throws a key error if a key is missing in targets 37 | # or prints the sentence if NAN is present 38 | for i in range(l): 39 | for j in range(l): 40 | if not np.isfinite(m[i][j]): 41 | print("Missing distance for",targets[i],targets[j]) 42 | 43 | if not np.isfinite(lm[i][j]): 44 | print("Missing ligand_sim for",targets[i],targets[j]) 45 | 46 | 47 | pickle.dump((m, target_names, lm), open('matrix.pickle','wb'),-1) 48 | -------------------------------------------------------------------------------- /combine_rows_lowmem.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse as ap 4 | import numpy as np 5 | import pandas as pd 6 | import pickle 7 | 8 | from tqdm import tqdm 9 | 10 | # Parse arguments 11 | parser = ap.ArgumentParser() 12 | parser.add_argument("files", nargs="+", type=str) 13 | parser.add_argument("-out", "--output", type=str) 14 | args = parser.parse_args() 15 | 16 | # Open first file to get targets 17 | print("Reading targets...", end="", flush=True) 18 | targets = np.loadtxt(args.files[0], usecols=1, dtype="U4") 19 | n_targets = len(targets) 20 | print("done") 21 | 22 | # Build DataFrame 23 | # The DF are used to ensure that distances and ligand similarities are inserted at the correct place 24 | # Initializing the DF with a numpy array is essential for speed at assignment 25 | print("Allocating DataFrame memory...", end="", flush=True) 26 | df_dist = pd.DataFrame( 27 | index=targets, columns=targets, data=-1 * np.ones((n_targets, n_targets)) 28 | ) 29 | df_lsim = pd.DataFrame( 30 | index=targets, columns=targets, data=-1 * np.ones((n_targets, n_targets)) 31 | ) 32 | print("done") 33 | 34 | print("Merging data...", flush=True) 35 | for fname in tqdm(args.files): 36 | target = np.loadtxt(fname, usecols=0, dtype="U4")[0] 37 | ctargets = np.loadtxt(fname, usecols=1, dtype="U4") 38 | dist = np.loadtxt(fname, usecols=2) 39 | lsim = np.loadtxt(fname, usecols=3) 40 | 41 | # Populate distance matrix 42 | if len(dist) == n_targets: 43 | df_dist.loc[target, ctargets] = dist 44 | else: 45 | print(" Invalid number of distances for {target}") 46 | 47 | # Populate ligand similarity matrix 48 | if len(lsim) == n_targets: 49 | df_lsim.loc[target, ctargets] = lsim 50 | else: 51 | print(" Invalid number of ligand similarities for {target}") 52 | 53 | dist = df_dist.values 54 | lsim = df_lsim.values 55 | 56 | # Check properties 57 | print("Checking matrix properties...", flush=True) 58 | ddist, dlsim = np.diagonal(dist), np.diagonal(lsim) 59 | assert int(round(np.sum(ddist))) == int(round(np.sum(ddist[ddist < 0]))) 60 | assert int(round(np.sum(dlsim[dlsim >= 0]))) - int(round(np.sum(dlsim[dlsim < 0]))) == n_targets 61 | print("done") 62 | 63 | # Set NaNs for compatibility with original implementation 64 | dist[dist < 0] = np.nan 65 | lsim[lsim < 0] = np.nan 66 | 67 | print("Checking data...", flush=True) 68 | rows, cols = np.where(np.isnan(dist)) # Invalid distances 69 | for t1, t2 in zip(df_dist.index.values[rows], df_dist.columns.values[cols]): 70 | print(f" Missing distance for {t1} {t2}") 71 | rows, cols = np.where(np.isnan(lsim)) # Invalid ligand similarities 72 | for t1, t2 in zip(df_dist.index.values[rows], df_dist.columns.values[cols]): 73 | print(f" Missing ligand similarity for {t1} {t2}") 74 | 75 | print(f"Dumping pickle object {args.output}...", end="", flush=True) 76 | pickle.dump((dist, targets, lsim), open(f"{args.output}", "wb"), -1) 77 | print("done") 78 | -------------------------------------------------------------------------------- /compute_row.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | '''Compute a single row of a distance matrix from a pdbinfo file. 4 | This allows for distributed processing''' 5 | 6 | import clustering,argparse,sys 7 | from rdkit.Chem import AllChem as Chem 8 | from rdkit.Chem import AllChem 9 | from rdkit.DataStructs import FingerprintSimilarity as fs 10 | from rdkit.Chem.Fingerprints import FingerprintMols 11 | 12 | def compute_ligand_similarity(smiles, pair): 13 | ''' 14 | Input a list of smiles, and a pair to compute the similarity. 15 | Returns the indices of the pair and the similarity 16 | ''' 17 | 18 | (a,b) = pair 19 | smi_a = smiles[a] 20 | mol_a = AllChem.MolFromSmiles(smi_a) 21 | if mol_a == None: 22 | mol_a = AllChem.MolFromSmiles(smi_a, sanitize=False) 23 | fp_a = FingerprintMols.FingerprintMol(mol_a) 24 | 25 | smi_b = smiles[b] 26 | mol_b = AllChem.MolFromSmiles(smi_b) 27 | if mol_b == None: 28 | mol_b = AllChem.MolFromSmiles(smi_b, sanitize=False) 29 | fp_b = FingerprintMols.FingerprintMol(mol_b) 30 | 31 | sim=fs(fp_a, fp_b) 32 | 33 | return a, b, sim 34 | 35 | 36 | if __name__ == '__main__': 37 | parser = argparse.ArgumentParser(description='Compute a single row of a distance matrix and ligand similarity matrix from a pdbinfo file.') 38 | parser.add_argument('--pdbseqs',type=str,required=True,help="file with target names, ligand smile, and sequences (chains separated by space)") 39 | parser.add_argument('-r','--row',type=int,required=True,help="row to compute") 40 | parser.add_argument('--out',help='output file (default stdout)',type=argparse.FileType('w'),default=sys.stdout) 41 | 42 | 43 | args = parser.parse_args() 44 | 45 | target_names = [] 46 | targets = [] 47 | smiles = [] 48 | for line in open(args.pdbseqs): 49 | toks = line.rstrip().split() 50 | target_names.append(toks[0]) 51 | smiles.append(toks[1]) 52 | targets.append(toks[2:]) 53 | 54 | r = args.row 55 | if r < len(target_names): 56 | name = target_names[r] 57 | row = [] 58 | for i in range(len(target_names)): 59 | print(target_names[i]) 60 | (a, b, mindist) = clustering.cUTDM2(targets, (r,i)) 61 | (la, lb, lig_sim) = compute_ligand_similarity(smiles, (r,i)) 62 | #sanity checks 63 | assert a == la 64 | assert b == lb 65 | row.append((target_names[i], mindist, lig_sim)) 66 | #output somewhat verbosely 67 | for (n, dist, lsim) in row: 68 | args.out.write('%s %s %f %f\n'%(name, n, dist, lsim)) 69 | else: 70 | print("Invalid row",r,"with only",len(target_names),"targets") 71 | -------------------------------------------------------------------------------- /compute_seqs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | '''Given a pdbinfo file output sequence information for each chain''' 4 | 5 | import clustering,argparse,sys 6 | 7 | def get_smiles(target_names, input_file): 8 | ''' 9 | Returns a list of each of the smiles (3rd col) of input_file, indexed by target_names. 10 | ''' 11 | 12 | smi_dic={} 13 | smi_list=[] 14 | with open(input_file) as filein: 15 | for line in filein: 16 | name=line.split()[0] 17 | smi_file=line.split()[2].rstrip() 18 | smi=open(smi_file).readline().split()[0] 19 | smi_dic[name]=smi 20 | 21 | for tname in target_names: 22 | smi_list.append(smi_dic[tname]) 23 | 24 | return smi_list 25 | 26 | 27 | if __name__ == '__main__': 28 | parser = argparse.ArgumentParser(description='Output the needed input for compute_row. This takes the format of " " separated by spaces') 29 | parser.add_argument('--pdbfiles',type=str,required=True,help="file with target names, paths to pbdfiles of targets, and path to smiles file of ligand (separated by space)") 30 | parser.add_argument('--out',help='output file (default stdout)',type=argparse.FileType('w'),default=sys.stdout) 31 | 32 | 33 | args = parser.parse_args() 34 | 35 | (target_names,targets) = clustering.readPDBfiles(args.pdbfiles) 36 | target_smiles = get_smiles(target_names, args.pdbfiles) 37 | 38 | for (name, target, smi) in zip(target_names, targets, target_smiles): 39 | args.out.write('%s %s %s\n'%(name, smi, ' '.join(target))) 40 | -------------------------------------------------------------------------------- /counterexample_generation_jobs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | ''' 4 | This is a script which will generate a file of commands for gnina to use cnn_minimze to generate iterative training poses. 5 | 6 | ASSUMPTIONS 7 | i) assumes all receptors are PDB files IE end in .pdb 8 | ii) Assumes all docked poses or outputs from gnina will be SDF files. 9 | iii) The crystal ligand filenames are formatted PDBid_LignameLIGSUFFIX 10 | iv) assumes file format is ROOT/POCKET/FILES 11 | v) Will generate a line for every identified crystal ligand with every identified receptor in POCKET -- i.e. crossdocking. 12 | vi) Assumes ligands will have the name of their corresponding crystal ligand file present in their filename. (This is especially important is using docked poses.) 13 | vii) Will generate REC_LIG_lig_it#_docked.sdf files as output. (If using docked poses as well, they will have their name will have extra _it#_ parts in it, the current it# will be the leftmost one) 14 | ''' 15 | 16 | 17 | import os, argparse, glob, re 18 | 19 | def get_receptors(root,rec_id): 20 | all_pdbs=glob.glob(root+'*.pdb') 21 | identifier=re.compile(rec_id) 22 | recs=[x for x in all_pdbs if re.match(identifier,x.split('/')[-1])] 23 | return recs 24 | 25 | def get_ligands(root,lig_suffix): 26 | all_ligs=glob.glob(root+'*'+lig_suffix) 27 | return all_ligs 28 | 29 | def generate_line(receptor,ligand,outname,crystal_ligand,seed,num_modes,builtin_cnn,supplied_cnn=None,supplied_weights=None): 30 | if bool(supplied_cnn) and bool(supplied_weights): 31 | return(f'gnina -r {receptor} -l {ligand} -o {outname} --autobox_ligand {crystal_ligand} --seed {seed} --gpu --minimize --cnn_scoring refinement --num_modes {num_modes} --cnn_model {supplied_cnn} --cnn_weights {supplied_weights}\n') 32 | else: 33 | return(f'gnina -r {receptor} -l {ligand} -o {outname} --autobox_ligand {crystal_ligand} --seed {seed} --gpu --minimize --cnn_scoring refinement --num_modes {num_modes} --cnn {builtin_cnn}\n') 34 | 35 | #grabbing the arguments 36 | parser=argparse.ArgumentParser(description='Create cnn_minimize jobs for a dataset. Assumes dataset file structure is //') 37 | parser.add_argument('-o','--outfile',type=str,required=True,help='Name for gnina job commands output file.') 38 | parser.add_argument('-r','--root',default='./',help='ROOT for data directory structure. Defaults to current working directory.') 39 | parser.add_argument('-ri','--rec_id',default='...._._rec.pdb',help='Regular expression to identify the receptor PDB. Defaults to ...._._rec.pdb') 40 | parser.add_argument('-cs','--crystal_suffix',default='_lig.pdb',help='Expresssion to glob the crystal ligand PDB. Defaults to _lig.pdb. Assumes filename is PDBid_LignameLIGSUFFIX') 41 | parser.add_argument('-ds','--docked_suffix',default='_tt_docked.sdf',help='Expression to glob docked poses. These contain the poses that need to be minimized. Default is "_tt_docked.sdf"') 42 | parser.add_argument('-i','--iteration',type=int,required=True,help='Sets what iteration number we are doing. Adds _it#_docked.sdf to the output file for the gnina job line.') 43 | parser.add_argument('--num_modes',type=int,default=20,help='Sets the --num_modes argument for the gnina command. Defaults to 20.') 44 | parser.add_argument('--cnn',type=str, default='dense',help='Sets the --cnn command for the gnina command. Defaults to dense. Must be dense, general_default2018, or crossdock_default2018.') 45 | parser.add_argument('--cnn_model',type=str,default=None,help='Override --cnn with a user provided caffe model file. If used, requires the user to pass in a weights file as well.') 46 | parser.add_argument('--cnn_weights',type=str,default=None,help='The weights file to use with the supplied caffemodel file.') 47 | parser.add_argument('--seed',default=42,type=int,help='Seed for the gnina commands. Defaults to 42') 48 | parser.add_argument('--dirs',type=str,default=None,help='Supplied file containing a subset of the dataset (one pocket per line). Default behavior is to do every directory.') 49 | args=parser.parse_args() 50 | 51 | #double checking that the arguments are compatible 52 | if args.cnn_model: 53 | assert bool(args.cnn_weights),"Didn't set cnn_weights to go with cnn_model" 54 | else: 55 | assert args.cnn in set(['dense','general_default2018','crossdock_default2018']),"Must have built-in cnn be dense, general_default2018, or crossdock_default2018" 56 | assert args.num_modes>1,"Need to set num_modes to a positive integer." 57 | assert args.seed>0,"Need a positive seed." 58 | assert args.iteration>0,"Need an iteration number >=1." 59 | 60 | 61 | #now we begin. 62 | #Step 1 -- assemble all of the directories that we will be using. 63 | dataroot=sys.path.join(args.root,'') 64 | todo=glob.glob(dataroot+'*/') 65 | 66 | if args.dirs: 67 | subdirs=open(args.dirs).readlines() 68 | subdirs=[x.rstrip() for x in subdirs] 69 | subdirs=set(subdirs) 70 | todo=[x for x in todo if x.split('/')[-2] in subdirs] 71 | 72 | #Step 2 -- main loop of the script 73 | #set the iteration plugin variable 74 | itname='_it'+str(args.iteration) 75 | 76 | # We loop over the pockets 77 | #TODO -- change to only do the docked poses 78 | with open(args.outfile,'w') as outfile: 79 | for pocket_root in todo: 80 | #grab the receptors 81 | recs=get_receptors(pocket_root,args.rec_id) 82 | 83 | #grab all of the crystal ligands 84 | cr_ligs=get_ligands(pocket_root,args.crystal_suffix) 85 | 86 | #Grab all of the docked poses 87 | ligs=get_ligands(pocket_root,args.docked_suffix) 88 | for r in recs: 89 | for cl in cr_ligs: 90 | #determine which ligands will work -- IE which ligands have the crystal ligand indentifier in their name, and which ligands have the receptor in their name. 91 | lig_todo=[l for l in ligs if cl.split('/')[-1].split(args.crystal_suffix)[0] in l] 92 | lig_todo=[l for l in lig_todo if r.split('/')[-1].split('.pdb')[0] in l] 93 | for ligname in lig_todo: 94 | #generate the output filename 95 | #if args.docked_suffix and args.docked_suffix in ligname: 96 | outname=ligname.replace(args.docked_suffix,itname+args.docked_suffix) 97 | #else: 98 | # rec_part=r.split('.pdb')[0]+'_' 99 | # lig_part=ligname.split('/')[-1].split(args.crystal_suffix)[0] 100 | # outname=rec_part+lig_part+'_lig_'+itname+'docked.sdf' 101 | 102 | outfile.write(generate_line(receptor=r,ligand=ligname,outname=outname,crystal_ligand=cl,seed=args.seed,num_modes=args.num_modes,builtin_cnn=args.cnn,supplied_cnn=args.cnn_model,supplied_weights=args.cnn_weights)) 103 | 104 | -------------------------------------------------------------------------------- /create_caches.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | '''Takes a bunch of types training files. First argument is what index the receptor starts on 4 | (ligand is assumed to be right after). Reads in the gninatypes files specified in these types 5 | files and writes out two monolithic receptor and ligand cache files for use with recmolcache 6 | and ligmolcache molgrid options''' 7 | 8 | import os, sys 9 | import struct, argparse 10 | 11 | def writemol(root, mol, out): 12 | '''mol is gninatypes file, write it in the appropriate binary format to out''' 13 | fname = root+'/'+mol 14 | try: 15 | with open(fname,'rb') as gninatype: 16 | if len(fname) > 255: 17 | print("Skipping",mol,"since filename is too long") 18 | return 19 | s = bytes(mol) 20 | out.write(struct.pack('b',len(s))) 21 | out.write(s) 22 | data = gninatype.read() 23 | assert(len(data) % 16 == 0) 24 | natoms = len(data)/16 25 | out.write(struct.pack('i',natoms)) 26 | out.write(data) 27 | except Exception as e: 28 | print(mol) 29 | print(e) 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('-c', '--col', required=True,type=int,help='Column receptor starts on') 33 | parser.add_argument('--recmolcache', default='rec.molcache',type=str,help='Filename of receptor cache') 34 | parser.add_argument('--ligmolcache', default='lig.molcache',type=str,help='Filename of ligand cache') 35 | parser.add_argument('-d','--data_root',type=str,required=False,help="Root folder for relative paths in train/test files",default='') 36 | parser.add_argument('fnames',nargs='+',type=str,help='types files to process') 37 | 38 | args = parser.parse_args() 39 | 40 | recout = open(args.recmolcache,'wb') 41 | ligout = open(args.ligmolcache,'wb') 42 | 43 | seenlig = set() 44 | seenrec = set() 45 | for fname in args.fnames: 46 | for line in open(fname): 47 | vals = line.split() 48 | rec = vals[args.col] 49 | ligs = vals[args.col+1:] 50 | 51 | if rec not in seenrec: 52 | seenrec.add(rec) 53 | writemol(args.data_root, rec, recout) 54 | 55 | for lig in ligs: 56 | if lig == '#': 57 | break 58 | if lig not in seenlig: 59 | seenlig.add(lig) 60 | writemol(args.data_root, lig, ligout) 61 | 62 | -------------------------------------------------------------------------------- /create_caches2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | '''Takes a bunch of types training files. First argument is what index the receptor starts on 4 | (ligands are assumed to be right after). Reads in the gninatypes files specified in these types 5 | files and writes out two monolithic receptor and ligand cache files in version 2 format. 6 | 7 | Version 2 is optimized for memory mapped storage of caches. keys (file names) are stored 8 | first followed by dense storage of values (coordinates and types). 9 | ''' 10 | 11 | import os, sys 12 | import struct, argparse, traceback 13 | import multiprocessing 14 | 15 | mols_to_read = multiprocessing.Queue() 16 | mols_to_write = multiprocessing.Queue() 17 | N = multiprocessing.cpu_count()*2 18 | 19 | def read_data(data_root): 20 | '''read a types file and put it in mols_to_write''' 21 | while True: 22 | sys.stdout.flush() 23 | mol = mols_to_read.get() 24 | if mol == None: 25 | break 26 | fname = mol 27 | if len(data_root): 28 | fname = data_root+'/'+mol 29 | try: 30 | with open(fname,'rb') as gninatype: 31 | data = gninatype.read() 32 | assert(len(data) % 16 == 0) 33 | if len(data) == 0: 34 | print(fname,"EMPTY") 35 | else: 36 | mols_to_write.put((mol,data)) 37 | except Exception as e: 38 | print(fname) 39 | print(e) 40 | mols_to_write.put(None) 41 | 42 | def fill_queue(molfiles): 43 | 'thread for filling mols_to_read' 44 | for mol in molfiles: 45 | mols_to_read.put(mol) 46 | for _ in range(N): 47 | mols_to_read.put(None) 48 | 49 | def create_cache2(molfiles, data_root, outfile): 50 | '''Create an outfile molcache2 file from the list molfiles stored at data_root.''' 51 | out = open(outfile,'wb') 52 | #first byte is for versioning 53 | out.write(struct.pack('i',-1)) 54 | out.write(struct.pack('L',0)) #placeholder for offset to keys 55 | 56 | filler = multiprocessing.Process(target=fill_queue,args=(molfiles,)) 57 | filler.start() 58 | 59 | 60 | readers = multiprocessing.Pool(N) 61 | for _ in range(N): 62 | readers.apply_async(read_data,(data_root,)) 63 | 64 | offsets = dict() #indxed by mol, location of data 65 | #start writing molecular data 66 | endcnt = 0 67 | while True: 68 | moldata = mols_to_write.get() 69 | if moldata == None: 70 | endcnt += 1 71 | if endcnt == N: 72 | break 73 | else: 74 | continue 75 | (mol,data) = moldata 76 | offsets[mol] = out.tell() 77 | natoms = len(data)//16 78 | out.write(struct.pack('i',natoms)) 79 | out.write(data) 80 | 81 | start = out.tell() #where the names start 82 | for mol in molfiles: 83 | if len(mol) > 255: 84 | print("Skipping",mol,"since filename is too long") 85 | continue 86 | if mol not in offsets: 87 | print("SKIPPING",mol,"since failed to read it in") 88 | continue 89 | s = bytes(mol, encoding='UTF-8') 90 | out.write(struct.pack('B',len(s))) 91 | out.write(s) 92 | out.write(struct.pack('L',offsets[mol])) 93 | 94 | #now set start 95 | out.seek(4) 96 | out.write(struct.pack('L',start)) 97 | out.seek(0,os.SEEK_END) 98 | out.close() 99 | 100 | 101 | 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument('-c', '--col', required=True,type=int,help='Column receptor starts on') 104 | parser.add_argument('--recmolcache', default='rec.molcache2',type=str,help='Filename of receptor cache') 105 | parser.add_argument('--ligmolcache', default='lig.molcache2',type=str,help='Filename of ligand cache') 106 | parser.add_argument('-d','--data_root',type=str,required=False,help="Root folder for relative paths in train/test files",default='') 107 | parser.add_argument('fnames',nargs='+',type=str,help='types files to process') 108 | 109 | args = parser.parse_args() 110 | 111 | #load all file names into memory 112 | seenlig = set() 113 | seenrec = set() 114 | for fname in args.fnames: 115 | for line in open(fname): 116 | vals = line.split() 117 | rec = vals[args.col] 118 | ligs = vals[args.col+1:] 119 | 120 | if rec not in seenrec: 121 | seenrec.add(rec) 122 | 123 | for lig in ligs: 124 | if lig == '#' or lig.startswith('#'): 125 | break 126 | if lig not in seenlig: 127 | seenlig.add(lig) 128 | 129 | create_cache2(sorted(list(seenrec)), args.data_root, args.recmolcache) 130 | create_cache2(sorted(list(seenlig)), args.data_root, args.ligmolcache) 131 | 132 | 133 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Documentation 2 | -------------------------------------------------------------------------------- /docs/Using_the_Cluster.md: -------------------------------------------------------------------------------- 1 | # Using the Cluster 2 | 3 | Requires account on Dvorak. Assumes knowledge of navigating system directories and the following terminal commands: cd, ls, pwd, mkdir, mv. 4 | 5 | #### GitHub Gnina Repository: https://github.com/gnina 6 | _____ 7 | **1) Logging onto cluster:** 8 | 9 | `ssh username@gpu.csb.pitt.edu` 10 | To exit cluster/server, type `exit` 11 | 12 | If not in lab, SSH into dvorak (username@dvorak.csb.pitt.edu) first, and from there to the gpu cluster. VPN tool required (e.g. Pulse Secure). 13 | 14 | **2) Read through the README file located on the cluster.** 15 | 16 | **3) Place all necessary files (model, data, python scripts, pbs script) into a single directory on local machine.** 17 | 18 | * Model: https://github.com/gnina/models 19 | * File normally ends in .model or .prototxt 20 | * Scripts: https://github.com/gnina/scripts 21 | * Ensure that all files ending in .py are in your directory. 22 | * Data: Refer to Dr. Koes 23 | https://github.com/gnina/models/tree/master/data 24 | 25 | **4) In the model file, in all layers of type “MolGridData” change the root folder to:** 26 | "/net/pulsar/home/koes/dkoes/PDBbind/refined-set/" 27 | 28 | **5) Install python dependencies.** 29 | ``` 30 | pip install --user -I numpy scipy sklearn scikit-image google protobuf psutil Pillow 31 | ``` 32 | 33 | **6) Include the following exports in your PBS script (before the python command):** 34 | Refer to provided pbs script for a complete template. 35 | 36 | 37 | ``` 38 | export PATH=/usr/bin:$PATH 39 | export LD_LIBRARY_PATH=/net/pulsar/home/koes/dkoes/local/lib:/usr/lib64:/usr/lib/x86_64-linux-gnu:/usr/local/cuda-9.0/lib64 40 | export PYTHONPATH=/net/pulsar/home/koes/dkoes/local/python:$PYTHONPATH 41 | ``` 42 | 43 | **7) Copy working directory onto server/back to local machine (scp command):** 44 | ``` 45 | scp -r ~/Desktop/test_folder username@gpu.csb.pitt.edu:~ 46 | scp -r test_folder username@perigee/apogee.csb.pitt.edu:~/Desktop 47 | ``` 48 | 49 | **8) Test on cluster nodes, not head node:** 50 | Launch job with `qsub script.pbs` from directory with required files. Use `qstat -au username` to check job status. 51 | 52 | #### Do NOT run python directly in terminal after ssh. 53 | 54 | _____ 55 | 56 | ## Troubleshooting 57 | 58 | Use `cat` to read output file (located in folder `qsub` was run in) and `pip` to manually install missing python packages (e.g. numpy): 59 | ``` 60 | pip install -I --user [package] 61 | ``` 62 | 63 | Launch an interactive `qsub` session to get a commandline on a cluster node: 64 | ``` 65 | qsub -I -l nodes=1:ppn=1:gpus=1 -q dept_gpu 66 | ``` 67 | 68 | _____ 69 | 70 | ## Quick Tips 71 | **Viewing files in terminal:** 72 | ``` 73 | cat /path/to/file 74 | ``` 75 | **Editing files in terminal:** 76 | ``` 77 | vi /path/to/file 78 | ``` 79 | 80 | #### Vi Basics 81 | Default is command mode. 82 | * `x` to delete character under cursor 83 | * `v` to start selection (for copy/cut operation) 84 | * Move cursor to select, then `y` to copy or `x` to cut 85 | * Position cursor, then `p` to paste 86 | * Save & exit `:wq` (MUST BE IN COMMAND MODE) 87 | * Exit `:q!` 88 | 89 | To enter insert mode press **i** (to type normally), and **Esc** to go back to command mode. 90 | 91 | -------------------------------------------------------------------------------- /generate_unique_lig_poses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ''' 3 | This script exists to generate the unique ligand sdf files in a given pocket, by grabbing all of the existing docked poses 4 | and then calculating which of them are less than a certain RMSD threshold between one another. 5 | 6 | Assumptions: 7 | i) The crystal ligand files are named _ 8 | ii) the directory structure of the data is // 9 | iii) you have obrms installed and accessible from the commandline 10 | 11 | Input: 12 | i) the Pocket directory you are working on 13 | ii) the root of the pocket directories 14 | iii) the suffix for the docked poses 15 | iv) the suffix for the crystal poses 16 | v) the desired suffix for the "unique pose sdf" 17 | vi) the threshold RMSD to determine unique poses 18 | 19 | Output: 20 | i) a file for each ligand in the pocket containing the unique poses for that ligand. 21 | ''' 22 | 23 | import argparse, subprocess, glob, re, os 24 | from rdkit.Chem import AllChem as Chem 25 | import pandas as pd 26 | 27 | def check_exists(filename): 28 | if os.path.isfile(filename) and os.path.getsize(filename)>0: 29 | return True 30 | else: 31 | return False 32 | 33 | def run_obrms_cross(filename): 34 | ''' 35 | This function returns a pandas dataframe of the RMSD between every pose and every other pose, which is generated using obrms -x 36 | ''' 37 | 38 | csv=subprocess.check_output('obrms -x '+filename,shell=True) 39 | csv=str(csv,'utf-8').rstrip().split('\n') 40 | data=pd.DataFrame([x.split(',')[1:] for x in csv],dtype=float) 41 | return data 42 | 43 | parser=argparse.ArgumentParser(description='Create ligname files for use with generate_counterexample_typeslines.py.') 44 | parser.add_argument('-p','--pocket',type=str,required=True,help='Name of the pocket that you will be generating the file for.') 45 | parser.add_argument('-r','--root',type=str,required=True,help='PATH to the ROOT of the pockets.') 46 | parser.add_argument('-ds','--docked_suffix',default='_tt_docked.sdf', help='Expression to glob docked poses. These contain the poses that need to be uniqified. Default is "_tt_docked.sdf"') 47 | parser.add_argument('-cs','--crystal_suffix',default='_lig.pdb', help='Expression to glob the crystal ligands. Default is "_lig.pdb"') 48 | parser.add_argument('-os','--out_suffix',required=True,help='End of the filename for LIGNAME. This will be the --old_unique_suffix for generate_counterexample_typeslines.py.') 49 | parser.add_argument('--unique_threshold',default=0.25,help='RMSD threshold for unique poses. IE poses with RMSD > thresh are considered unique. Defaults to 0.25.') 50 | args=parser.parse_args() 51 | 52 | assert args.unique_threshold >0, "Unique RMSD threshold needs to be positive" 53 | 54 | #setting the myroot variable 55 | myroot=os.path.join(args.root,args.pocket,'') 56 | 57 | 58 | #1) gather the crystal files & pull out the crystal names present in the pocket 59 | crystal_files=glob.glob(myroot+'*'+args.crystal_suffix) 60 | crystal_names=set([x.split('/')[-1].split(args.crystal_suffix)[0].split('_')[1] for x in crystal_files]) 61 | 62 | #2) main loop 63 | for cr_name in crystal_names: 64 | if cr_name!='iqz': 65 | continue 66 | print(cr_name) 67 | #i) grab all of the docked files 68 | docked_files=glob.glob(myroot+'*_'+cr_name+'_*'+args.docked_suffix) 69 | print(docked_files) 70 | 71 | #ii) make sure that the "working sdf file" does not exist 72 | sdf_name=myroot+'___.sdf' 73 | if check_exists(sdf_name): 74 | os.remove(sdf_name) 75 | 76 | #iii) write all of the previously docked poses into the 'working sdf file' 77 | w=Chem.SDWriter(sdf_name) 78 | for file in docked_files: 79 | supply=Chem.SDMolSupplier(file,sanitize=False) 80 | for mol in supply: 81 | w.write(mol) 82 | w=None 83 | 84 | #iv) run obrms cross to calculate the RMSD between every pair of poses 85 | unique_data=run_obrms_cross(sdf_name) 86 | 87 | #v) determine the "unique poses" 88 | assignments={} 89 | for (r,row) in unique_data.iterrows(): 90 | if r not in assignments: 91 | for simi in row[row0: 27 | return True 28 | else: 29 | return False 30 | 31 | 32 | def get_atoms(filename): 33 | ''' 34 | Function that reads the atom types from filename & returns them as a list 35 | ''' 36 | listo=[] 37 | with open(filename) as infile: 38 | for line in infile: 39 | item=line.rstrip() 40 | listo.append(item) 41 | 42 | return listo 43 | 44 | def make_points(atom,val_range,root,mapping): 45 | ''' 46 | Function that makes the points needed for the types file. 47 | ''' 48 | 49 | if not os.path.isdir(root+atom): 50 | os.mkdir(root+atom) 51 | 52 | counter=0 53 | for x in val_range: 54 | for y in val_range: 55 | for z in val_range: 56 | pos=[x,y,z] 57 | pos=struct.pack('f'*len(pos),*pos) 58 | identity=[mapping] 59 | identity=struct.pack('i'*len(identity),*identity) 60 | with open(root+atom+'/'+atom+'_'+str(counter)+'.gninatypes','wb') as f: 61 | f.write(pos) 62 | f.write(identity) 63 | counter+=1 64 | 65 | def make_types(atom, root, receptor): 66 | ''' 67 | Function that writes a types file for all the points created from make_points in root 68 | 69 | Returns the name of the file 70 | ''' 71 | def atoi(text): 72 | return int(text) if text.isdigit() else text 73 | 74 | def natural_keys(text): 75 | return [ atoi(c) for c in re.split(r'(\d+)', text) ] 76 | 77 | gninatypes=glob.glob(root+atom+'/'+atom+'*.gninatypes') 78 | gninatypes.sort(key=natural_keys) 79 | filename=root+receptor.split('_0.gnina')[0]+'_'+atom+'.types' 80 | with open(filename,'w') as out: 81 | for g in gninatypes: 82 | out.write('1 3.0 0.00 '+receptor+' '+g+'\n') 83 | 84 | return filename 85 | 86 | def make_dx(filename, num_on_axis, min_point, val_delta): 87 | ''' 88 | Function that takes the filename IE output of jobs, and makes a dx file from the results for visualization 89 | ''' 90 | 91 | with open(filename) as fin: 92 | data=fin.readlines() 93 | 94 | if len(data) == 0: 95 | return None,None 96 | l=filename.split('_predictscores')[0] 97 | 98 | pattern=re.compile("^[0-9]") 99 | data=[float(x.split()[0]) for x in data if pattern.match(x)] 100 | scores=np.array(data) 101 | dxdata=scores.reshape(num_on_axis,num_on_axis,num_on_axis) 102 | test=dxdata.round(4) 103 | g=gridData.Grid(dxdata,origin=min_point, delta=val_delta) 104 | g.export(l+"grid","DX") 105 | return dxdata,test 106 | 107 | def gninatyper(pdbfilename): 108 | ''' 109 | Function that takes in a pdbfile and converts it to a gninatypes file via gninatyper 110 | 111 | Returns 1 on failed gninatyper 112 | Returns newfilename on success. 113 | ''' 114 | 115 | newname=pdbfilename.split('.')[0] 116 | 117 | try: 118 | subprocess.call('gninatyper '+pdbfilename+' '+newname,shell=True) 119 | except: 120 | return 1 121 | 122 | return newname+'_0.gninatypes' 123 | 124 | if __name__=='__main__': 125 | args=parse_args() 126 | 127 | #perform arguments check to terminate early? 128 | 129 | #sanitize inputs 130 | if not os.path.isdir(args.typesroot): 131 | os.mkdir(args.typesroot) 132 | 133 | if not os.path.isdir(args.dataroot): 134 | os.mkdir(args.dataroot) 135 | 136 | if not os.path.isdir(args.dataroot) and args.make_dx: 137 | print('Error! Specified plotting, but the dataroot does not exist!') 138 | print('Could not find the directory: '+args.dataroot) 139 | sys.exit() 140 | 141 | if not path_checker(args.recatoms) or not path_checker(args.ligatoms): 142 | print('Error!') 143 | print('Could not locate either: '+args.recatoms+' or '+args.ligatoms) 144 | sys.exit() 145 | 146 | if not path_checker(args.model): 147 | print('Error!') 148 | print('Could not locate: '+args.model) 149 | sys.exit() 150 | 151 | if not path_checker(args.weights): 152 | print('Error!') 153 | print('Could not locate: '+args.weights) 154 | sys.exit() 155 | 156 | if not path_checker(args.test_pdb): 157 | print('Error!') 158 | print('Could not locate: '+args.test_pdb) 159 | sys.exit() 160 | 161 | #Now we are ready to start the program! 162 | 163 | #making atom mapping -- BLAH hardcoded. Not sure if this is changing, but is critical to functionality 164 | inv_map = { 165 | 'Hydrogen':0, 166 | 'PolarHydrogen':1, 167 | 'AliphaticCarbonXSHydrophobe':2 , 168 | 'AliphaticCarbonXSNonHydrophobe':3 , 169 | 'AromaticCarbonXSHydrophobe':4 , 170 | 'AromaticCarbonXSNonHydrophobe':5 , 171 | 'Nitrogen':6, 172 | 'NitrogenXSDonor':7, 173 | 'NitrogenXSDonorAcceptor':8, 174 | 'NitrogenXSAcceptor':9, 175 | 'Oxygen':10, 176 | 'OxygenXSDonor':11, 177 | 'OxygenXSDonorAcceptor':12, 178 | 'OxygenXSAcceptor':13, 179 | 'Sulfur':14, 180 | 'SulfurAcceptor':15, 181 | 'Phosphorus':16, 182 | 'Fluorine':17, 183 | 'Chlorine':18, 184 | 'Bromine':19, 185 | 'Iodine':20, 186 | 'Magnesium':21, 187 | 'Manganese':22, 188 | 'Zinc':23, 189 | 'Calcium':24, 190 | 'Iron':25, 191 | 'GenericMetal':26, 192 | 'Boron':27, 193 | } 194 | 195 | #now we need to figure out which atom types we are working with 196 | lig_atoms=get_atoms(args.ligatoms) 197 | rec_atoms=get_atoms(args.recatoms) 198 | todo=list(set(lig_atoms+rec_atoms)) 199 | 200 | #making sure that the roots are formatted appropriately 201 | types_root=args.typesroot 202 | if types_root[-1]!='/': 203 | types_root+='/' 204 | dataroot=args.dataroot 205 | if dataroot[-1]!='/': 206 | dataroot+='/' 207 | 208 | prefix=args.test_pdb.split('/')[-1].split('.pdb')[0] 209 | mprefix=args.model.split('/')[-1].split('.model')[0] 210 | 211 | #figure out the dimensions that we are working with 212 | rad=args.cube_length/2.0 213 | testpos=np.linspace(0,rad,args.num_points) 214 | testneg=np.linspace(-1*rad,0,args.num_points) 215 | val_range=list(testneg[:-1])+list(testpos) 216 | num_on_axis=len(val_range) 217 | minimum_point=(-1*rad, -1*rad, -1*rad) 218 | val_delta=val_range[1]-val_range[0] 219 | 220 | #The bulk of the script 221 | if args.make_dx: 222 | for atom in todo: 223 | print('Working on '+atom) 224 | #make the dx file 225 | data_name=dataroot+prefix+'_rec_'+atom+'_lig_'+mprefix+'_predictscores' 226 | _,_ = make_dx(data_name, num_on_axis, minimum_point, val_delta) 227 | print('Made dx file in: '+dataroot) 228 | else: 229 | with open(args.outname,'w') as outfile: 230 | for atom in todo: 231 | print('Working on '+atom) 232 | 233 | #make the points 234 | make_points(atom, val_range, types_root, inv_map[atom]) 235 | print('Made points in: '+types_root+atom) 236 | 237 | #make the gninatypes file 238 | gninatypes_filename=gninatyper(args.test_pdb) 239 | if gninatypes_filename==1: 240 | print('Error with gninatyper!') 241 | sys.exit() 242 | 243 | if not path_checker(gninatypes_filename): 244 | print('Error!') 245 | print(gninatypes_filename+' is an empty file!') 246 | sys.exit() 247 | 248 | #then make the files 249 | working_name=make_types(atom, types_root, gninatypes_filename) 250 | print('Made typesfile in: '+types_root) 251 | 252 | #and write the newline 253 | outfile.write('$GNINASCRIPTSDIR/predict.py -m '+args.model+' -w '+args.weights+' -i '+working_name+' --rotation 100 > '+dataroot+prefix+'_rec_'+atom+'_lig_'+mprefix+'_predictscores\n') 254 | -------------------------------------------------------------------------------- /models/lenet1.template: -------------------------------------------------------------------------------- 1 | layer { 2 | name: "data" 3 | type: "NDimData" 4 | top: "data" 5 | top: "label" 6 | include { 7 | phase: TEST 8 | } 9 | ndim_data_param { 10 | source: "TESTFILE" 11 | batch_size: 1 12 | shape { 13 | dim: 34 14 | dim: 49 15 | dim: 49 16 | dim: 49 17 | } 18 | shuffle: false 19 | balanced: false 20 | } 21 | } 22 | layer { 23 | name: "data" 24 | type: "NDimData" 25 | top: "data" 26 | top: "label" 27 | include { 28 | phase: TRAIN 29 | } 30 | ndim_data_param { 31 | source: "TRAINFILE" 32 | batch_size: 10 33 | shape { 34 | dim: 34 35 | dim: 49 36 | dim: 49 37 | dim: 49 38 | } 39 | shuffle: true 40 | balanced: true 41 | rotate: 24 42 | } 43 | } 44 | 45 | layer { 46 | name: "convdownsample" 47 | type: "Convolution" 48 | bottom: "data" 49 | top: "convdownsample" 50 | convolution_param { 51 | num_output: 33 52 | kernel_size: 5 53 | stride: 3 54 | weight_filler { 55 | type: "xavier" 56 | } 57 | } 58 | } 59 | layer { 60 | name: "conv1" 61 | type: "Convolution" 62 | bottom: "convdownsample" 63 | top: "conv1" 64 | convolution_param { 65 | num_output: 20 66 | kernel_size: 5 67 | weight_filler { 68 | type: "xavier" 69 | } 70 | } 71 | } 72 | layer { 73 | name: "conv2" 74 | type: "Convolution" 75 | bottom: "conv1" 76 | top: "conv2" 77 | convolution_param { 78 | num_output: 50 79 | kernel_size: 5 80 | weight_filler { 81 | type: "xavier" 82 | } 83 | } 84 | } 85 | layer { 86 | name: "ip1" 87 | type: "InnerProduct" 88 | bottom: "conv2" 89 | top: "ip1" 90 | inner_product_param { 91 | num_output: 500 92 | weight_filler { 93 | type: "xavier" 94 | } 95 | } 96 | } 97 | layer { 98 | name: "drop" 99 | type: "Dropout" 100 | bottom: "ip1" 101 | top: "drop" 102 | dropout_param { 103 | dropout_ratio: 0.5 104 | } 105 | } 106 | layer { 107 | name: "sig1" 108 | type: "Sigmoid" 109 | bottom: "drop" 110 | top: "sig1" 111 | } 112 | layer { 113 | name: "ip2" 114 | type: "InnerProduct" 115 | bottom: "sig1" 116 | top: "ip2" 117 | inner_product_param { 118 | num_output: 2 119 | weight_filler { 120 | type: "xavier" 121 | } 122 | } 123 | } 124 | 125 | layer { 126 | name: "loss" 127 | type: "SoftmaxWithLoss" 128 | bottom: "ip2" 129 | bottom: "label" 130 | top: "loss" 131 | } 132 | layer { 133 | name: "output" 134 | type: "Softmax" 135 | bottom: "ip2" 136 | top: "output" 137 | } 138 | layer { 139 | name: "accuracy" 140 | type: "Accuracy" 141 | bottom: "ip2" 142 | bottom: "label" 143 | top: "accuracy" 144 | include { 145 | phase: TEST 146 | } 147 | } 148 | layer { 149 | name: "labelout" 150 | type: "Reshape" 151 | bottom: "label" 152 | top: "labelout" 153 | include { 154 | phase: TEST 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | import glob, re, sklearn, collections, argparse, sys, os 8 | import sklearn.metrics 9 | import scipy 10 | import caffe 11 | from caffe.proto.caffe_pb2 import NetParameter 12 | import google.protobuf.text_format as prototxt 13 | from train import evaluate_test_net 14 | 15 | 16 | def write_model_file(model_file, template_file, test_file, root_folder): 17 | param = NetParameter() 18 | with open(template_file, 'r') as f: 19 | prototxt.Merge(f.read(), param) 20 | for layer in param.layer: 21 | if layer.molgrid_data_param.source == 'TESTFILE': 22 | layer.molgrid_data_param.source = test_file 23 | if layer.molgrid_data_param.root_folder == 'DATA_ROOT': 24 | layer.molgrid_data_param.root_folder = root_folder 25 | with open(model_file, 'w') as f: 26 | f.write(str(param)) 27 | 28 | 29 | def predict(args): 30 | '''Return yscore and/or y_predaff with rest of input line for each example''' 31 | if args.gpu >= 0: 32 | caffe.set_device(args.gpu) 33 | caffe.set_mode_gpu() 34 | test_model = 'predict.%d.prototxt' % os.getpid() 35 | write_model_file(test_model, args.model, args.input, args.data_root) 36 | test_net = caffe.Net(test_model, args.weights, caffe.TEST) 37 | with open(args.input, 'r') as f: 38 | lines = f.readlines() 39 | result = evaluate_test_net(test_net, len(lines), args.rotations) 40 | auc = result.auc 41 | y_true = result.y_true 42 | y_score = result.y_score 43 | loss = result.loss 44 | rmsd = result.rmsd 45 | pearsonr = None 46 | y_affinity = result.y_aff 47 | y_predaff = result.y_predaff 48 | 49 | # auc, y_true, y_score, loss, rmsd, y_affinity, y_predaff = result 50 | 51 | if 'labelout' in test_net.outputs: 52 | assert np.all(y_true == [float(l.split(' ')[0]) for l in lines]) #check alignment 53 | if 'affout' in test_net.outputs: 54 | for (l,a) in zip(lines,y_affinity): 55 | lval = float(l.split()[1]) 56 | if abs(lval-a) > 0.001: 57 | print("Mismatching values",a,l) 58 | sys.exit(-1) 59 | 60 | if rmsd != None and auc != None: 61 | output_lines = [t for t in zip(y_score, y_predaff, lines)] 62 | elif rmsd != None: 63 | output_lines = [t for t in zip(y_predaff, lines)] 64 | elif auc != None: 65 | output_lines = [t for t in zip(y_score, lines)] 66 | 67 | 68 | #this is all awkward and should be rewritten with a smarter approach than munging strings 69 | if args.max_score or args.max_affinity: 70 | output_lines = maxLigandScore(output_lines, args.max_affinity) 71 | #have to recalculate RMSD and AUC 72 | 73 | if auc != None: 74 | y_true = [float(line[-1].split()[0]) for line in output_lines] 75 | y_score = [line[0] for line in output_lines] 76 | auc = sklearn.metrics.roc_auc_score(y_true, y_score) 77 | if rmsd != None: 78 | y_affinity = [float(line[-1].split()[1]) for line in output_lines] 79 | y_predaff = [line[1] for line in output_lines] 80 | rmsd = np.sqrt(sklearn.metrics.mean_squared_error(np.abs(y_affinity),y_predaff)) 81 | pearsonr = scipy.stats.pearsonr(np.abs(y_affinity),y_predaff)[0] 82 | 83 | if not args.keep: 84 | os.remove(test_model) 85 | return output_lines,auc,rmsd,pearsonr 86 | 87 | def predict_lines(args): 88 | '''Return previous format of a list of strings corresponding to the output lines of a prediction file''' 89 | predictions = predict(args) 90 | lines = [] 91 | for line in predictions[0]: 92 | l = '' 93 | for val in line[:-1]: 94 | l += '%f '%val 95 | l += '| %s' % line[-1] 96 | lines.append(l) 97 | if predictions[1] != None: 98 | lines.append('# AUC %f\n'%predictions[1]) 99 | if predictions[2] != None: 100 | lines.append('# rmsd %f\n'%predictions[2]) 101 | if predictions[3] != None: 102 | lines.append('# pearsonr %f\n'%predictions[3]) 103 | return lines 104 | 105 | def get_ligand_key(rec_path, pose_path): 106 | # no good naming convention, so just use the receptor name 107 | # and each numeric part of the ligand/pose name except for 108 | # the last, which is the pose number of the ligand 109 | rec_dir = os.path.dirname(rec_path) 110 | rec_name = rec_dir.rsplit('/', 1)[-1] 111 | pose_name = os.path.splitext(os.path.basename(pose_path))[0] 112 | pose_name_nums = [] 113 | for i, part in enumerate(pose_name.split('_')): 114 | try: 115 | pose_name_nums.append(int(part)) 116 | except ValueError: 117 | continue 118 | return tuple([rec_name] + pose_name_nums[:-1]) 119 | 120 | 121 | def maxLigandScore(lines, useaff): 122 | #output format: score label [affinity] rec_path pose_path 123 | ligands = {} 124 | for line in lines: 125 | data = line[2].split('#')[0].split() 126 | data = list(line[:2])+data 127 | if len(data) == 4: #only score present 128 | score = float(data[0]) 129 | rec_path = data[2].strip() 130 | pose_path = data[3].strip() 131 | elif len(data) == 5: #only affinity present 132 | score = float(data[0]) 133 | rec_path = data[3].strip() 134 | pose_path = data[4].strip() 135 | elif len(data) == 6: 136 | if useaff: 137 | score = float(data[1]) 138 | else: 139 | score = float(data[0]) 140 | rec_path = data[4].strip() 141 | pose_path = data[5].strip() 142 | else: 143 | print(line) 144 | 145 | key = get_ligand_key(rec_path, pose_path) 146 | if key not in ligands or score > ligands[key][0]: 147 | ligands[key] = (score, line) 148 | return [ligands[key][1] for key in ligands] 149 | 150 | 151 | def parse_args(argv=None): 152 | parser = argparse.ArgumentParser(description='Test neural net on gninatypes data.') 153 | parser.add_argument('-m','--model',type=str,required=True,help="Model template. Must use TESTFILE with unshuffled, unbalanced input. EX: file.model ") 154 | parser.add_argument('-w','--weights',type=str,required=True,help="Model weights (.caffemodel)") 155 | parser.add_argument('-d','--data_root',type=str,required=False,help="Root folder for paths in .types files",default='') 156 | parser.add_argument('-i','--input',type=str,required=True,help="Input .types file to predict") 157 | parser.add_argument('-g','--gpu',type=int,help='Specify GPU to run on',default=-1) 158 | parser.add_argument('-o','--output',type=str,help='Output file name',default=None) 159 | parser.add_argument('-s','--seed',type=int,help='Random seed',default=None) 160 | parser.add_argument('-k','--keep',action='store_true',default=False,help="Don't delete prototxt files") 161 | parser.add_argument('--rotations',type=int,help='Number of rotations; rotatation must be enabled in test net!',default=1) 162 | parser.add_argument('--max_score',action='store_true',default=False,help="take max score per ligand as its score") 163 | parser.add_argument('--max_affinity',action='store_true',default=False,help="take max affinity per ligand as its score") 164 | parser.add_argument('--notcalc_predictions', type=str, default='',help='use file of predictions instead of calculating') 165 | return parser.parse_args(argv) 166 | 167 | 168 | if __name__ == '__main__': 169 | args = parse_args() 170 | if not args.output: 171 | out = sys.stdout 172 | else: 173 | out = open(args.output, 'w') 174 | if args.seed != None: 175 | caffe.set_random_seed(args.seed) 176 | if not args.notcalc_predictions: 177 | predictions = predict_lines(args) 178 | else: 179 | with open(args.notcalc_predictions, 'r') as f: 180 | predictions = f.readlines() 181 | if args.max_score or args.max_affinity: 182 | predictions = maxLigandScore(predictions, args.max_affinity) 183 | 184 | out.writelines(predictions) 185 | 186 | -------------------------------------------------------------------------------- /pymol_arrows.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import sys 5 | import os 6 | import argparse 7 | 8 | 9 | def write_pymol_arrows(base, structs, scale, color, radius, hradius, hlength, threshold): 10 | pymol_file = base + '_arrows.pymol' 11 | lines = [] 12 | arrow_objs = set() 13 | t2 = threshold**2 14 | s2 = scale**2 15 | for i, struct in enumerate(structs): 16 | for j, atom in enumerate(struct): 17 | arrow_obj = base + '_arrow_' + str(j) 18 | arrow_objs.add(arrow_obj) 19 | elem, xi, yi, zi, dx, dy, dz = atom 20 | xf = xi + scale*dx 21 | yf = yi + scale*dy 22 | zf = zi + scale*dz 23 | line = 'cgo_arrow [{}, {}, {}], [{}, {}, {}]'.format(xi, yi, zi, xf, yf, zf) 24 | if len(structs) > 1: 25 | line += ', state={}'.format(i+1) 26 | if radius: 27 | line += ', radius={}'.format(radius) 28 | if hradius > 0: 29 | line += ', hradius={}'.format(hradius) 30 | if hlength > 0: 31 | line += ', hlength={}'.format(hlength) 32 | if color: 33 | line += ', color={}'.format(color) 34 | line += ', name={}'.format(arrow_obj) 35 | if (dx**2 + dy**2 + dz**2)*s2 > t2: 36 | lines.append(line) 37 | arrow_group = base + '_arrows' 38 | line = 'group {}, {}'.format(arrow_group, ' '.join(arrow_objs)) 39 | lines.append(line) 40 | with open(pymol_file, 'w') as f: 41 | f.write('\n'.join(lines)) 42 | 43 | 44 | def xyz_line_to_atom(xyz_line): 45 | fields = xyz_line.split() 46 | elem = fields[0] 47 | x = float(fields[1]) 48 | y = float(fields[2]) 49 | z = float(fields[3]) 50 | dx = float(fields[4]) 51 | dy = float(fields[5]) 52 | dz = float(fields[6]) 53 | return elem, x, y, z, dx, dy, dz 54 | 55 | 56 | def atom_to_pdb_line(atom, idx, dosum): 57 | if not isinstance(idx, int) or idx < 0 or idx > 99999: 58 | raise TypeError('idx must be an integer from 0 to 99999 ({})'.format(idx)) 59 | elem, x, y, z, dx, dy, dz = atom 60 | if len(elem) not in {1, 2}: 61 | raise IndexError('atom elem must be a string of length 1 or 2 ({})'.format(elem)) 62 | if dosum: 63 | d = dx+dy+dz 64 | else: 65 | d = (dx**2 + dy**2 + dz**2)**0.5 66 | return '{:6}{:5} {:4}{:1}{:3} {:1}{:4}{:1} {:8.3f}{:8.3f}{:8.3f}{:6.2f}{:6f} {:2}{:2}' \ 67 | .format('ATOM', idx, '', '', '', '', '', '', x, y, z, 1.0, d, elem.rjust(2), '') 68 | 69 | 70 | def read_xyz_file(xyz_file, header_len=2): 71 | 72 | with open(xyz_file, 'r') as f: 73 | lines = f.readlines() 74 | 75 | structs = [] 76 | struct_start = 0 77 | for i, line in enumerate(lines): 78 | try: 79 | # line index relative to struct start 80 | j = i - struct_start 81 | 82 | if j == 0 or j >= header_len + n_atoms: 83 | struct_start = i 84 | structs.append([]) 85 | n_atoms = int(lines[i]) 86 | 87 | elif j < header_len: 88 | continue 89 | 90 | else: 91 | atom = xyz_line_to_atom(lines[i]) 92 | structs[-1].append(atom) 93 | except: 94 | print('{}:{} {}'.format(xyz_file, i, repr(line)), file=sys.stderr) 95 | raise 96 | 97 | return structs 98 | 99 | 100 | def write_pdb_file(pdb_file, atoms, dosum): 101 | lines = [] 102 | for i, atom in enumerate(atoms): 103 | line = atom_to_pdb_line(atom, i, dosum) 104 | lines.append(line) 105 | if pdb_file: 106 | with open(pdb_file, 'w') as f: 107 | f.write('\n'.join(lines)) 108 | else: 109 | print('\n'.join(lines)) 110 | 111 | 112 | def parse_args(): 113 | parser = argparse.ArgumentParser(description='Output a pymol script that creates \ 114 | arrows from an .xyz file containing atom coordinates and gradient components, \ 115 | can also create a .pdb file where the b-factor is the gradient magnitude') 116 | parser.add_argument('xyz_file') 117 | parser.add_argument('-s', '--scale', type=float, default=1.0, 118 | help='Arrow length scaling factor') 119 | parser.add_argument('-c', '--color', type=str, default='', 120 | help='Arrow color or pair of colors, e.g. "white black"') 121 | parser.add_argument('-r', '--radius', type=float, default=0.2, 122 | help='Radius of arrow body') 123 | parser.add_argument('-hr', '--hradius', type=float, default=-1, 124 | help='Radius of arrow head') 125 | parser.add_argument('-hl', '--hlength', type=float, default=-1, 126 | help='Length of arrow head') 127 | parser.add_argument('-p', '--pdb_file', action='store_true', default=False, 128 | help='Output a .pdb file where the b-factor is gradient magnitude') 129 | parser.add_argument('--sum', action='store_true', default=False, 130 | help='Sum gradient components instead of taking magnitude') 131 | parser.add_argument('-t', '--threshold', type=float, default=0, 132 | help="Gradient threshold for drawing arrows (using scale factor)") 133 | return parser.parse_args() 134 | 135 | 136 | if __name__ == '__main__': 137 | args = parse_args() 138 | structs = read_xyz_file(args.xyz_file) 139 | base_name = args.xyz_file.replace('.xyz', '') 140 | write_pymol_arrows(base_name, structs, args.scale, args.color, args.radius, args.hradius, args.hlength, args.threshold) 141 | if args.pdb_file: 142 | pdb_file = base_name + '.pdb' 143 | write_pdb_file(pdb_file, atoms, args.sum) 144 | 145 | -------------------------------------------------------------------------------- /reduce_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import argparse 4 | import random 5 | 6 | 7 | def crossval_files(prefix, numfolds): 8 | cvfiles = [] 9 | for i in range(numfolds): 10 | trainfile = '{}train{}.types'.format(prefix, i) 11 | testfile = '{}test{}.types'.format(prefix, i) 12 | cvfiles.append((trainfile, testfile)) 13 | return cvfiles 14 | 15 | 16 | def reduced_file(file): 17 | match = re.match('(.*?)(((train|test)[0-9]+)?.types)', file) 18 | return match.group(1) + '_reduced' + match.group(2) 19 | 20 | 21 | def read_lines(file): 22 | with open(file, 'r') as f: 23 | lines = f.readlines() 24 | return lines 25 | 26 | 27 | def write_reduced_lines(file, lines, factor): 28 | random.shuffle(lines) 29 | reduced = lines[:int(len(lines)/factor)] 30 | with open(file, 'w') as f: 31 | f.write(''.join(reduced)) 32 | 33 | 34 | def parse_args(argv=None): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('-p', '--prefix', required=True) 37 | parser.add_argument('-n', '--numfolds', type=int, default=3) 38 | parser.add_argument('-a', '--allfolds', default=False, action='store_true') 39 | parser.add_argument('-f', '--factor', required=True, type=float) 40 | parser.add_argument('-s', '--random_seed', type=int, default=0) 41 | return parser.parse_args(argv) 42 | 43 | 44 | if __name__ == '__main__': 45 | args = parse_args() 46 | random.seed(args.random_seed) 47 | cvfiles = crossval_files(args.prefix, args.numfolds) 48 | for i, (trainfile, testfile) in enumerate(cvfiles): 49 | train = read_lines(trainfile) 50 | reduced_trainfile = reduced_file(trainfile) 51 | write_reduced_lines(reduced_trainfile, train, args.factor) 52 | print(reduced_trainfile) 53 | test = read_lines(testfile) 54 | reduced_testfile = reduced_file(testfile) 55 | write_reduced_lines(reduced_testfile, test, args.factor) 56 | print(reduced_testfile) 57 | if args.allfolds: 58 | allfile = '{}.types'.format(args.prefix) 59 | all = read_lines(allfile) 60 | reduced_allfile = reduced_file(allfile) 61 | write_reduced_lines(reduced_allfile, all, args.factor) 62 | print(reduced_allfile) 63 | 64 | -------------------------------------------------------------------------------- /show_xyz_arrows.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | '''Adds show_arrows command to pymol, which takes an xyz file''' 3 | 4 | import sys 5 | import os 6 | 7 | from pymol import cmd, cgo, CmdException 8 | from chempy import cpv 9 | 10 | 11 | def draw_arrow(xyz1,xyz2, radius=0.5, gap=0.0, hlength=-1, hradius=-1, 12 | color='blue red', name=''): 13 | ''' 14 | Draw an arrow; borrows heavily from cgi arrows. 15 | ''' 16 | radius, gap = float(radius), float(gap) 17 | hlength, hradius = float(hlength), float(hradius) 18 | xyz1 = list(xyz1) 19 | xyz2 = list(xyz2) 20 | try: 21 | color1, color2 = color.split() 22 | except: 23 | color1 = color2 = color 24 | color1 = list(cmd.get_color_tuple(color1)) 25 | color2 = list(cmd.get_color_tuple(color2)) 26 | 27 | normal = cpv.normalize(cpv.sub(xyz1, xyz2)) 28 | 29 | if hlength < 0: 30 | hlength = radius * 3.0 31 | if hradius < 0: 32 | hradius = hlength * 0.6 33 | 34 | if gap: 35 | diff = cpv.scale(normal, gap) 36 | xyz1 = cpv.sub(xyz1, diff) 37 | xyz2 = cpv.add(xyz2, diff) 38 | 39 | xyz3 = cpv.add(cpv.scale(normal, hlength), xyz2) 40 | 41 | obj = [cgo.CYLINDER] + xyz1 + xyz3 + [radius] + color1 + color2 + \ 42 | [cgo.CONE] + xyz3 + xyz2 + [hradius, 0.0] + color2 + color2 + \ 43 | [1.0, 0.0] 44 | 45 | if not name: 46 | name = cmd.get_unused_name('arrow') 47 | 48 | cmd.load_cgo(obj, name) 49 | 50 | 51 | 52 | def make_pymol_arrows(base, atoms, scale, color, radius): 53 | 54 | arrow_objs = [] 55 | arrow_group = base + '_arrows' 56 | cmd.delete(arrow_group) #remove any pre-existing group 57 | for i, atom in enumerate(atoms): 58 | arrow_obj = base + '_arrow_' + str(i) 59 | arrow_objs.append(arrow_obj) 60 | elem, xi, yi, zi, dx, dy, dz = atom 61 | c = 1.725*radius 62 | xf = xi + -scale*dx + c 63 | yf = yi + -scale*dy + c 64 | zf = zi + -scale*dz + c 65 | draw_arrow((xi,yi,zi),(xf,yf,zf),radius=radius,color=color,name=arrow_obj) 66 | 67 | cmd.group(arrow_group,' '.join(arrow_objs)) 68 | 69 | 70 | def xyz_line_to_atom(xyz_line): 71 | fields = xyz_line.split() 72 | elem = fields[0] 73 | x = float(fields[1]) 74 | y = float(fields[2]) 75 | z = float(fields[3]) 76 | dx = float(fields[4]) 77 | dy = float(fields[5]) 78 | dz = float(fields[6]) 79 | return elem, x, y, z, dx, dy, dz 80 | 81 | 82 | 83 | def read_xyz_file(xyz_file): 84 | with open(xyz_file, 'r') as f: 85 | lines = f.readlines() 86 | n_atoms = int(lines[0]) 87 | atoms = [] 88 | for i in range(n_atoms): 89 | atom = xyz_line_to_atom(lines[2+i]) 90 | atoms.append(atom) 91 | return atoms 92 | 93 | 94 | 95 | def show_xyz_arrows(xyzfile, scale=2.0, color="white purple",radius=0.2): 96 | atoms = read_xyz_file(xyzfile) 97 | base_name = xyzfile.replace('.xyz', '') 98 | make_pymol_arrows(base_name, atoms, float(scale), color, float(radius)) 99 | 100 | 101 | 102 | cmd.extend('show_xyz_arrows', show_xyz_arrows) 103 | 104 | -------------------------------------------------------------------------------- /timemodel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | '''quick script for generating a real model and caffe time command''' 4 | import argparse,sys 5 | 6 | parser = argparse.ArgumentParser(description='Train neural net on .types data.') 7 | parser.add_argument('-m','--model',type=str,required=True,help="Model template. Must use TRAINFILE and TESTFILE") 8 | parser.add_argument('-p','--prefix',type=str,help="Prefix for training/test files: [train|test][num].types",default="/home/dkoes/scorebench/PDBbind/refined-set/affinity_search/types/small") 9 | parser.add_argument('-o','--output',type=str,help="Output model (default timeit.model)",default="timeit.model") 10 | parser.add_argument('--data_root',type=str,help="Root directory for gninatypes files",default="/home/dkoes/scorebench/PDBbind/refined-set/") 11 | 12 | args = parser.parse_args() 13 | 14 | model = open(args.model).read() 15 | model = model.replace('TRAINFILE','%strain0.types'%args.prefix) 16 | model = model.replace('TESTFILE','%stest0.types'%args.prefix) 17 | model = model.replace('DATA_ROOT',args.data_root) 18 | 19 | out = open(args.output,'w') 20 | out.write(model) 21 | print("caffe time -gpu 0 -model %s"%args.output) 22 | 23 | -------------------------------------------------------------------------------- /types2xyz.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | '''convert gninatypes file to xyz file''' 4 | import struct, sys, argparse 5 | from functools import partial 6 | import molgrid 7 | 8 | names = molgrid.GninaIndexTyper().get_type_names() 9 | 10 | def elem(t): 11 | '''convert type index into element string''' 12 | name = names[t] 13 | if 'Hydrogen' in name: 14 | return 'H' 15 | elif 'Carbon' in name: 16 | return 'C' 17 | elif 'Nitrogen' in name: 18 | return 'N' 19 | elif 'Oxygen' in name: 20 | return 'O' 21 | elif 'Sulfur' in name: 22 | return 'S' 23 | elif 'Phosphorus' == name: 24 | return 'P' 25 | elif 'Fluorine' == name: 26 | return 'F' 27 | elif 'Chlorine' == name: 28 | return 'Cl' 29 | elif 'Bromine' == name: 30 | return 'Br' 31 | elif 'Iodine' == name: 32 | return 'I' 33 | elif 'Magnesium' == name: 34 | return 'Mg' 35 | elif 'Manganese' == name: 36 | return 'Mn' 37 | elif 'Zinc' == name: 38 | return 'Zn' 39 | elif 'Calcium' == name: 40 | return 'Ca' 41 | elif 'Iron' == name: 42 | return 'Fe' 43 | elif 'Boron' == name: 44 | return 'B' 45 | else: 46 | return 'X' 47 | 48 | 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('input',type=str,help='gninatypes file') 51 | parser.add_argument('output', default='-',nargs='?',type=argparse.FileType('w'),help='output xyz') 52 | 53 | args = parser.parse_args() 54 | 55 | 56 | struct_fmt = 'fffi' 57 | struct_len = struct.calcsize(struct_fmt) 58 | struct_unpack = struct.Struct(struct_fmt).unpack_from 59 | 60 | with open(args.input,'rb') as tfile: 61 | results = [struct_unpack(chunk) for chunk in iter(partial(tfile.read, struct_len), b'')] 62 | 63 | args.output.write('%d\n'%len(results)) # number atoms 64 | args.output.write(args.input+'\n') #comment 65 | for x,y,z,t in results: 66 | args.output.write('%s\t%f\t%f\t%f\n'%(elem(t),x,y,z)) 67 | 68 | -------------------------------------------------------------------------------- /types_extender.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ''' 3 | This script will generate the new types file with the lines from generate_counterexample_typeslines.py 4 | 5 | Assumptions 6 | i) The data structure is // 7 | ii) The name of the file containing the types lines to add is for each pocket in the types file. 8 | iii) the input types file has / from which to parse the needed pockets from. 9 | 10 | INPUT 11 | i) Original types file 12 | ii) New types filename 13 | iii) Name of file in Pocket that contains the lines to add 14 | iv) The ROOT of the data directory 15 | 16 | OUTPUT 17 | i) The new types file -- note that the lines of the new types file will not necessarily be in order. 18 | ''' 19 | 20 | import argparse, os, re, glob 21 | 22 | def check_exists(filename): 23 | if os.path.isfile(filename) and os.path.getsize(filename)>0: 24 | return True 25 | else: 26 | return False 27 | 28 | parser=argparse.ArgumentParser(description='Add lines to types file and create a new one. Assumes data file structure is ROOT/POCKET/FILES.') 29 | parser.add_argument('-i','--input',type=str,required=True,help='Types file you will be extending.') 30 | parser.add_argument('-o','--output',type=str,required=True,help='Name of the extended types file.') 31 | parser.add_argument('-n','--name',type=str,required=True,help='Name of the file containing the lines to add for a given pocket. This is the output of generate_counterexample_typeslines.py.') 32 | parser.add_argument('-r','--root',default='',help='Root of the data directory. Defaults to current working directory.') 33 | args=parser.parse_args() 34 | 35 | completed=set() 36 | with open(args.output,'w') as outfile: 37 | with open(args.input) as infile: 38 | for line in infile: 39 | outfile.write(line) 40 | m=re.search(r' (\S+)/',line) 41 | pocket=m.group(1) 42 | 43 | if pocket not in completed: 44 | completed.add(pocket) 45 | with open(os.path.join(args.root,pocket,args.name)) as linesfile: 46 | for line2 in linesfile: 47 | outfile.write(line2) 48 | --------------------------------------------------------------------------------