├── .gitignore ├── Adaptive-MPO ├── Task1_Adaptive_MPO_HITL.py ├── adapt-mpo.yml ├── data │ └── best_matching_params_modifiedQED_config.json ├── scripts │ ├── REINVENTconfig.py │ ├── acquisition.py │ ├── evaluation.py │ ├── lowSlogPQED.py │ └── write.py └── usermodel_doublesigmoid.stan ├── Chemists-Component ├── Task2_Chemists_Component.ipynb ├── acquisition.py ├── cc_env_hitl.yml ├── cc_env_reinvent.yml ├── data │ ├── drd2.pkl │ ├── drd2_regression.test.csv │ └── drd2_regression.train.csv ├── evaluate_results_Task2.ipynb ├── kernels.py ├── load_data.py ├── models.py ├── molecular_features.py ├── query.py └── write.py ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ -------------------------------------------------------------------------------- /Adaptive-MPO/Task1_Adaptive_MPO_HITL.py: -------------------------------------------------------------------------------- 1 | # load dependencies 2 | import sys 3 | import os 4 | import re 5 | import json 6 | import tempfile 7 | import pandas as pd 8 | import csv 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | import rdkit.Chem as Chem 12 | import math 13 | import reinvent_scoring 14 | from numpy.random import default_rng 15 | 16 | from scripts.lowSlogPQED import expert_qed, get_adsParameters 17 | from scripts.write import write_REINVENT_config 18 | from scripts.acquisition import select_query 19 | from scripts.REINVENTconfig import parse_config_file 20 | 21 | import pystan 22 | import stan_utility 23 | import pickle 24 | 25 | 26 | def do_run(acquisition, seed): 27 | ############################## 28 | # Quick options 29 | FIT_MODEL = True # whether to fit a Stan model or not 30 | LOAD_MODEL = False # load Stan model from disc instead of fitting it 31 | SUBSAMPLE = True # Reduce the size of the pool of unlabeled molecules to reduce computation time 32 | ############################## 33 | 34 | jobid = 'demo_Task1' 35 | jobname = "Learn the parameters of MPO" 36 | np.random.seed(seed) 37 | rng = default_rng(seed) 38 | 39 | ########### HITL setup ################# 40 | T = 10 # numer of HITL iterations (T in the paper) 41 | n = 10 # number of molecules shown to the simulated chemist at each iteration (n_batch in the paper) 42 | n0 = 10 # number of molecules shown to the expert at initialization (N_0 in the paper) 43 | K = 3 # number of REINVENT runs (K=R+1 in the paper): usage: K=2 for one HITL round (T*n+n0 queries); K=3 for two HITL rounds (2*(T*n+n0) queries) 44 | ######################################## 45 | 46 | stanfile = 'usermodel_doublesigmoid' 47 | model_identifier = stanfile + '_' + str(jobid) 48 | 49 | # --------- change these path variables as required 50 | reinvent_dir = os.path.expanduser("/path/to/Reinvent") 51 | reinvent_env = os.path.expanduser("/path/to/conda_environment") 52 | output_dir = os.path.expanduser("/path/to/result/directory/{}_seed{}".format(jobid,seed)) 53 | print("Running MPO experiment with N0={}, T={}, R={}, n_batch={}, seed={}. \n Results will be saved at {}".format(n0, T, K-1, n, seed, output_dir)) 54 | 55 | expert_score = [] 56 | conf_filename='config.json' 57 | 58 | for REINVENT_iteration in np.arange(1,K): 59 | # if required, generate a folder to store the results 60 | READ_ONLY = False # if folder exists, do not overwrite results there 61 | try: 62 | os.mkdir(output_dir) 63 | except FileExistsError: 64 | READ_ONLY = True 65 | print("Reading REINVENT results from file, no re-running.") 66 | pass 67 | 68 | if(not READ_ONLY): 69 | # write the configuration file to disc 70 | configuration_JSON_path = write_REINVENT_config(reinvent_dir, reinvent_env, output_dir, jobid, jobname) 71 | # run REINVENT 72 | os.system(reinvent_env + '/bin/python ' + reinvent_dir + '/input.py ' + configuration_JSON_path + '&> ' + output_dir + '/run.err') 73 | 74 | with open(os.path.join(output_dir, "results/scaffold_memory.csv"), 'r') as file: 75 | data = pd.read_csv(file) 76 | 77 | colnames = list(data) 78 | smiles = data['SMILES'] 79 | total_score = data['total_score'] 80 | high_scoring_threshold = 0.4 81 | high_scoring_idx = total_score >= high_scoring_threshold 82 | 83 | # Scoring component values 84 | scoring_component_names = [s.split("raw_")[1] for s in colnames if "raw_" in s] 85 | print("scoring components:") 86 | print(scoring_component_names) 87 | x = np.array(data[scoring_component_names]) 88 | print('Scoring component matrix dimensions:') 89 | print(x.shape) 90 | 91 | # Only analyse highest scoring molecules 92 | x = x[high_scoring_idx,:] 93 | smiles = smiles[high_scoring_idx] 94 | total_score = total_score[high_scoring_idx] 95 | print('{} molecules'.format(len(smiles))) 96 | 97 | # Expert values (modified QED) 98 | s_mqed = np.zeros(len(smiles)) 99 | for i in np.arange(len(smiles)): 100 | try: 101 | cur_mol = Chem.MolFromSmiles(smiles[i]) 102 | s_mqed[i] = expert_qed(cur_mol) 103 | except: 104 | print("INVALID MOLECULE in scaffold memory") 105 | s_mqed[i] = 0 106 | expert_score += [s_mqed] 107 | print("Average modified QED in REINVENT output") 108 | print(np.mean(expert_score[REINVENT_iteration-1])) 109 | 110 | raw_scoring_component_names = ["raw_"+name for name in scoring_component_names] 111 | x_raw = data[raw_scoring_component_names].to_numpy() 112 | x = data[scoring_component_names].to_numpy() 113 | if(SUBSAMPLE): 114 | N0 = x_raw.shape[0] 115 | N = 10000 # Maximum number of molecules in U 116 | N = min(N0, N) 117 | sample_idx = rng.choice(N0, N, replace=False) 118 | x_raw = x_raw[sample_idx,:] 119 | x = x[sample_idx,:] 120 | smiles = smiles[sample_idx] 121 | try: 122 | user_score = expert_score[REINVENT_iteration-1][sample_idx] 123 | except IndexError: 124 | user_score = np.array([expert_qed(Chem.MolFromSmiles(si)) for si in smiles]) 125 | else: 126 | try: 127 | user_score = expert_score[REINVENT_iteration-1] 128 | except IndexError: 129 | user_score = np.array([expert_qed(Chem.MolFromSmiles(si)) for si in smiles]) 130 | 131 | N = x_raw.shape[0] # total number of molecules 132 | print("N_U={}".format(N)) 133 | k = x_raw.shape[1] # number of scoring functions 134 | w_e = np.ones(k)/k # equal weights 135 | 136 | # Generate simulated chemist's feedback 137 | y_all = rng.binomial(1, user_score) 138 | # Select indices of feedback molecules at initialization (=iteration 0) 139 | selected_feedback = rng.choice(N, n0, replace=False) 140 | 141 | # Read desirability function (=score transformation) parameters from config file; 142 | # they will be used as prior means in the user-model 143 | conf0 = parse_config_file(os.path.join(output_dir, conf_filename), scoring_component_names) 144 | low0 = conf0['low'] 145 | high0 = conf0['high'] 146 | print("Prior means of low:") 147 | print(low0) 148 | print("Prior means of high:") 149 | print(high0) 150 | 151 | # fixed double sigmoid params from config file: 152 | coef_div = conf0['coef_div'] 153 | coef_si = conf0['coef_si'] 154 | coef_se = conf0['coef_se'] 155 | 156 | # Read true values from a ground-truth config file 157 | gt_config = parse_config_file(os.path.join('./data/best_matching_params_modifiedQED_config.json'), scoring_component_names) 158 | low_true = gt_config['low'] 159 | high_true = gt_config['high'] 160 | 161 | mask = np.ones(N, dtype=bool) 162 | mask[selected_feedback] = False 163 | 164 | data_doublesigmoid = { 165 | 'n': n0, 166 | 'k': k, 167 | 'x_raw': x_raw[selected_feedback,:], 168 | 'y': y_all[selected_feedback], 169 | 'weights': w_e, 170 | "coef_div": coef_div, 171 | "coef_si": coef_si, 172 | "coef_se": coef_se, 173 | 'high0': high0, 174 | 'low0': low0, 175 | 'npred': N-len(selected_feedback), 176 | 'xpred': x_raw[mask,:] 177 | } 178 | 179 | model_savefile = output_dir + '/{}_iteration_{}.pkl'.format(model_identifier, REINVENT_iteration-1) 180 | if(FIT_MODEL): 181 | print("compiling the Stan model") 182 | model = stan_utility.compile_model('./' + stanfile + '.stan', model_name=model_identifier) 183 | print("sampling") 184 | fit = model.sampling(data=data_doublesigmoid, seed=8453462, chains=4, iter=1000, n_jobs=1) 185 | print("Saving the fitted model to {}".format(model_savefile)) 186 | pickle.dump({'model': model, 'fit': fit}, open(model_savefile, 'wb' ), protocol=-1) 187 | if(LOAD_MODEL): 188 | print("Loading the fit") 189 | data_dict = pickle.load(open(model_savefile, 'rb')) 190 | fit = data_dict['fit'] 191 | model = data_dict['model'] 192 | la = fit.extract(permuted=True) # return a dictionary of arrays for each model parameter 193 | # compute errors in learned limits 194 | low = np.mean(la['lows'],axis=0) 195 | high = np.mean(la['highs'],axis=0) 196 | parameter_order = ['low{}'.format(i) for i in np.arange(len(low0))] + ['high{}'.format(i) for i in np.arange(len(high0))] 197 | thetas = np.hstack((low,high)) 198 | thetas_true = np.hstack((low_true, high_true)) 199 | errs = [(thetas_true - thetas)**2] # MSE 200 | mean_limits =[thetas] 201 | 202 | # Diagnostic tests 203 | stan_utility.check_all_diagnostics(fit) 204 | warning0 = stan_utility.check_all_diagnostics(fit, quiet=True) 205 | 206 | print("highs") 207 | for i in np.arange(7): 208 | print(high[i]) 209 | print("lows") 210 | for i in np.arange(7): 211 | print(low[i]) 212 | 213 | y = y_all[selected_feedback] 214 | n_accept = [sum(y)] # number of accepted molecules at each iteration 215 | warning_code = [warning0] 216 | 217 | ########################### HITL rounds ###################################### 218 | for t in np.arange(T): 219 | print("iteration t={}".format(t)) 220 | # query selection 221 | new_query = select_query(N, n, fit, selected_feedback, acquisition, rng) 222 | # get simulated chemist's responses 223 | new_y = rng.binomial(1, user_score[new_query]) 224 | n_accept += [sum(new_y)] 225 | print("Feedback idx at iteration {}:".format(t)) 226 | print(new_query) 227 | print("Number of accepted molecules at iteration {}: {}".format(t,n_accept[t])) 228 | # append feedback 229 | selected_feedback = np.hstack((selected_feedback, new_query)) 230 | y = np.hstack((y, new_y)) 231 | mask = np.ones(N, dtype=bool) 232 | mask[selected_feedback] = False 233 | data_doublesigmoid = { 234 | 'n': len(selected_feedback), 235 | 'k': k, 236 | 'x_raw': x_raw[selected_feedback,:], 237 | 'y': y, 238 | 'weights': w_e, 239 | "coef_div": coef_div, 240 | "coef_si": coef_si, 241 | "coef_se": coef_se, 242 | 'high0': high0, 243 | 'low0': low0, 244 | 'npred': N-len(selected_feedback), 245 | 'xpred': x_raw[mask,:] 246 | } 247 | # re-fit model 248 | fit = model.sampling(data=data_doublesigmoid, seed=8453462, chains=4, iter=1000, n_jobs=1) 249 | stan_utility.check_all_diagnostics(fit) 250 | code = stan_utility.check_all_diagnostics(fit, quiet=True) 251 | warning_code += [code] 252 | la = fit.extract(permuted=True) 253 | low = np.mean(la['lows'],axis=0) 254 | high = np.mean(la['highs'],axis=0) 255 | thetas = np.hstack((low,high)) 256 | errs += [(thetas_true - thetas)**2] 257 | mean_limits += [thetas] 258 | 259 | # Posterior mean of parameters 260 | lows = la['lows'] 261 | highs = la['highs'] 262 | m_high = np.mean(highs, axis=0) 263 | m_low = np.mean(lows, axis=0) 264 | 265 | x = np.arange(T+1) 266 | true = np.hstack((low_true, high_true)) 267 | rerrs = np.absolute(mean_limits - true) / np.absolute(true) 268 | plt.plot(x, rerrs) 269 | plt.ylabel("relative error: |error|/true") 270 | plt.xlabel("Number of iterations") 271 | plt.legend(parameter_order) 272 | plt.title("Relative errors in learned limits") 273 | plt.savefig(os.path.join(output_dir, '{}_relative_abs_error_{}.png'.format(jobid, acquisition)), bbox_inches='tight') 274 | plt.close() 275 | 276 | plt.plot(x, np.mean(rerrs, axis=1)) 277 | plt.title("Mean relative error in learned limits") 278 | plt.ylabel("relative error: |error|/true") 279 | plt.xlabel("Number of iterations") 280 | plt.savefig(os.path.join(output_dir, '{}_relative_abs_error_mean_{}.png'.format(jobid, acquisition)), bbox_inches='tight') 281 | plt.close() 282 | 283 | #### SAVE RESULTS ### 284 | dat_save = {'mean limits': mean_limits, 'true limits': true, 'rerrs': rerrs} 285 | filename = output_dir + '/log_{}_it{}.p'.format(acquisition,T) 286 | with open(filename , 'wb') as f: 287 | pickle.dump(dat_save, f) 288 | 289 | print("Check convergence diagnostics of Stan: bits from right to left: n_eff, r_hat, div, treedepth, energy") 290 | for t in np.arange(len(warning_code)): 291 | print("t={}".format(t)) 292 | print(bin(warning_code[t])) 293 | 294 | 295 | # Define directory for the next round 296 | output_dir_iter = os.path.join(output_dir, "iteration{}_{}".format(REINVENT_iteration, acquisition)) 297 | READ_ONLY = False 298 | # if required, generate a folder to store the results 299 | try: 300 | os.mkdir(output_dir_iter) 301 | except FileExistsError: 302 | READ_ONLY = True 303 | print("Reading REINVENT results from file, no re-running.") 304 | pass 305 | 306 | def set_scoring_component_parameters(configuration, params): 307 | # modify data structure for easy access to components by their name 308 | scc = {} 309 | for comp in configuration["parameters"]["scoring_function"]["parameters"]: 310 | scc[comp["name"]] = comp 311 | 312 | for name, p in params.items(): 313 | for key, value in p.items(): 314 | print("Writing component " + name + ": " + key + "=" + str(value)) 315 | scc[name]["specific_parameters"]["transformation"][key] = value 316 | 317 | # modify parameters of the score transformations 318 | configuration = json.load(open(os.path.join(output_dir, conf_filename))) 319 | params = {} 320 | for i, comp in enumerate(scoring_component_names): 321 | params[comp] = {'high': m_high[i], 'low': m_low[i]} 322 | set_scoring_component_parameters(configuration, params) 323 | print(configuration) 324 | 325 | # modify log and result paths in configuration 326 | configuration["logging"]["logging_path"] = os.path.join(output_dir_iter, "progress.log") 327 | configuration["logging"]["result_folder"] = os.path.join(output_dir_iter, "results") 328 | 329 | if(not READ_ONLY): 330 | conf_filename = "iteration{}_config.json".format(REINVENT_iteration) 331 | configuration_JSON_path = os.path.join(output_dir_iter, conf_filename) 332 | # write the updated configuration file to the disc 333 | with open(configuration_JSON_path, 'w') as f: 334 | json.dump(configuration, f, indent=4, sort_keys=True) 335 | 336 | # Run REINVENT again 337 | if(not READ_ONLY): 338 | os.system(reinvent_env + '/bin/python ' + reinvent_dir + '/input.py ' + configuration_JSON_path + '&> ' + output_dir_iter + '/run.err') 339 | 340 | with open(os.path.join(output_dir_iter, "results/scaffold_memory.csv"), 'r') as file: 341 | data_it1 = pd.read_csv(file) 342 | 343 | # Last round: analyze results 344 | if REINVENT_iteration == K-1: 345 | # extract SMILES from scaffold memory 346 | smiles_it1 = data_it1['SMILES'] 347 | total_score_it1 = data_it1['total_score'] 348 | high_scoring_idx_it1 = total_score_it1 >= high_scoring_threshold 349 | 350 | scoring_component_names = [s.split("raw_")[1] for s in colnames if "raw_" in s] 351 | x_it1 = np.array(data_it1[scoring_component_names]) 352 | 353 | # Only analyse highest scoring molecules 354 | x_it1 = x_it1[high_scoring_idx_it1,:] 355 | smiles_it1 = smiles_it1[high_scoring_idx_it1] 356 | total_score_it1 = total_score_it1[high_scoring_idx_it1] 357 | print('{} molecules'.format(len(smiles_it1))) 358 | 359 | # Expert values (modified QED) 360 | s_mqed = np.zeros(len(smiles_it1)) 361 | for i in np.arange(len(smiles_it1)): 362 | try: 363 | cur_mol = Chem.MolFromSmiles(smiles_it1[i]) 364 | s_mqed[i] = expert_qed(cur_mol) 365 | except: 366 | s_mqed[i] = 0 367 | expert_score += [s_mqed] 368 | 369 | for i in np.arange(len(expert_score)): 370 | print("Iteration " + str(i)) 371 | print("Average modified QED in REINVENT output") 372 | print(np.mean(expert_score[i])) 373 | print("Number of molecules with modified QED score > 0.8") 374 | print(np.sum([int(sc >= 0.8) for sc in expert_score[i]])) 375 | print("Number of molecules with modified QED score > 0.9") 376 | print(np.sum([int(sc >= 0.9) for sc in expert_score[i]])) 377 | 378 | dat_save = {'mean limits': mean_limits, 'true limits': true, 'rerrs': rerrs, 'expert_score': expert_score} 379 | filename = output_dir + '/log_{}_it{}.p'.format(acquisition,T) 380 | with open(filename , 'wb') as f: 381 | pickle.dump(dat_save, f) 382 | 383 | # Set output dir and configuration file name of the next round: 384 | output_dir = output_dir_iter 385 | conf_filename = "iteration{}_config.json".format(REINVENT_iteration) 386 | 387 | # Plot the final result 388 | r = np.arange(len(expert_score)) 389 | m_score = [np.mean(expert_score[i]) for i in r] 390 | plt.plot(r, m_score) 391 | plt.title("Performance of {}".format(acquisition)) 392 | plt.ylabel("Average of modified QED score in REINVENT output") 393 | plt.xlabel("Rounds") 394 | plt.savefig(os.path.join(output_dir, '{}_mQED_{}.png'.format(jobid, acquisition)), bbox_inches='tight') 395 | plt.close() 396 | 397 | 398 | if __name__ == "__main__": 399 | print(sys.argv) 400 | acquisition = sys.argv[1] # acquisition: 'uncertainty', 'random', 'thompson', 'greedy' 401 | seed = int(sys.argv[2]) 402 | do_run(acquisition, seed) -------------------------------------------------------------------------------- /Adaptive-MPO/adapt-mpo.yml: -------------------------------------------------------------------------------- 1 | name: adapt_mpo 2 | channels: 3 | - rdkit 4 | - pytorch 5 | - openeye 6 | - conda-forge 7 | - anaconda 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1 11 | - _pytorch_select=0.1 12 | - _tflow_select=2.3.0 13 | - aiohttp=3.6.2 14 | - argon2-cffi=20.1.0 15 | - astunparse=1.6.3 16 | - async-timeout=3.0.1 17 | - async_generator=1.10 18 | - attrs=19.3.0 19 | - backcall=0.2.0 20 | - blas=1.0 21 | - bleach=3.2.1 22 | - blinker=1.4 23 | - brotlipy=0.7.0 24 | - bzip2=1.0.8 25 | - c-ares=1.17.1 26 | - ca-certificates=2021.4.13 27 | - cachetools=4.2.2 28 | - cairo=1.14.12 29 | - certifi=2020.12.5 30 | - cffi=1.14.3 31 | - chardet=3.0.4 32 | - click=8.0.0 33 | - cmarkgfm=0.4.2 34 | - colorama=0.4.4 35 | - coverage=5.5 36 | - cryptography=3.2.1 37 | - cudatoolkit=10.2.89 38 | - cycler=0.10.0 39 | - cython=0.29.23 40 | - dacite=1.5.1 41 | - dataclasses=0.7 42 | - dbus=1.13.12 43 | - decorator=5.0.7 44 | - defusedxml=0.7.1 45 | - dill=0.3.1.1 46 | - docutils=0.16 47 | - entrypoints=0.3 48 | - expat=2.2.9 49 | - fontconfig=2.13.0 50 | - freetype=2.9.1 51 | - glib=2.63.1 52 | - google-auth=1.28.0 53 | - google-auth-oauthlib=0.4.1 54 | - google-pasta=0.2.0 55 | - gst-plugins-base=1.14.0 56 | - gstreamer=1.14.0 57 | - h5py=2.10.0 58 | - hdf5=1.10.6 59 | - icu=58.2 60 | - idna=2.10 61 | - importlib-metadata=3.10.0 62 | - importlib_metadata=1.5.0 63 | - intel-openmp=2020.0 64 | - ipykernel=5.3.4 65 | - ipython=7.22.0 66 | - ipython_genutils=0.2.0 67 | - ipywidgets=7.6.3 68 | - jedi=0.17.0 69 | - jeepney=0.6.0 70 | - jinja2=2.11.3 71 | - joblib=0.15.1 72 | - jpeg=9b 73 | - jsonschema=3.2.0 74 | - jupyter=1.0.0 75 | - jupyter_client=6.1.12 76 | - jupyter_console=6.4.0 77 | - jupyter_core=4.7.1 78 | - jupyterlab_pygments=0.1.2 79 | - jupyterlab_widgets=1.0.0 80 | - keras-preprocessing=1.1.2 81 | - keyring=21.5.0 82 | - kiwisolver=1.1.0 83 | - ld_impl_linux-64=2.33.1 84 | - libboost=1.67.0 85 | - libedit=3.1.20181209 86 | - libffi=3.2.1 87 | - libgcc-ng=9.1.0 88 | - libgfortran-ng=7.3.0 89 | - libpng=1.6.37 90 | - libprotobuf=3.14.0 91 | - libsodium=1.0.18 92 | - libstdcxx-ng=9.1.0 93 | - libtiff=4.1.0 94 | - libuuid=1.0.3 95 | - libuv=1.40.0 96 | - libxcb=1.13 97 | - libxml2=2.9.9 98 | - markupsafe=1.1.1 99 | - matplotlib=3.1.3 100 | - matplotlib-base=3.1.3 101 | - mistune=0.8.4 102 | - mkl=2020.0 103 | - mkl-service=2.3.0 104 | - mkl_fft=1.0.15 105 | - mkl_random=1.1.0 106 | - more-itertools=8.2.0 107 | - multidict=4.7.3 108 | - multiprocess=0.70.9 109 | - nbclient=0.5.3 110 | - nbconvert=6.0.7 111 | - nbformat=5.1.3 112 | - ncurses=6.2 113 | - nest-asyncio=1.5.1 114 | - ninja=1.9.0 115 | - notebook=6.3.0 116 | - numpy=1.18.1 117 | - numpy-base=1.18.1 118 | - oauthlib=3.1.0 119 | - olefile=0.46 120 | - openeye-toolkits=2020.2.2 121 | - openssl=1.1.1k 122 | - packaging=20.3 123 | - pandas=1.0.3 124 | - pandoc=2.12 125 | - pandocfilters=1.4.3 126 | - parso=0.8.2 127 | - pathos=0.2.5 128 | - pcre=8.43 129 | - pexpect=4.8.0 130 | - pickleshare=0.7.5 131 | - pillow=7.0.0 132 | - pip=20.0.2 133 | - pixman=0.38.0 134 | - pkginfo=1.6.1 135 | - pluggy=0.13.1 136 | - pox=0.2.7 137 | - ppft=1.6.6.1 138 | - prometheus_client=0.10.1 139 | - prompt-toolkit=3.0.17 140 | - prompt_toolkit=3.0.17 141 | - ptyprocess=0.7.0 142 | - py=1.8.1 143 | - py-boost=1.67.0 144 | - py4j=0.10.7 145 | - pyasn1=0.4.8 146 | - pyasn1-modules=0.2.8 147 | - pycparser=2.20 148 | - pydantic=1.8.2 149 | - pygments=2.7.2 150 | - pyjwt=1.7.1 151 | - pyopenssl=19.1.0 152 | - pyparsing=2.4.6 153 | - pyqt=5.9.2 154 | - pyrsistent=0.17.3 155 | - pysocks=1.7.1 156 | - pytest=5.4.1 157 | - python=3.7.7 158 | - python-dateutil=2.8.1 159 | - python_abi=3.7 160 | - pytorch=1.7.1 161 | - pytz=2019.3 162 | - pyzmq=20.0.0 163 | - qt=5.9.7 164 | - qtconsole=5.0.3 165 | - qtpy=1.9.0 166 | - rdkit=2020.03.3.0 167 | - readline=8.0 168 | - readme_renderer=27.0 169 | - requests=2.25.0 170 | - requests-oauthlib=1.3.0 171 | - requests-toolbelt=0.9.1 172 | - rfc3986=1.4.0 173 | - rsa=4.7.2 174 | - scikit-learn=0.21.3 175 | - scipy=1.4.1 176 | - secretstorage=3.2.0 177 | - send2trash=1.5.0 178 | - setuptools=46.1.1 179 | - sip=4.19.8 180 | - six=1.14.0 181 | - sqlite=3.31.1 182 | - tensorboard-plugin-wit=1.6.0 183 | - termcolor=1.1.0 184 | - terminado=0.9.4 185 | - testpath=0.4.4 186 | - tk=8.6.8 187 | - tornado=6.0.4 188 | - tqdm=4.43.0 189 | - traitlets=5.0.5 190 | - twine=3.2.0 191 | - typing_extensions=3.7.4.3 192 | - urllib3=1.25.11 193 | - wcwidth=0.1.9 194 | - webencodings=0.5.1 195 | - wheel=0.34.2 196 | - widgetsnbextension=3.5.1 197 | - wrapt=1.12.1 198 | - xz=5.2.4 199 | - yarl=1.4.2 200 | - zeromq=4.3.4 201 | - zipp=2.2.0 202 | - zlib=1.2.11 203 | - zstd=1.3.7 204 | - pip: 205 | - absl-py==0.9.0 206 | - astor==0.8.1 207 | - future==0.18.2 208 | - gast==0.2.2 209 | - grpcio==1.27.2 210 | - keras-applications==1.0.8 211 | - markdown==3.2.1 212 | - opt-einsum==3.2.0 213 | - protobuf==3.11.3 214 | - reinvent-chemistry==0.0.51 215 | - reinvent-models==0.0.15 216 | - reinvent-scoring==0.0.73 217 | - tensorboard==1.15.0 218 | - tensorflow==1.15.2 219 | - tensorflow-estimator==1.15.1 220 | - termcolor==1.1.0 221 | - werkzeug==1.0.0 222 | - pystan==2.19.1.1 223 | - stan-utility==0.1.2 224 | -------------------------------------------------------------------------------- /Adaptive-MPO/data/best_matching_params_modifiedQED_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "logging": { 3 | "job_id": "", 4 | "job_name": "", 5 | "logging_frequency": 0, 6 | "logging_path": "", 7 | "recipient": "local", 8 | "result_folder": "", 9 | "sender": "http://127.0.0.1" 10 | }, 11 | "parameters": { 12 | "diversity_filter": { 13 | "bucket_size": 25, 14 | "minscore": 0.4, 15 | "minsimilarity": 0.4, 16 | "name": "IdenticalMurckoScaffold" 17 | }, 18 | "inception": { 19 | "memory_size": 20, 20 | "sample_size": 5, 21 | "smiles": [] 22 | }, 23 | "reinforcement_learning": { 24 | "agent": "", 25 | "batch_size": 128, 26 | "learning_rate": 0.0001, 27 | "margin_threshold": 50, 28 | "n_steps": 300, 29 | "prior": "", 30 | "reset": 0, 31 | "reset_score_cutoff": 0.5, 32 | "sigma": 128 33 | }, 34 | "scoring_function": { 35 | "name": "custom_product", 36 | "parallel": true, 37 | "parameters": [ 38 | { 39 | "component_type": "molecular_weight", 40 | "name": "Molecular weight", 41 | "specific_parameters": { 42 | "transformation": { 43 | "coef_div": 175.77, 44 | "coef_se": 2, 45 | "coef_si": 2, 46 | "high": 421.98, 47 | "low": 200.4, 48 | "transformation_type": "double_sigmoid" 49 | } 50 | }, 51 | "weight": 1 52 | }, 53 | { 54 | "component_type": "slogp", 55 | "name": "SlogP", 56 | "specific_parameters": { 57 | "transformation": { 58 | "coef_div": 3.0, 59 | "coef_se": 2, 60 | "coef_si": 2, 61 | "high": 3.48, 62 | "low": -1.47, 63 | "transformation_type": "double_sigmoid" 64 | } 65 | }, 66 | "weight": 1 67 | }, 68 | { 69 | "component_type": "num_hbd_lipinski", 70 | "name": "HB-donors (Lipinski)", 71 | "specific_parameters": { 72 | "transformation": { 73 | "coef_div": 2.41, 74 | "coef_se": 2, 75 | "coef_si": 2, 76 | "high": 2.79, 77 | "low": -0.298, 78 | "transformation_type": "double_sigmoid" 79 | } 80 | }, 81 | "weight": 1 82 | }, 83 | { 84 | "component_type": "num_hba_lipinski", 85 | "name": "HB-acceptors (Lipinski)", 86 | "specific_parameters": { 87 | "transformation": { 88 | "coef_div": 4.42, 89 | "coef_se": 4.4, 90 | "coef_si": 2, 91 | "high": 6.21, 92 | "low": 1.3, 93 | "transformation_type": "double_sigmoid" 94 | } 95 | }, 96 | "weight": 1 97 | }, 98 | { 99 | "component_type": "tpsa", 100 | "name": "PSA", 101 | "specific_parameters": { 102 | "transformation": { 103 | "coef_div": 75.34, 104 | "coef_se": 2, 105 | "coef_si": 2, 106 | "high": 120.68, 107 | "low": 13.28, 108 | "transformation_type": "double_sigmoid" 109 | } 110 | }, 111 | "weight": 1 112 | }, 113 | { 114 | "component_type": "num_rotatable_bonds", 115 | "name": "Number of rotatable bonds", 116 | "specific_parameters": { 117 | "transformation": { 118 | "coef_div": 5.69, 119 | "coef_se": 2, 120 | "coef_si": 2, 121 | "high": 7.56, 122 | "low": 0.269, 123 | "transformation_type": "double_sigmoid" 124 | } 125 | }, 126 | "weight": 1 127 | }, 128 | { 129 | "component_type": "num_rings", 130 | "name": "Number of aromatic rings", 131 | "specific_parameters": { 132 | "transformation": { 133 | "coef_div": 2.28, 134 | "coef_se": 2, 135 | "coef_si": 2, 136 | "high": 2.69, 137 | "low": -0.083, 138 | "transformation_type": "double_sigmoid" 139 | } 140 | }, 141 | "weight": 1 142 | } 143 | ] 144 | } 145 | }, 146 | "run_type": "reinforcement_learning", 147 | "version": 3 148 | } -------------------------------------------------------------------------------- /Adaptive-MPO/scripts/REINVENTconfig.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | 4 | def parse_config_file(f, scoring_component_names): 5 | configuration = json.load(open(f)) 6 | configuration_scoring_components = {} 7 | for comp in configuration["parameters"]["scoring_function"]["parameters"]: 8 | configuration_scoring_components[comp["name"]] = comp 9 | low = np.array([configuration_scoring_components[comp]["specific_parameters"]["transformation"]["low"] for comp in scoring_component_names]) 10 | high = np.array([configuration_scoring_components[comp]["specific_parameters"]["transformation"]["high"] for comp in scoring_component_names]) 11 | # double sigmoid params: 12 | coef_div = np.ones(high.shape) 13 | coef_si = np.ones(high.shape) 14 | coef_se = np.ones(high.shape) 15 | for i, comp in enumerate(scoring_component_names): 16 | try: 17 | coef_div[i] = configuration_scoring_components[comp]["specific_parameters"]["transformation"]["coef_div"] 18 | coef_si[i] = configuration_scoring_components[comp]["specific_parameters"]["transformation"]["coef_si"] 19 | coef_se[i] = configuration_scoring_components[comp]["specific_parameters"]["transformation"]["coef_se"] 20 | except KeyError: 21 | # In case original transformations were not double sigmoid, use default values 22 | coef_div[i] = high[i] 23 | coef_si[i] = 10 24 | coef_se[i] = 10 25 | return {'low': low, 'high': high, 'coef_div': coef_div, 'coef_si': coef_si, 'coef_se': coef_se} 26 | -------------------------------------------------------------------------------- /Adaptive-MPO/scripts/acquisition.py: -------------------------------------------------------------------------------- 1 | # Acquisitions for HITL interaction 2 | 3 | import numpy as np 4 | 5 | def local_idx_to_fulldata_idx(N, selected_feedback, idx): 6 | all_idx = np.arange(N) 7 | mask = np.ones(N, dtype=bool) 8 | mask[selected_feedback] = False 9 | pred_idx = all_idx[mask] 10 | return pred_idx[idx] 11 | 12 | def uncertainty_sampling(N,n,fit,selected_feedback, rng, t=None): 13 | la = fit.extract(permuted=True) 14 | score_pred = np.mean(la['score_pred'],axis=0) 15 | utility = np.absolute(score_pred - 0.5) 16 | query_idx = np.argsort(utility)[:n] 17 | return local_idx_to_fulldata_idx(N, selected_feedback, query_idx) 18 | 19 | def posterior_sampling(N,n,fit,selected_feedback, rng, t=None): 20 | print("Posterior sampling at iteration t={}".format(t)) 21 | la = fit.extract(permuted=True) 22 | n_samples = la['score_pred'].shape[0] # number of posterior draws 23 | sample_idx = rng.choice(n_samples, n, replace=True) # draw n random beliefs from the posterior, with replacement 24 | query_idx = np.zeros(n) 25 | for i in np.arange(0,len(sample_idx)): 26 | score_pred = la['score_pred'][sample_idx[i],:] 27 | cq = np.random.choice(np.flatnonzero(score_pred == score_pred.max())) # breaks ties at random 28 | k = len(score_pred)-1 29 | while(cq in query_idx[:i]): 30 | k = k-1 # take second best option 31 | if np.sum(score_pred == score_pred.max()) > 1: 32 | print("More than one maximizer") 33 | cq = np.argsort(score_pred)[k] # take the k:th highest value 34 | query_idx[i] = cq 35 | return local_idx_to_fulldata_idx(N, selected_feedback, query_idx.astype(int)) 36 | 37 | def exploitation(N,n,fit,selected_feedback, rng, t=None): 38 | la = fit.extract(permuted=True) 39 | score_pred = np.mean(la['score_pred'],axis=0) 40 | query_idx = np.argsort(score_pred)[::-1][:n] # get the n highest 41 | return local_idx_to_fulldata_idx(N, selected_feedback, query_idx) 42 | 43 | def random_selection(N,n,fit,selected_feedback, rng, t=None): 44 | selected = rng.choice(N-len(selected_feedback),n, replace=False) 45 | return local_idx_to_fulldata_idx(N, selected_feedback, selected) 46 | 47 | 48 | def select_query(N,n,fit,selected_feedback, acquisition='random', rng=None, t=None): 49 | ''' 50 | Parameters 51 | ---------- 52 | N : Size of the unlabeled data before acquisition 53 | n : number of queries to select. 54 | model : Current prediction model. 55 | acquisition : acquisition method, optional. The default is 'random' 56 | rng : random generator to be used 57 | 58 | Returns 59 | ------- 60 | int idx: 61 | Index of the query 62 | 63 | ''' 64 | # select acquisition: 65 | if acquisition == 'uncertainty': 66 | acq = uncertainty_sampling 67 | elif acquisition == 'thompson': 68 | acq = posterior_sampling 69 | elif acquisition == 'greedy': 70 | acq = exploitation 71 | elif acquisition == 'random': 72 | acq = random_selection 73 | else: 74 | print("Warning: unknown acquisition criterion. Using random sampling.") 75 | acq = random_selection 76 | return acq(N, n, fit, selected_feedback, rng, t) 77 | -------------------------------------------------------------------------------- /Adaptive-MPO/scripts/evaluation.py: -------------------------------------------------------------------------------- 1 | from scripts.write import write_sample_file 2 | import os 3 | 4 | def sample_mols_from_agent(jobid, jobname, agent_dir, reinvent_env, reinvent_dir, N=1000): 5 | print("Sampling from agent " + os.path.join(agent_dir, "Agent.ckpt")) 6 | conf_file = write_sample_file(jobid, jobname, agent_dir, N) 7 | os.system(reinvent_env + '/bin/python ' + reinvent_dir + '/input.py ' + conf_file + '&> ' + agent_dir + '/sampling.err') 8 | 9 | 10 | -------------------------------------------------------------------------------- /Adaptive-MPO/scripts/lowSlogPQED.py: -------------------------------------------------------------------------------- 1 | # Modified QED objective 2 | # 3 | # Modified from original: https://github.com/rdkit/rdkit/blob/master/rdkit/Chem/QED.py 4 | # 5 | # Copyright (c) 2009-2017, Novartis Institutes for BioMedical Research Inc. 6 | # All rights reserved. 7 | # 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, are permitted provided that the following conditions are 10 | # met: 11 | # 12 | # * Redistributions of source code must retain the above copyright 13 | # notice, this list of conditions and the following disclaimer. 14 | # * Redistributions in binary form must reproduce the above 15 | # copyright notice, this list of conditions and the following 16 | # disclaimer in the documentation and/or other materials provided 17 | # with the distribution. 18 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 19 | # nor the names of its contributors may be used to endorse or promote 20 | # products derived from this software without specific prior written permission. 21 | # 22 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 23 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 24 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 25 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 26 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 27 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 28 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 30 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 31 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 | # 34 | 35 | import rdkit.Chem.QED as qd 36 | import math 37 | 38 | shift = -1.5 # The shift of the peak in logP score transformation 39 | logp_modified_ads_params = qd.ADSparameter(A=3.172690585, B=137.8624751, C=2.534937431+shift, D=4.581497897, E=0.822739154, F=0.576295591, DMAX=131.3186604) 40 | 41 | adsParameters_modified = { 42 | 'MW': qd.ADSparameter(A=2.817065973, B=392.5754953, C=290.7489764, D=2.419764353, E=49.22325677, 43 | F=65.37051707, DMAX=104.9805561), 44 | 'ALOGP': logp_modified_ads_params, 45 | 'HBA': qd.ADSparameter(A=2.948620388, B=160.4605972, C=3.615294657, D=4.435986202, E=0.290141953, 46 | F=1.300669958, DMAX=148.7763046), 47 | 'HBD': qd.ADSparameter(A=1.618662227, B=1010.051101, C=0.985094388, D=0.000000001, E=0.713820843, 48 | F=0.920922555, DMAX=258.1632616), 49 | 'PSA': qd.ADSparameter(A=1.876861559, B=125.2232657, C=62.90773554, D=87.83366614, E=12.01999824, 50 | F=28.51324732, DMAX=104.5686167), 51 | 'ROTB': qd.ADSparameter(A=0.010000000, B=272.4121427, C=2.558379970, D=1.565547684, E=1.271567166, 52 | F=2.758063707, DMAX=105.4420403), 53 | 'AROM': qd.ADSparameter(A=3.217788970, B=957.7374108, C=2.274627939, D=0.000000001, E=1.317690384, 54 | F=0.375760881, DMAX=312.3372610), 55 | 'ALERTS': qd.ADSparameter(A=0.010000000, B=1199.094025, C=-0.09002883, D=0.000000001, E=0.185904477, 56 | F=0.875193782, DMAX=417.7253140), 57 | } 58 | adsParameters_names = { 59 | "Molecular weight": 'MW', 60 | "SlogP": 'ALOGP', 61 | "HB-donors (Lipinski)":'HBD', 62 | "HB-acceptors (Lipinski)": 'HBA', 63 | "PSA": 'PSA', 64 | "Number of rotatable bonds": 'ROTB', 65 | "Number of aromatic rings": 'AROM', 66 | 'ALERTS': 'ALERTS' 67 | } 68 | 69 | # Compute modified QED 70 | WEIGHT_MAX = qd.QEDproperties(0.50, 0.25, 0.00, 0.50, 0.00, 0.50, 0.25, 1.00) 71 | WEIGHT_MEAN = qd.QEDproperties(0.66, 0.46, 0.05, 0.61, 0.06, 0.65, 0.48, 0.95) 72 | WEIGHT_NONE = qd.QEDproperties(1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00) 73 | WEIGHT_MEAN_NOALERTS = qd.QEDproperties(0.66, 0.46, 0.05, 0.61, 0.06, 0.65, 0.48, 0.0) 74 | def expert_qed(mol, w=WEIGHT_MEAN_NOALERTS, qedProperties=None): 75 | if qedProperties is None: 76 | qedProperties = qd.properties(mol) 77 | d = [qd.ads(pi, adsParameters_modified[name]) for name, pi in qedProperties._asdict().items()] 78 | t = sum(wi * math.log(di) for wi, di in zip(w, d)) 79 | return math.exp(t / sum(w)) 80 | 81 | def get_adsParameters(): 82 | return adsParameters_names 83 | 84 | 85 | def qed_properties_from_physchem(scoring_component_names, physchem_properties): 86 | raw_properties = {} 87 | for i, sc in enumerate(scoring_component_names): 88 | raw_properties[adsParameters_names[sc]] = physchem_properties[i] 89 | 90 | qedProperties = qd.QEDproperties(raw_properties['MW'], raw_properties['ALOGP'], raw_properties['HBA'], raw_properties['HBD'], raw_properties['PSA'], raw_properties['ROTB'], raw_properties['AROM'], 0) 91 | return qedProperties 92 | 93 | 94 | -------------------------------------------------------------------------------- /Adaptive-MPO/scripts/write.py: -------------------------------------------------------------------------------- 1 | # Scripts for writing and modifying configuration json files of REINVENT 2 | 3 | import os 4 | import json 5 | 6 | def write_REINVENT_config(reinvent_dir, reinvent_env, output_dir, jobid, jobname): 7 | 8 | diversity_filter = { 9 | "name": "IdenticalMurckoScaffold", 10 | "bucket_size": 25, 11 | "minscore": 0.4, 12 | "minsimilarity": 0.4 13 | } 14 | 15 | inception = { 16 | "memory_size": 20, 17 | "sample_size": 5, 18 | "smiles": [] 19 | } 20 | 21 | component_mv = { 22 | "component_type": "molecular_weight", 23 | "name": "Molecular weight", 24 | "weight": 1, 25 | "specific_parameters": { 26 | "transformation": { 27 | "transformation_type": "double_sigmoid", 28 | "high": 700, 29 | "low": 50, 30 | "coef_div": 175.77, 31 | "coef_si": 2, 32 | "coef_se": 2 33 | } 34 | } 35 | } 36 | 37 | component_slogp = { 38 | "component_type": "slogp", 39 | "name": "SlogP", 40 | "weight": 1, 41 | "specific_parameters": { 42 | "transformation": { 43 | "transformation_type": "double_sigmoid", 44 | "high": 10, 45 | "low": 3, 46 | "coef_div": 3.0, 47 | "coef_si": 2, 48 | "coef_se": 2 49 | } 50 | } 51 | } 52 | 53 | component_hba = { 54 | "component_type": "num_hba_lipinski", 55 | "name": "HB-acceptors (Lipinski)", 56 | "weight": 1, 57 | "specific_parameters": { 58 | "transformation": { 59 | "transformation_type": "double_sigmoid", 60 | "high": 11, 61 | "low": 2, 62 | "coef_div": 4.42, 63 | "coef_si": 2, 64 | "coef_se": 4.4 65 | } 66 | } 67 | } 68 | 69 | component_hbd = { 70 | "component_type": "num_hbd_lipinski", 71 | "name": "HB-donors (Lipinski)", 72 | "weight": 1, 73 | "specific_parameters": { 74 | "transformation": { 75 | "transformation_type": "double_sigmoid", 76 | "high": 8, 77 | "low": 1, 78 | "coef_div": 2.41, 79 | "coef_si": 2, 80 | "coef_se": 2 81 | } 82 | } 83 | } 84 | 85 | component_psa = { 86 | "component_type": "tpsa", 87 | "name": "PSA", 88 | "weight": 1, 89 | "specific_parameters": { 90 | "transformation": { 91 | "transformation_type": "double_sigmoid", 92 | "high": 300, 93 | "low": 100, 94 | "coef_div": 75.34, 95 | "coef_si": 2, 96 | "coef_se": 2 97 | } 98 | } 99 | } 100 | 101 | component_rotatable_bonds = { 102 | "component_type": "num_rotatable_bonds", 103 | "name": "Number of rotatable bonds", 104 | "weight": 1, 105 | "specific_parameters": { 106 | "transformation": { 107 | "transformation_type": "double_sigmoid", 108 | "high": 20, 109 | "low": 5, 110 | "coef_div": 5.69, 111 | "coef_si": 2, 112 | "coef_se": 2 113 | } 114 | } 115 | } 116 | 117 | component_num_rings = { 118 | "component_type": "num_rings", 119 | "name": "Number of aromatic rings", 120 | "weight": 1, 121 | "specific_parameters": { 122 | "transformation": { 123 | "transformation_type": "double_sigmoid", 124 | "high": 10, 125 | "low": 1, 126 | "coef_div": 2.28, 127 | "coef_si": 2, 128 | "coef_se": 2 129 | } 130 | } 131 | } 132 | 133 | scoring_function = { 134 | "name": "custom_product", 135 | "parallel": True, 136 | "parameters": [ 137 | component_mv, 138 | component_slogp, 139 | component_hbd, 140 | component_hba, 141 | component_psa, 142 | component_rotatable_bonds, 143 | component_num_rings 144 | ] 145 | } 146 | 147 | configuration = { 148 | "version": 3, 149 | "run_type": "reinforcement_learning", 150 | "model_type": "default", 151 | "parameters": { 152 | "scoring_function": scoring_function 153 | } 154 | } 155 | 156 | configuration["parameters"]["diversity_filter"] = diversity_filter 157 | configuration["parameters"]["inception"] = inception 158 | 159 | configuration["parameters"]["reinforcement_learning"] = { 160 | "prior": os.path.join(reinvent_dir, "data/random.prior.new"), 161 | "agent": os.path.join(reinvent_dir, "data/random.prior.new"), 162 | "n_steps": 300, 163 | "sigma": 128, 164 | "learning_rate": 0.0001, 165 | "batch_size": 128, 166 | "reset": 0, 167 | "reset_score_cutoff": 0.5, 168 | "margin_threshold": 50 169 | } 170 | 171 | configuration["logging"] = { 172 | "sender": "http://127.0.0.1", 173 | "recipient": "local", 174 | "logging_frequency": 0, 175 | "logging_path": os.path.join(output_dir, "progress.log"), 176 | "result_folder": os.path.join(output_dir, "results"), 177 | "job_name": jobname, 178 | "job_id": jobid 179 | } 180 | 181 | # write the configuration file to disc 182 | configuration_JSON_path = os.path.join(output_dir, "config.json") 183 | with open(configuration_JSON_path, 'w') as f: 184 | json.dump(configuration, f, indent=4, sort_keys=True) 185 | 186 | return configuration_JSON_path 187 | 188 | 189 | def write_sample_file(jobid, jobname, agent_dir, N): 190 | configuration={ 191 | "logging": { 192 | "job_id": jobid, 193 | "job_name": "sample_agent_{}".format(jobname), 194 | "logging_path": os.path.join(agent_dir, "sampling.log"), 195 | "recipient": "local", 196 | "sender": "http://127.0.0.1" 197 | }, 198 | "parameters": { 199 | "model_path": os.path.join(agent_dir, "Agent.ckpt"), 200 | "output_smiles_path": os.path.join(agent_dir, "sampled_N_{}.csv".format(N)), 201 | "num_smiles": N, 202 | "batch_size": 128, 203 | "with_likelihood": False 204 | }, 205 | "run_type": "sampling", 206 | "version": 2 207 | } 208 | conf_filename = os.path.join(agent_dir, "evaluate_agent_config.json") 209 | with open(conf_filename, 'w') as f: 210 | json.dump(configuration, f, indent=4, sort_keys=True) 211 | return conf_filename 212 | -------------------------------------------------------------------------------- /Adaptive-MPO/usermodel_doublesigmoid.stan: -------------------------------------------------------------------------------- 1 | functions { 2 | real log_double_sigmoid(real value, real low, real high, real coef_div, real coef_si, real coef_se){ 3 | real A; 4 | real B; 5 | real C; 6 | A = 10^(coef_se * (value / coef_div)); 7 | B = (10^(coef_se * (value / coef_div)) + 10^(coef_se * (low / coef_div))); 8 | C = (10^(coef_si * (value / coef_div)) / (10^(coef_si * (value / coef_div)) + 10^(coef_si * (high / coef_div)))); 9 | if((A / B) - C < 0) 10 | return 0; 11 | else 12 | return log((A / B) - C); 13 | } 14 | 15 | real aggregated_score_prod(vector weights, vector x_raw, vector lows, vector deltas, vector coef_div, vector coef_si, vector coef_se, int k) { 16 | real score; 17 | score = 0; 18 | for (l in 1:k) { 19 | score = score + weights[l]*log_double_sigmoid(x_raw[l],lows[l], lows[l]+deltas[l], coef_div[l], coef_si[l], coef_se[l]); 20 | } 21 | score = exp(score); 22 | return score; 23 | } 24 | } 25 | 26 | data { 27 | int n; // number of data points 28 | int k; // number of scoring components 29 | vector[k] x_raw[n]; // raw scores 30 | int y[n]; // binary response variable 31 | vector[k] weights; // Non-negative weights of the user-model (fixed, sum 1) 32 | vector[k] coef_div; // params of double sigmoid (fixed) 33 | vector[k] coef_si; // params of double sigmoid (fixed) 34 | vector[k] coef_se; // params of double sigmoid (fixed) 35 | vector[k] high0; // initial guesses of the score transform variables 36 | vector[k] low0; 37 | int npred; 38 | vector[k] xpred[npred]; // all molecules (for active learning) 39 | } 40 | 41 | parameters { 42 | vector[k] lows; 43 | vector[k] deltas; 44 | } 45 | model { 46 | for (i in 1:k) { 47 | lows[i] ~ normal(low0[i], (high0[i]-low0[i])/8); 48 | deltas[i] ~ normal(high0[i]-low0[i], (high0[i]-low0[i])/8); 49 | } 50 | // observation model 51 | for (j in 1:n) { 52 | y[j] ~ bernoulli(aggregated_score_prod(weights, x_raw[j], lows, deltas, coef_div, coef_si, coef_se, k)); 53 | } 54 | } 55 | generated quantities { 56 | vector[k] highs; 57 | vector[npred] score_pred; 58 | for (i in 1:k) 59 | highs[i] = lows[i]+deltas[i]; 60 | for (j in 1:npred){ 61 | score_pred[j] = aggregated_score_prod(weights, xpred[j], lows, deltas, coef_div, coef_si, coef_se, k); 62 | } 63 | } -------------------------------------------------------------------------------- /Chemists-Component/Task2_Chemists_Component.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "from multiprocessing import Pool\n", 11 | "import pandas as pd\n", 12 | "import pickle\n", 13 | "import json\n", 14 | "from datetime import datetime\n", 15 | "import rdkit, scipy, sklearn\n", 16 | "import tensorflow as tf\n", 17 | "from rdkit.Chem.Draw import IPythonConsole\n", 18 | "from rdkit.Chem import PandasTools\n", 19 | "from rdkit.Chem import AllChem as Chem\n", 20 | "from rdkit.Chem import DataStructs, Descriptors\n", 21 | "from rdkit.Chem import MACCSkeys\n", 22 | "from rdkit import Avalon\n", 23 | "from rdkit.Avalon import pyAvalonTools\n", 24 | "from sklearn.metrics import accuracy_score, cohen_kappa_score, matthews_corrcoef, recall_score, precision_score,f1_score\n", 25 | "\n", 26 | "\n", 27 | "from load_data import load_data\n", 28 | "from write import write_config_file, write_query_to_csv, write_runs_sh, write_idx, write_training_data, write_sample_file, write_run_sample\n", 29 | "from acquisition import select_query\n", 30 | "from models import Tanimoto_model\n", 31 | "from query import query\n", 32 | "\n", 33 | "\n", 34 | "from scipy.spatial.distance import *\n", 35 | "import numpy as np\n", 36 | "\n", 37 | "import matplotlib.pyplot as plt\n", 38 | "%matplotlib inline\n", 39 | "\n", 40 | "\n", 41 | "import pandas as pd\n", 42 | "\n", 43 | "import gpflow\n", 44 | "from gpflow.utilities import positive\n", 45 | "from gpflow.utilities.ops import broadcasting_elementwise\n", 46 | "from gpflow.mean_functions import Constant\n", 47 | "import tensorflow as tf\n", 48 | "import reinvent_scoring\n", 49 | "\n", 50 | "\n", 51 | "from sklearn.model_selection import train_test_split\n", 52 | "from sklearn.preprocessing import StandardScaler\n", 53 | "import time" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "# configuration" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "experiment = 1 # options: numbers 1-5 that define which seed is used in user demo\n", 70 | "\n", 71 | "seeds = [1718, 1896, 3975, 8355, 9774] \n", 72 | "seed = seeds[experiment-1]\n", 73 | "np.random.seed(seed)\n", 74 | "tf.random.set_seed(seed)\n", 75 | "\n", 76 | "acquisition='thompson' #options: 'thompson', 'uncertainty', 'random', 'greedy'\n", 77 | "READ_ONLY = False # Use 'True' to playback an existing experiment (reads queries and feedback from files instead of using the algorithm)\n", 78 | "usecase = 'drd2'\n", 79 | "simulated_human = True\n", 80 | "verbose = False # print out details during interaction\n", 81 | "loop, last_job=0, None # Do not change" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "def parse_usecase(usecase):\n", 91 | " usecase_params = {\n", 92 | " 'drd2': {\n", 93 | " 'train_data': 'drd2_regression.train.csv',\n", 94 | " 'test_data': 'drd2_regression.test.csv',\n", 95 | " 'y_field': 'activity'\n", 96 | " }\n", 97 | " }\n", 98 | "\n", 99 | " train_data_file = usecase_params[usecase]['train_data']\n", 100 | " test_data_file = usecase_params[usecase]['test_data']\n", 101 | " y_field = usecase_params[usecase]['y_field']\n", 102 | " return train_data_file , test_data_file, y_field\n", 103 | "\n", 104 | "def load_config(acquisition,seed,loop, jobid=None):\n", 105 | " if not jobid:\n", 106 | " jobid=datetime.now().strftime(\"%d-%m-%Y\")\n", 107 | " else:\n", 108 | " jobid=jobid\n", 109 | " jobname = 'Task2_demo_{}'.format(acquisition)\n", 110 | "\n", 111 | " N0=10 # size of initial training data L\n", 112 | " n_restarts=2 # parameters of GP optimization\n", 113 | " n_batch=10 # number of molecules shown at each iteration\n", 114 | " n_iteration=10 # the number of iteration\n", 115 | " fpdim=1024 #dimension of morgan fingerprint\n", 116 | " step=1\n", 117 | " cwd=os.getcwd()\n", 118 | " \n", 119 | " reinvent_dir = os.path.expanduser(\"/path/to/Reinvent\")\n", 120 | " reinvent_env = os.path.expanduser(\"/path/to/conda_environment/for/Reinvent\")\n", 121 | " output_dir=os.path.join(cwd,\"./results/{}_{}_seed_{}\".format(jobname, jobid, seed))\n", 122 | " \n", 123 | " if not os.path.exists(output_dir):\n", 124 | " os.mkdir(output_dir)\n", 125 | " output_dir=os.path.join(output_dir,\"./loop{}\".format(loop))\n", 126 | " if not os.path.exists(output_dir):\n", 127 | " os.mkdir(output_dir)\n", 128 | " lastloop_dir=None\n", 129 | " return jobid, jobname, N0, n_restarts, n_batch, n_iteration, fpdim, step, reinvent_dir, reinvent_env, output_dir, lastloop_dir\n", 130 | "\n", 131 | "def load_data_config(usecase, lastloop_dir=None):\n", 132 | " train_num=10000 # sample 10.000 molecules from training dataset to U\n", 133 | " test_num=2000 # sample 2000 molecules from testing dataset to the test set\n", 134 | " data_dir= './data/'\n", 135 | " train_data_file , test_data_file, y_field=parse_usecase(usecase)\n", 136 | "\n", 137 | " test_data_path=os.path.join(data_dir,test_data_file) if test_data_file else None\n", 138 | "\n", 139 | " if lastloop_dir:\n", 140 | " train_data_path=os.path.join(lastloop_dir,train_data_file)\n", 141 | " else:\n", 142 | " train_data_path=os.path.join(data_dir,train_data_file)\n", 143 | " return train_num, test_num, train_data_path, test_data_path, y_field" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "def train(X_updated, y_updated, jobid, acquisition,seed,loop, iteration):\n", 153 | " jobid, jobname, N0, n_restarts, n_batch, _, fpdim, step,\\\n", 154 | " reinvent_dir, reinvent_env, output_dir, _ =load_config(acquisition,seed,loop, jobid)\n", 155 | " start=time.time()\n", 156 | " gpc=Tanimoto_model(X_updated, y_updated)\n", 157 | " model_dir = output_dir+\"/models/model_{}\".format(iteration)\n", 158 | " if (iteration % step)==0:\n", 159 | " gpc.predict_y_compiled = tf.function(gpc.predict_f, input_signature=[tf.TensorSpec(shape=[None, fpdim], dtype=tf.float64)])\n", 160 | " tf.saved_model.save(gpc, model_dir)\n", 161 | " print('save model at path {}'.format(model_dir))\n", 162 | " conf_filename = write_config_file(jobid, jobname, reinvent_dir, reinvent_env, output_dir, fpdim, loop, iteration, model_dir, seed)\n", 163 | " write_sample_file(jobid, jobname, output_dir, loop, iteration)\n", 164 | " print('training model at iteration {} use {} s'.format(iteration, time.time()-start))\n", 165 | " return gpc" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "def get_labels(p):\n", 175 | " label=p.copy()\n", 176 | " label[label>0.5]=1\n", 177 | " label[label<0.5]=0\n", 178 | " label[label==0.5]=np.random.binomial(1, 0.5,sum(label==0.5))\n", 179 | " return label.astype(int)\n", 180 | "\n", 181 | "def evaluate(model,X,Y,accuracy,recall, precision, f1, MCC, Kappa, rmse):\n", 182 | " mean, var = model.predict_f(X)\n", 183 | " rmse+= [np.sqrt(np.mean((mean-Y)**2))]\n", 184 | " pred=get_labels(mean.numpy())\n", 185 | " Y=get_labels(Y)\n", 186 | " recall+=[recall_score(Y,pred)]\n", 187 | " accuracy += [accuracy_score(Y, pred)]\n", 188 | " precision+=[precision_score(Y,pred)]\n", 189 | " f1+=[f1_score(Y,pred)]\n", 190 | " MCC += [matthews_corrcoef(Y, pred)]\n", 191 | " Kappa += [cohen_kappa_score(Y, pred)]\n", 192 | " return accuracy, recall, precision, f1, MCC, Kappa, rmse" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "# HITL" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "# load configuration \n", 209 | "jobid, jobname, N0, n_restarts, n_batch, n_iteration, fpdim, step, reinvent_dir, reinvent_env, output_dir, lastloop_dir = load_config(acquisition,seed, loop, last_job)\n", 210 | "\n", 211 | "# load training data configuration\n", 212 | "train_num, test_num, train_data_path, test_data_path, y_field = load_data_config(usecase,lastloop_dir)\n", 213 | "\n", 214 | "# load data\n", 215 | "X_train, X_test, y_train, y_test, smiles_train, id_train, X_L, X_U, y_L, y_U, L , U = load_data(train_data_path, test_data_path, output_dir, train_num, N0, y_field, id_field=None, ext='csv', sampling=True, normalization=False)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "X_updated=X_L\n", 225 | "y_updated=y_L\n", 226 | "# Fit model with D_0\n", 227 | "start=time.time()\n", 228 | "gpc=train(X_updated, y_updated, jobid, acquisition,seed, loop, 0)\n", 229 | "print(\"training spend {} s\".format(time.time()-start))\n", 230 | "\n", 231 | "# Evaluate performance in test data\n", 232 | "accuracy, recall, precision, f1, MCC , Kappa, rmse =[] ,[], [], [], [], [], []\n", 233 | "start=time.time()\n", 234 | "accuracy, recall, precision, f1, MCC , Kappa, rmse = evaluate(gpc, X_test, y_test, accuracy, recall, precision, f1, MCC, Kappa, rmse)\n", 235 | "print(\"evaluation spend {} s\".format(time.time()-start))\n", 236 | "if verbose:\n", 237 | " print('accuracy is {}'.format(accuracy))\n", 238 | " print('recall is {}'.format(recall))\n", 239 | " print('MCC is {}'.format(MCC))\n", 240 | " print('Kappa is {}'.format(Kappa))\n", 241 | " print('rmse is {}'.format(rmse))\n", 242 | "\n", 243 | "# Baseline prediction: mean of training data\n", 244 | "pred_baseline = np.repeat(1, len(y_test))\n", 245 | "label_test=get_labels(y_test)\n", 246 | "accuracy_baseline = accuracy_score(label_test, pred_baseline)\n", 247 | "recall_baseline=recall_score(label_test, pred_baseline)\n", 248 | "MCC_baseline = matthews_corrcoef(label_test, pred_baseline)\n", 249 | "Kappa_baseline = cohen_kappa_score(label_test, pred_baseline)\n", 250 | "\n", 251 | "if verbose: \n", 252 | " print(\"Baseline accuracy: {}\".format(accuracy_baseline))" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "print(\"start iteration\")\n", 262 | "positive_molecules_num=np.array([np.sum(y_updated)],dtype=np.int)\n", 263 | "idx_query=np.array([],dtype='i')\n", 264 | "all_smiles_query=[]\n", 265 | "D0_smiles = smiles_train[L]\n", 266 | "\n", 267 | "for iteration in np.arange(1,n_iteration+1):\n", 268 | " # CREATE QUERY\n", 269 | " print('it: ' + str(iteration))\n", 270 | " queryfile_identifier = 'query_it{}'.format(iteration) \n", 271 | " if not READ_ONLY:\n", 272 | " # Elicit n_batch feedback\n", 273 | " t = iteration\n", 274 | " i_query = select_query(U, X_train, n_batch, gpc, acquisition, X_updated, y_updated, L, t)\n", 275 | " elif READ_ONLY:\n", 276 | " query_csv = pd.read_csv(output_dir + '/query/{}.csv'.format(queryfile_identifier))\n", 277 | " sq = query_csv['SMILES']\n", 278 | " i_query = np.array([list(smiles_train).index(s) for s in sq])\n", 279 | "\n", 280 | " idx_query=np.append(idx_query,i_query)\n", 281 | " U = np.setdiff1d(U,i_query)\n", 282 | " if verbose:\n", 283 | " print(\"Feedback on n={}\".format(i_query))\n", 284 | " print(\"Size of pool: {}\".format(len(U)))\n", 285 | "\n", 286 | " query_smiles =smiles_train[i_query]\n", 287 | " all_smiles_query += [query_smiles]\n", 288 | "\n", 289 | " # Write the query to a csv file\n", 290 | " if not READ_ONLY:\n", 291 | " write_query_to_csv(query_smiles, id_train, i_query, output_dir + '/query/{}.csv'.format(queryfile_identifier), output_dir)\n", 292 | "\n", 293 | " if not simulated_human:\n", 294 | " input(\"Press Enter to continue...\")\n", 295 | " # READ RESPONSE QUERY\n", 296 | " print(\"Read response at iteration {}\".format(iteration))\n", 297 | " # Response\n", 298 | " if simulated_human:\n", 299 | " y_response = y_train[i_query]\n", 300 | " unkown_idx=np.where(y_response==-1)[0] \n", 301 | " if(len(unkown_idx)>0): # If simulated feedback not pre-computed\n", 302 | " y_response[unkown_idx]=query(query_smiles[unkown_idx])\n", 303 | " #update y_train\n", 304 | " y_train[i_query[unkown_idx]]=y_response[unkown_idx]\n", 305 | " else:\n", 306 | " y_csv = pd.read_csv(output_dir + '/query/{}_with_ratings.csv'.format(queryfile_identifier)) \n", 307 | " smiles_response = list(y_csv['SMILES'])\n", 308 | " responses = y_csv['rating'] \n", 309 | " order = [smiles_response.index(smiles_train[i]) for i in i_query]\n", 310 | " y_response = np.array([responses[i] for i in order], dtype=float) # match responses using smiles\n", 311 | " y_response=y_response.reshape(len(y_response),1)\n", 312 | " # filter out molecules without feedback (rating=0)\n", 313 | " got_feedback = y_response != 0\n", 314 | " y_response = y_response[got_feedback]\n", 315 | " i_query = i_query[np.squeeze(got_feedback)]\n", 316 | " # parse numerical feedback from columns 1=0, 5=1 and the rest in between\n", 317 | " y_response = (y_response - 1)/4.0\n", 318 | " y_response=y_response.reshape(len(y_response),1)\n", 319 | " if verbose:\n", 320 | " print(\"response is\")\n", 321 | " print(y_response)\n", 322 | "\n", 323 | " # Fit a new model, evaluate performance in test data\n", 324 | " X_updated = np.vstack((X_updated, X_train[i_query,:]))\n", 325 | " y_updated = np.concatenate((y_updated, y_response))\n", 326 | " gpc=train(X_updated, y_updated,jobid, acquisition,seed,loop, iteration)\n", 327 | " positive_molecules_num=np.append(positive_molecules_num,np.sum(y_updated))\n", 328 | " accuracy, recall, precision, f1, MCC , Kappa, rmse = evaluate(gpc, X_test, y_test, accuracy, recall, precision, f1, MCC, Kappa, rmse)\n", 329 | "\n", 330 | " if verbose:\n", 331 | " print(accuracy[iteration])\n", 332 | " print(precision[iteration])\n", 333 | " print(f1[iteration])\n", 334 | " print(recall[iteration])\n", 335 | " print(MCC[iteration])\n", 336 | " print(Kappa[iteration])\n", 337 | " print(rmse[iteration])\n", 338 | "\n", 339 | "\n", 340 | "\n", 341 | "dat_save = {\n", 342 | " 'hitl params': {'N0': N0, 'T': iteration, 'n_batch': n_batch, 'step': step, 'acquisition': acquisition},\n", 343 | " 'accuracy': accuracy,\n", 344 | " 'recall': recall,\n", 345 | " 'precision':precision,\n", 346 | " 'f1':f1,\n", 347 | " 'MCC': MCC,\n", 348 | " 'Kappa':Kappa,\n", 349 | " 'rmse':rmse,\n", 350 | " 'baseline accuracy': accuracy_baseline,\n", 351 | " 'baseline_recall':recall_baseline,\n", 352 | " 'baseline_MCC':MCC_baseline,\n", 353 | " 'baseline_Kappa':Kappa_baseline,\n", 354 | " 'idx_query':idx_query,\n", 355 | " 'smiles_D0': D0_smiles,\n", 356 | " 'smiles_query':all_smiles_query,\n", 357 | " 'positive_molecules_num':positive_molecules_num\n", 358 | "}\n", 359 | "filename = output_dir + '/log_{}_it{}.p'.format(acquisition,iteration)\n", 360 | "with open(filename , 'wb') as f:\n", 361 | " pickle.dump(dat_save, f)\n", 362 | "\n", 363 | "if verbose:\n", 364 | " print('accuracy is {}'.format(accuracy))\n", 365 | " print('recall is {}'.format(recall))\n", 366 | " print('MCC is {}'.format(MCC))\n", 367 | " print('Kappa is {}'.format(Kappa))\n", 368 | " print('rmse is {}'.format(rmse))\n", 369 | "\n", 370 | "# Create shell scripts for evaluating performance of Chemist's component as Reinvent scoring function: \n", 371 | "# Script for reinforcement learning training of agent at each iteration\n", 372 | "write_runs_sh(seed, output_dir, reinvent_env, reinvent_dir, step, n_iteration)\n", 373 | "# Script for sampling molecules to evaluate the quality of Reinvent output once the agents are trained\n", 374 | "write_run_sample(seed, output_dir, reinvent_env, reinvent_dir, step, n_iteration)" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": null, 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [ 383 | "# Visualize performance of Chemist's component as predictive model in the test data\n", 384 | "plt.plot(MCC)\n", 385 | "plt.xlabel('iteration')\n", 386 | "plt.ylabel('MCC')\n", 387 | "plt.title(\"Chemist's-component performance in the test set (drd2)\")" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [ 396 | "plt.plot(recall)\n", 397 | "plt.xlabel('iteration')\n", 398 | "plt.ylabel('recall')\n", 399 | "plt.title(\"Chemist's-component performance in the test set (drd2)\")" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": null, 405 | "metadata": {}, 406 | "outputs": [], 407 | "source": [ 408 | "plt.plot(accuracy)\n", 409 | "plt.xlabel('iteration')\n", 410 | "plt.ylabel('accuracy')\n", 411 | "plt.title(\"Chemist's-component performance in the test set (drd2)\")" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": null, 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "plt.plot(Kappa)\n", 421 | "plt.xlabel('iteration')\n", 422 | "plt.ylabel('Kappa')\n", 423 | "plt.title(\"Chemist's-component performance in the test set (drd2)\")" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": null, 429 | "metadata": {}, 430 | "outputs": [], 431 | "source": [] 432 | } 433 | ], 434 | "metadata": { 435 | "kernelspec": { 436 | "display_name": "chemist-component", 437 | "language": "python", 438 | "name": "cc_env" 439 | }, 440 | "language_info": { 441 | "codemirror_mode": { 442 | "name": "ipython", 443 | "version": 3 444 | }, 445 | "file_extension": ".py", 446 | "mimetype": "text/x-python", 447 | "name": "python", 448 | "nbconvert_exporter": "python", 449 | "pygments_lexer": "ipython3", 450 | "version": "3.7.3" 451 | } 452 | }, 453 | "nbformat": 4, 454 | "nbformat_minor": 4 455 | } 456 | -------------------------------------------------------------------------------- /Chemists-Component/acquisition.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import norm 3 | import time 4 | 5 | 6 | 7 | def random_sampling(U, X_train, n, model, acquisition, X_updated, y_updated, L=None,t=None): 8 | ret = np.random.choice(U, n, replace=False) 9 | return ret 10 | 11 | def uncertainty_sampling(U, X_train, n, model, acquisition, X_updated, y_updated, L, t=None): 12 | X_pool = X_train[U,:] 13 | preds_p, var = model.predict_f(X_pool) 14 | idx = np.argsort(np.squeeze(var))[-n:] 15 | return U[idx] 16 | 17 | # exploitation 18 | def greedy_algorithm(U, X_train, n, model, acquisition, X_updated, y_updated, L=None, t=None): 19 | X_pool = X_train[U,:] 20 | preds_p,_ = model.predict_f(X_pool) 21 | preds_p=np.squeeze(preds_p) 22 | idx = np.argsort(preds_p)[::-1] 23 | assert len(idx)>0 24 | selected = idx[:n] 25 | return U[selected] 26 | 27 | 28 | def thompson_sampling(U, X_U, n, model, acquisition, X_updated, y_updated, L=None, t=None): 29 | X_pool =X_U[U,:] 30 | # a sample from posterior 31 | preds_p = model.predict_f_samples(X_pool) 32 | preds_p = np.squeeze(preds_p) 33 | # greedily maximize with respect to the randomly sampled belief 34 | idx = np.argsort(preds_p)[::-1] 35 | assert len(idx)>0 36 | selected = idx[:n] 37 | return U[selected] 38 | 39 | def select_query(U, X_train, n, model, acquisition='random', X_updated=None, y_updated=None, L=None, t=None): 40 | ''' 41 | Parameters 42 | ---------- 43 | U : Indices of X_train that are unlabeled. 44 | X_train : Original training data (labeled and unlabeled). Current pool of unlabeled samples: X_train[U,:] 45 | n : number of queries to select. 46 | model : Current prediction model. 47 | acquisition : acquisition method, optional. The default is 'random'. 48 | 49 | Returns 50 | ------- 51 | int idx: 52 | Index of the query; idx \in U. Features of the query: X_train[idx,:] 53 | 54 | ''' 55 | # select acquisition: 56 | if acquisition == 'uncertainty': 57 | acq = uncertainty_sampling 58 | elif acquisition == 'thompson': 59 | acq = thompson_sampling 60 | elif acquisition == 'random': 61 | acq = random_sampling 62 | elif acquisition == 'greedy': 63 | acq = greedy_algorithm 64 | else: 65 | print("Warning: Unknown acquisition function. Using random sampling.") 66 | acq = random_sampling 67 | return acq(U, X_train, n, model, acquisition, X_updated, y_updated, L, t) -------------------------------------------------------------------------------- /Chemists-Component/cc_env_hitl.yml: -------------------------------------------------------------------------------- 1 | name: cc_env_hitl 2 | channels: 3 | - rdkit 4 | - openeye 5 | - omnia 6 | - anaconda 7 | - conda-forge 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1 11 | - _openmp_mutex=4.5 12 | - _py-xgboost-mutex=2.0 13 | - _tflow_select=2.3.0 14 | - absl-py=0.8.1 15 | - aiohttp=3.6.2 16 | - asn1crypto=1.2.0 17 | - astor=0.7.1 18 | - async-timeout=3.0.1 19 | - atomicwrites=1.3.0 20 | - attrs=19.3.0 21 | - backcall=0.1.0 22 | - blas=2.14 23 | - bleach=3.1.0 24 | - bzip2=1.0.8 25 | - c-ares=1.15.0 26 | - ca-certificates=2021.4.13 27 | - cairo=1.16.0 28 | - certifi=2020.12.5 29 | - cffi=1.13.2 30 | - chardet=3.0.4 31 | - cryptography=2.8 32 | - cudatoolkit=10.1.243 33 | - cudnn=7.6.4 34 | - cycler=0.10.0 35 | - dbus=1.13.6 36 | - decorator=4.4.1 37 | - deepsmiles=1.0.1 38 | - defusedxml=0.6.0 39 | - dill=0.3.1.1 40 | - entrypoints=0.3 41 | - expat=2.2.5 42 | - ffmpeg=4.0 43 | - fontconfig=2.13.1 44 | - freeglut=3.0.0 45 | - freetype=2.10.0 46 | - future=0.18.2 47 | - gast=0.3.2 48 | - gettext=0.19.8.1 49 | - glib=2.58.3 50 | - google-pasta=0.1.8 51 | - graphite2=1.3.13 52 | - gst-plugins-base=1.14.5 53 | - gstreamer=1.14.5 54 | - harfbuzz=1.8.8 55 | - hdf5=1.10.2 56 | - icu=58.2 57 | - idna=2.8 58 | - importlib_metadata=0.23 59 | - intel-openmp=2019.4 60 | - ipykernel=5.1.3 61 | - ipython=7.9.0 62 | - ipython_genutils=0.2.0 63 | - ipywidgets=7.5.1 64 | - jasper=2.0.14 65 | - jedi=0.15.1 66 | - jinja2=2.10.3 67 | - joblib=0.14.0 68 | - jpeg=9c 69 | - jsonschema=3.2.0 70 | - jupyter=1.0.0 71 | - jupyter_client=5.3.3 72 | - jupyter_console=6.0.0 73 | - jupyter_core=4.6.1 74 | - keras-applications=1.0.8 75 | - kiwisolver=1.1.0 76 | - libblas=3.8.0 77 | - libboost=1.67.0 78 | - libcblas=3.8.0 79 | - libedit=3.1.20170329 80 | - libffi=3.2.1 81 | - libgcc-ng=9.1.0 82 | - libgfortran=3.0.0 83 | - libgfortran-ng=7.3.0 84 | - libglu=9.0.0 85 | - libiconv=1.15 86 | - liblapack=3.8.0 87 | - liblapacke=3.8.0 88 | - libopenblas=0.3.7 89 | - libopencv=3.4.2 90 | - libopus=1.3.1 91 | - libpng=1.6.37 92 | - libprotobuf=3.10.1 93 | - libsodium=1.0.17 94 | - libstdcxx-ng=9.1.0 95 | - libtiff=4.1.0 96 | - libuuid=2.32.1 97 | - libvpx=1.7.0 98 | - libwebp-base=1.1.0 99 | - libxcb=1.13 100 | - libxgboost=0.90 101 | - libxml2=2.9.9 102 | - llvm-openmp=10.0.1 103 | - lz4-c=1.8.3 104 | - markdown=3.1.1 105 | - markupsafe=1.1.1 106 | - mistune=0.8.4 107 | - mkl=2019.5 108 | - mkl-service=2.3.0 109 | - mkl_fft=1.0.15 110 | - mkl_random=1.1.0 111 | - more-itertools=7.2.0 112 | - multidict=4.7.5 113 | - multiprocess=0.70.9 114 | - nbconvert=5.6.1 115 | - nbformat=4.4.0 116 | - ncurses=6.1 117 | - nest-asyncio=1.3.3 118 | - ninja=1.9.0 119 | - notebook=6.0.1 120 | - numpy=1.17.3 121 | - numpy-base=1.17.3 122 | - olefile=0.46 123 | - opencv=3.4.2 124 | - openeye-toolkits=2019.10.2 125 | - openssl=1.1.1g 126 | - packaging=19.2 127 | - pandas=0.25.3 128 | - pandoc=2.7.3 129 | - pandocfilters=1.4.2 130 | - parso=0.5.1 131 | - pathos=0.2.5 132 | - pcre=8.43 133 | - pexpect=4.7.0 134 | - pickleshare=0.7.5 135 | - pillow=6.2.1 136 | - pip=19.3.1 137 | - pixman=0.38.0 138 | - pluggy=0.13.0 139 | - pox=0.2.7 140 | - ppft=1.6.6.1 141 | - prometheus_client=0.7.1 142 | - prompt_toolkit=2.0.10 143 | - protobuf=3.10.1 144 | - psutil=5.6.7 145 | - pthread-stubs=0.4 146 | - ptyprocess=0.6.0 147 | - py=1.8.0 148 | - py-boost=1.67.0 149 | - py-opencv=3.4.2 150 | - py-xgboost=0.90 151 | - pycparser=2.19 152 | - pygments=2.4.2 153 | - pyopenssl=19.0.0 154 | - pyparsing=2.4.5 155 | - pyqt=5.9.2 156 | - pyrsistent=0.15.5 157 | - pysocks=1.7.1 158 | - pytest=5.3.0 159 | - python=3.7.3 160 | - python-dateutil=2.8.1 161 | - python_abi=3.7 162 | - pytz=2019.3 163 | - pyzmq=18.1.1 164 | - qt=5.9.7 165 | - qtconsole=4.6.0 166 | - rdkit=2019.09.1.0 167 | - readline=8.0 168 | - requests=2.22.0 169 | - scikit-learn=0.21.2 170 | - scipy=1.3.2 171 | - send2trash=1.5.0 172 | - setuptools=41.6.0 173 | - sip=4.19.8 174 | - six=1.13.0 175 | - sqlite=3.30.1 176 | - termcolor=1.1.0 177 | - terminado=0.8.3 178 | - testpath=0.4.4 179 | - tk=8.6.9 180 | - tornado=6.0.3 181 | - tqdm=4.38.0 182 | - traitlets=4.3.3 183 | - typing=3.6.4 184 | - urllib3=1.25.7 185 | - wcwidth=0.1.7 186 | - webencodings=0.5.1 187 | - werkzeug=0.16.0 188 | - wheel=0.33.6 189 | - widgetsnbextension=3.5.1 190 | - wrapt=1.11.2 191 | - xgboost=0.90 192 | - xorg-fixesproto=5.0 193 | - xorg-inputproto=2.3.2 194 | - xorg-kbproto=1.0.7 195 | - xorg-libice=1.0.10 196 | - xorg-libsm=1.2.3 197 | - xorg-libx11=1.6.9 198 | - xorg-libxau=1.0.9 199 | - xorg-libxdmcp=1.1.3 200 | - xorg-libxext=1.3.4 201 | - xorg-libxfixes=5.0.3 202 | - xorg-libxi=1.7.10 203 | - xorg-libxrender=0.9.10 204 | - xorg-renderproto=0.11.1 205 | - xorg-xextproto=7.3.0 206 | - xorg-xproto=7.0.31 207 | - xz=5.2.4 208 | - yarl=1.4.2 209 | - zeromq=4.3.2 210 | - zipp=0.6.0 211 | - zlib=1.2.11 212 | - zstd=1.4.4 213 | - pip: 214 | - astunparse==1.6.3 215 | - cached-property==1.5.2 216 | - cachetools==4.2.4 217 | - cloudpickle==2.0.0 218 | - cython==0.29.21 219 | - dacite==1.5.1 220 | - dataclasses==0.6 221 | - deprecated==1.2.13 222 | - dm-tree==0.1.6 223 | - flatbuffers==2.0 224 | - google-auth==2.3.3 225 | - google-auth-oauthlib==0.4.6 226 | - gpflow==2.3.1 227 | - grpcio==1.43.0 228 | - h5py==3.6.0 229 | - keras==2.7.0 230 | - keras-preprocessing==1.1.2 231 | - libclang==12.0.0 232 | - multipledispatch==0.6.0 233 | - oauthlib==3.1.1 234 | - opt-einsum==3.3.0 235 | - pyasn1==0.4.8 236 | - pyasn1-modules==0.2.8 237 | - pydantic==1.8.2 238 | - reinvent-chemistry==0.0.51 239 | - reinvent-models==0.0.15rc1 240 | - reinvent-scoring==0.0.73 241 | - requests-oauthlib==1.3.0 242 | - rsa==4.8 243 | - seaborn==0.11.1 244 | - tabulate==0.8.9 245 | - tensorboard==2.8.0 246 | - tensorboard-data-server==0.6.1 247 | - tensorboard-plugin-wit==1.8.1 248 | - tensorflow==2.7.0 249 | - tensorflow-estimator==2.7.0 250 | - tensorflow-io-gcs-filesystem==0.23.1 251 | - tensorflow-probability==0.15.0 252 | - torch==1.7.1 253 | - typing-extensions==4.0.1 254 | -------------------------------------------------------------------------------- /Chemists-Component/cc_env_reinvent.yml: -------------------------------------------------------------------------------- 1 | name: cc_env_reinvent 2 | channels: 3 | - rdkit 4 | - pytorch 5 | - openeye 6 | - conda-forge 7 | - anaconda 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1 11 | - _openmp_mutex=4.5 12 | - _pytorch_select=0.1 13 | - _tflow_select=2.3.0 14 | - aiohttp=3.6.2 15 | - argon2-cffi=20.1.0 16 | - astunparse=1.6.3 17 | - async-timeout=3.0.1 18 | - async_generator=1.10 19 | - attrs=19.3.0 20 | - backcall=0.2.0 21 | - blas=1.0 22 | - bleach=3.2.1 23 | - blinker=1.4 24 | - brotlipy=0.7.0 25 | - bzip2=1.0.8 26 | - c-ares=1.17.1 27 | - ca-certificates=2021.4.13 28 | - cachetools=4.2.2 29 | - cairo=1.14.12 30 | - certifi=2020.12.5 31 | - cffi=1.14.3 32 | - chardet=3.0.4 33 | - click=8.0.0 34 | - cmarkgfm=0.4.2 35 | - colorama=0.4.4 36 | - coverage=5.5 37 | - cryptography=3.2.1 38 | - cudatoolkit=10.2.89 39 | - cycler=0.10.0 40 | - cython=0.29.23 41 | - dacite=1.5.1 42 | - dbus=1.13.12 43 | - decorator=5.0.7 44 | - defusedxml=0.7.1 45 | - dill=0.3.1.1 46 | - docutils=0.16 47 | - entrypoints=0.3 48 | - expat=2.2.9 49 | - fontconfig=2.13.0 50 | - freetype=2.9.1 51 | - glib=2.63.1 52 | - google-auth=1.28.0 53 | - google-auth-oauthlib=0.4.1 54 | - google-pasta=0.2.0 55 | - gst-plugins-base=1.14.0 56 | - gstreamer=1.14.0 57 | - h5py=2.10.0 58 | - hdf5=1.10.6 59 | - icu=58.2 60 | - idna=2.10 61 | - importlib-metadata=3.10.0 62 | - importlib_metadata=1.5.0 63 | - intel-openmp=2020.0 64 | - ipykernel=5.3.4 65 | - ipython=7.22.0 66 | - ipython_genutils=0.2.0 67 | - ipywidgets=7.6.3 68 | - jedi=0.17.0 69 | - jeepney=0.6.0 70 | - jinja2=2.11.3 71 | - joblib=0.15.1 72 | - jpeg=9b 73 | - jsonschema=3.2.0 74 | - jupyter=1.0.0 75 | - jupyter_client=6.1.12 76 | - jupyter_console=6.4.0 77 | - jupyter_core=4.7.1 78 | - jupyterlab_pygments=0.1.2 79 | - jupyterlab_widgets=1.0.0 80 | - keras-preprocessing=1.1.2 81 | - keyring=21.5.0 82 | - kiwisolver=1.1.0 83 | - ld_impl_linux-64=2.33.1 84 | - libblas=3.8.0 85 | - libboost=1.67.0 86 | - libcblas=3.8.0 87 | - libedit=3.1.20181209 88 | - libffi=3.2.1 89 | - libgcc-ng=9.1.0 90 | - libgfortran-ng=7.3.0 91 | - libgomp=9.3.0 92 | - libiconv=1.16 93 | - liblapack=3.8.0 94 | - libpng=1.6.37 95 | - libprotobuf=3.14.0 96 | - libsodium=1.0.18 97 | - libstdcxx-ng=9.1.0 98 | - libtiff=4.1.0 99 | - libuuid=1.0.3 100 | - libuv=1.40.0 101 | - libxcb=1.13 102 | - libxml2=2.9.9 103 | - libzlib=1.2.11 104 | - llvm-openmp=14.0.4 105 | - lz4=3.1.1 106 | - lz4-c=1.9.2 107 | - markupsafe=1.1.1 108 | - matplotlib=3.1.3 109 | - matplotlib-base=3.1.3 110 | - mistune=0.8.4 111 | - mkl=2020.0 112 | - mkl-service=2.3.0 113 | - mkl_fft=1.0.15 114 | - mkl_random=1.1.0 115 | - more-itertools=8.2.0 116 | - multidict=4.7.3 117 | - multiprocess=0.70.9 118 | - nbclient=0.5.3 119 | - nbconvert=6.0.7 120 | - nbformat=5.1.3 121 | - ncurses=6.2 122 | - nest-asyncio=1.5.1 123 | - ninja=1.9.0 124 | - notebook=6.3.0 125 | - numpy=1.18.1 126 | - oauthlib=3.1.0 127 | - olefile=0.46 128 | - openeye-toolkits=2020.2.2 129 | - openssl=1.1.1k 130 | - packaging=20.3 131 | - pandas=1.0.3 132 | - pandoc=2.12 133 | - pandocfilters=1.4.3 134 | - parso=0.8.2 135 | - pathos=0.2.5 136 | - pcre=8.43 137 | - pexpect=4.8.0 138 | - pickleshare=0.7.5 139 | - pillow=7.0.0 140 | - pip=20.0.2 141 | - pixman=0.38.0 142 | - pkginfo=1.6.1 143 | - pluggy=0.13.1 144 | - pox=0.2.7 145 | - ppft=1.6.6.1 146 | - prometheus_client=0.10.1 147 | - prompt-toolkit=3.0.17 148 | - prompt_toolkit=3.0.17 149 | - pthread-stubs=0.4 150 | - ptyprocess=0.7.0 151 | - py=1.8.1 152 | - py-boost=1.67.0 153 | - py4j=0.10.7 154 | - pyasn1=0.4.8 155 | - pyasn1-modules=0.2.8 156 | - pycparser=2.20 157 | - pydantic=1.8.2 158 | - pygments=2.7.2 159 | - pyjwt=1.7.1 160 | - pyopenssl=19.1.0 161 | - pyparsing=2.4.6 162 | - pyqt=5.9.2 163 | - pyrsistent=0.17.3 164 | - pysocks=1.7.1 165 | - pytest=5.4.1 166 | - python=3.7.7 167 | - python-dateutil=2.8.1 168 | - python_abi=3.7 169 | - pytorch=1.7.1 170 | - pytz=2019.3 171 | - pyzmq=20.0.0 172 | - qt=5.9.7 173 | - qtconsole=5.0.3 174 | - qtpy=1.9.0 175 | - rdkit=2020.03.3.0 176 | - readline=8.0 177 | - readme_renderer=27.0 178 | - requests=2.25.0 179 | - requests-oauthlib=1.3.0 180 | - requests-toolbelt=0.9.1 181 | - rfc3986=1.4.0 182 | - rsa=4.7.2 183 | - scikit-learn=0.21.3 184 | - scipy=1.4.1 185 | - secretstorage=3.2.0 186 | - send2trash=1.5.0 187 | - setuptools=46.1.1 188 | - sip=4.19.8 189 | - six=1.14.0 190 | - sqlite=3.31.1 191 | - tensorboard-plugin-wit=1.6.0 192 | - termcolor=1.1.0 193 | - terminado=0.9.4 194 | - testpath=0.4.4 195 | - tk=8.6.8 196 | - tornado=6.0.4 197 | - tqdm=4.43.0 198 | - traitlets=5.0.5 199 | - twine=3.2.0 200 | - typing-extensions=3.7.4.3 201 | - typing_extensions=3.7.4.3 202 | - urllib3=1.25.11 203 | - wcwidth=0.1.9 204 | - webencodings=0.5.1 205 | - wheel=0.34.2 206 | - widgetsnbextension=3.5.1 207 | - wrapt=1.12.1 208 | - xorg-libxau=1.0.9 209 | - xorg-libxdmcp=1.1.3 210 | - xz=5.2.4 211 | - yarl=1.4.2 212 | - zeromq=4.3.4 213 | - zipp=2.2.0 214 | - zlib=1.2.11 215 | - zstd=1.3.7 216 | - pip: 217 | - absl-py==1.1.0 218 | - astor==0.8.1 219 | - cloudpickle==2.1.0 220 | - dataclasses==0.6 221 | - deprecated==1.2.13 222 | - dm-tree==0.1.7 223 | - flatbuffers==1.12 224 | - future==0.18.2 225 | - gast==0.3.2 226 | - gpflow==2.3.1 227 | - grpcio==1.27.2 228 | - keras==2.7.0 229 | - keras-applications==1.0.8 230 | - libclang==14.0.1 231 | - markdown==3.2.1 232 | - multipledispatch==0.6.0 233 | - opt-einsum==3.2.0 234 | - protobuf==3.11.3 235 | - reinvent-chemistry==0.0.51 236 | - reinvent-models==0.0.15rc1 237 | - tabulate==0.8.9 238 | - tensorboard==2.9.1 239 | - tensorboard-data-server==0.6.1 240 | - tensorflow==2.7.0 241 | - tensorflow-estimator==2.7.0 242 | - tensorflow-io-gcs-filesystem==0.26.0 243 | - tensorflow-probability==0.17.0 244 | - werkzeug==2.1.2 245 | -------------------------------------------------------------------------------- /Chemists-Component/data/drd2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/reinvent-hitl/784c11c4be402bfc52f2bb0ba2eca6ba1e1b6c86/Chemists-Component/data/drd2.pkl -------------------------------------------------------------------------------- /Chemists-Component/evaluate_results_Task2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# This is a notebook for evaluating and plotting the results of Task 2 experiments. \n", 10 | "# To run an experiment, use:\n", 11 | "# Task2_Chemists_component.ipynb and\n", 12 | "# run Reinvent 3.2 with config_tX.json and config_tX_sampling.json for X=0,...,T in the result folder (use runs.sh and run_sampling.sh in a computation cluster)\n", 13 | "\n", 14 | "import os\n", 15 | "import pickle\n", 16 | "import numpy as np\n", 17 | "import pandas as pd\n", 18 | "import fnmatch\n", 19 | "import reinvent_scoring\n", 20 | "from multiprocessing import Pool\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "%matplotlib inline\n", 23 | "\n", 24 | "\n", 25 | "jobname = ['Task2_demo_thompson']\n", 26 | "result_dir ='./results/'\n", 27 | "save_result_dir = './'\n", 28 | "\n", 29 | "acquisitions = ['random', 'greedy', 'thompson', 'uncertainty']\n", 30 | "# options for plotting\n", 31 | "method_colors = {'uncertainty': 'C1',\n", 32 | " 'random': 'C2',\n", 33 | " 'greedy': 'C3',\n", 34 | " 'thompson': 'C0'}\n", 35 | "method_names = {'uncertainty': 'Uncertainty sampling',\n", 36 | " 'random': 'Random sampling',\n", 37 | " 'greedy': 'Pure exploitation',\n", 38 | " 'thompson': 'Thompson sampling'} " 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "# Oracle model:\n", 48 | "expert_model_path = './data/drd2.pkl'\n", 49 | "qsar_model = {\n", 50 | " \"component_type\": \"predictive_property\",\n", 51 | " \"name\": \"DRD2\",\n", 52 | " \"weight\": 1,\n", 53 | " \"specific_parameters\": {\n", 54 | " \"model_path\": expert_model_path,\n", 55 | " \"scikit\": \"classification\",\n", 56 | " \"descriptor_type\": \"ecfp\",\n", 57 | " \"size\": 2048,\n", 58 | " \"radius\": 3,\n", 59 | " \"use_counts\": True,\n", 60 | " \"use_features\": True,\n", 61 | " \"transformation\": {\n", 62 | " \"transformation_type\": \"no_transformation\"\n", 63 | " }\n", 64 | " }\n", 65 | " }\n", 66 | "\n", 67 | "scoring_function = {\n", 68 | " \"name\": \"custom_sum\",\n", 69 | " \"parallel\": False,\n", 70 | " \"parameters\": [\n", 71 | " qsar_model\n", 72 | " ]\n", 73 | "}\n", 74 | "scoring_function_parameters = reinvent_scoring.scoring.ScoringFunctionParameters(**scoring_function)\n", 75 | "expert_scoring_function = reinvent_scoring.scoring.ScoringFunctionFactory(scoring_function_parameters)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "def score_sampled_smiles(file):\n", 85 | " data = pd.read_csv(file).to_numpy()\n", 86 | " data =np.squeeze(data)\n", 87 | " summary = expert_scoring_function.get_final_score(data)\n", 88 | " return summary.total_score" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "result_dir_per_job={}\n", 98 | "print('reading {}'.format(result_dir))\n", 99 | "for fname in os.listdir(result_dir):\n", 100 | " for job in jobname:\n", 101 | " if fnmatch.fnmatch(fname,job+'*'):\n", 102 | " #take care the position of acquisition method in file name\n", 103 | " print(job)\n", 104 | " acquisition=job.split('_')[2]\n", 105 | " if acquisition in result_dir_per_job.keys():\n", 106 | " result_dir_per_job[acquisition].append(result_dir+fname+'/loop0/')\n", 107 | " else:\n", 108 | " result_dir_per_job[acquisition]=[result_dir+fname+'/loop0/']\n", 109 | "\n", 110 | "for key in result_dir_per_job.keys():\n", 111 | " assert key in acquisitions, \"take care the position of acquisition method in file name\"" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "def parallel_processing(params):\n", 121 | " res={}\n", 122 | " key,folders=params[0],params[1]\n", 123 | " acquisition=key\n", 124 | " res[acquisition]={}\n", 125 | " for folder in folders:\n", 126 | " print(\"processing \" +folder)\n", 127 | " files=os.listdir(folder)\n", 128 | " for filename in files:\n", 129 | " if filename.endswith(\".p\"):\n", 130 | " dat_save = pickle.load(open(folder + filename, \"rb\" ))\n", 131 | " # collect results\n", 132 | " for key, value in dat_save.items():\n", 133 | " if key in res[acquisition]:\n", 134 | " #assume experiment should have the same 'hitl params'\n", 135 | " if 'baseline' in key:\n", 136 | " res[acquisition][key]=np.append(res[acquisition][key],value)\n", 137 | " elif key!='idx_query' and key!='hitl params':\n", 138 | " res[acquisition][key]=np.vstack((res[acquisition][key],value))\n", 139 | " else:\n", 140 | " if key=='hitl params':\n", 141 | " for k, v in value.items():\n", 142 | " res[acquisition][k]=v\n", 143 | " elif 'baseline' in key:\n", 144 | " res[acquisition][key]=np.array([value])\n", 145 | " elif key!='idx_query':\n", 146 | " res[acquisition][key] = np.array(value)\n", 147 | " res[acquisition]['Expert score REINVENT output'] = []\n", 148 | " sc=[]\n", 149 | " for it in np.arange(0,res[acquisition]['T']+1,res[acquisition]['step']):\n", 150 | " sampled_smiles = folder + 'results_t{}/sampled.csv'.format(it)\n", 151 | " print('processing ' + sampled_smiles)\n", 152 | " try:\n", 153 | " sc += [score_sampled_smiles(sampled_smiles)]\n", 154 | " except FileNotFoundError:\n", 155 | " print(\"MISSING RESULT {}, please sample molecules from your agent\".format(sampled_smiles)) \n", 156 | " sc += [np.nan]\n", 157 | " res[acquisition]['Expert score REINVENT output'] += [sc]\n", 158 | "\n", 159 | " return res" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "params=[[key, folders] for key, folders in result_dir_per_job.items() ]\n", 169 | "with Pool(len(jobname)) as p:\n", 170 | " mapped_pool=p.map(parallel_processing,params)\n", 171 | "res={}\n", 172 | "for i in range(len(mapped_pool)):\n", 173 | " res.update(mapped_pool[i])" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "def shadedplot(x, y, fill=True, label='', color=''):\n", 183 | " p = plt.plot(x, y[0,:], label=label, color=color)\n", 184 | " c = p[-1].get_color()\n", 185 | " if fill:\n", 186 | " plt.fill_between(x, y[1,:], y[2,:], color=c, alpha=0.25)" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "# Plot Expert score in REINVENT output\n", 196 | "for acquisition in res.keys():\n", 197 | " print(acquisition)\n", 198 | " if acquisition not in acquisitions:\n", 199 | " continue\n", 200 | " rs = res[acquisition]['Expert score REINVENT output']\n", 201 | " r = np.array([[np.nanmean(scores) for scores in rs[i]] for i in np.arange(len(rs))])\n", 202 | " m = np.nanmean(r, axis=0)\n", 203 | " sd = np.nanstd(r, axis=0)/np.sqrt(r.shape[0]) #SEM\n", 204 | " x = np.arange(len(m))\n", 205 | " N0=res[acquisition]['N0']\n", 206 | " x[0] = N0\n", 207 | " step=res[acquisition]['step']\n", 208 | " n_batch=res[acquisition]['n_batch']\n", 209 | " for i in np.arange(1,len(x)):\n", 210 | " x[i]=x[i-1]+n_batch*step\n", 211 | " if(r.shape[0]>1): # If multiple repetitions\n", 212 | " shadedplot(x, np.array([m, m-sd,m+sd]), fill=True,label=method_names[acquisition], color=method_colors[acquisition])\n", 213 | " for line in np.arange(r.shape[0]):\n", 214 | " p = np.isnan(r[line,:]) # remove missing iterations\n", 215 | " p = [not b for b in p]\n", 216 | " plt.plot(x[p], r[line,p], color=method_colors[acquisition], label=method_names[acquisition])\n", 217 | "\n", 218 | "plt.legend()\n", 219 | "plt.xlabel('Number of queries to the simulated chemist')\n", 220 | "plt.ylabel('Average oracle score')\n", 221 | "plt.ylim([0,1])\n", 222 | "#plt.savefig(save_result_dir + \"{}_Expert_score_REINVENT_output.png\".format(jobname[0]),bbox_inches='tight')" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [] 231 | } 232 | ], 233 | "metadata": { 234 | "kernelspec": { 235 | "display_name": "chemist-component.v2", 236 | "language": "python", 237 | "name": "cc_env_v2" 238 | }, 239 | "language_info": { 240 | "codemirror_mode": { 241 | "name": "ipython", 242 | "version": 3 243 | }, 244 | "file_extension": ".py", 245 | "mimetype": "text/x-python", 246 | "name": "python", 247 | "nbconvert_exporter": "python", 248 | "pygments_lexer": "ipython3", 249 | "version": "3.7.7" 250 | } 251 | }, 252 | "nbformat": 4, 253 | "nbformat_minor": 4 254 | } 255 | -------------------------------------------------------------------------------- /Chemists-Component/kernels.py: -------------------------------------------------------------------------------- 1 | import gpflow 2 | from gpflow.utilities import positive 3 | from gpflow.utilities.ops import broadcasting_elementwise 4 | import tensorflow as tf 5 | 6 | # from https://github.com/Ryan-Rhys/FlowMO/blob/master/GP/kernels.py 7 | 8 | class Tanimoto(gpflow.kernels.Kernel): 9 | def __init__(self): 10 | super().__init__() 11 | # We constrain the value of the kernel variance to be positive when it's being optimised 12 | self.variance = gpflow.Parameter(1.0, transform=positive()) 13 | 14 | def K(self, X, X2=None): 15 | """ 16 | Compute the Tanimoto kernel matrix σ² * (() / (||x||^2 + ||y||^2 - )) 17 | :param X: N x D array 18 | :param X2: M x D array. If None, compute the N x N kernel matrix for X. 19 | :return: The kernel matrix of dimension N x M 20 | """ 21 | if X2 is None: 22 | X2 = X 23 | 24 | Xs = tf.reduce_sum(tf.square(X), axis=-1) # Squared L2-norm of X 25 | X2s = tf.reduce_sum(tf.square(X2), axis=-1) # Squared L2-norm of X2 26 | outer_product = tf.tensordot(X, X2, [[-1], [-1]]) # outer product of the matrices X and X2 27 | # Analogue of denominator in Tanimoto formula 28 | 29 | denominator = -outer_product + broadcasting_elementwise(tf.add, Xs, X2s) 30 | 31 | return self.variance * outer_product/denominator 32 | 33 | def K_diag(self, X): 34 | """ 35 | Compute the diagonal of the N x N kernel matrix of X 36 | :param X: N x D array 37 | :return: N x 1 array 38 | """ 39 | return tf.fill(tf.shape(X)[:-1], tf.squeeze(self.variance)) -------------------------------------------------------------------------------- /Chemists-Component/load_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from rdkit.Chem import AllChem as Chem 3 | from multiprocessing import Pool 4 | import pandas as pd 5 | import numpy as np 6 | import pickle 7 | from sklearn.model_selection import train_test_split 8 | from sklearn.preprocessing import StandardScaler 9 | from molecular_features import descriptorsX, morganX 10 | import time 11 | class Dataset2D: 12 | def __init__(self, file, y_field=None, id_field=None, ext='sdf'): 13 | self.smiles = [] 14 | self.moles = [] 15 | self.Y = [] if y_field is not None else None 16 | self.id = [] 17 | temp_id = 1 18 | if ext == 'sdf': 19 | suppl = Chem.SDMolSupplier(file, strictParsing=False) 20 | for i in suppl: 21 | if i is None: 22 | continue 23 | smi = Chem.MolToSmiles(i, isomericSmiles=False) 24 | if smi is not None and smi != '': 25 | self.smiles.append(smi) 26 | self.moles.append(i) 27 | if y_field is not None: 28 | self.Y.append(i.GetProp(y_field)) 29 | if id_field is not None: 30 | self.id.append(i.GetProp(id_field)) 31 | else: 32 | self.id.append('id{:0>5}'.format(temp_id)) 33 | temp_id += 1 34 | 35 | elif ext == 'csv': 36 | # df = pd.read_csv(file) 37 | df=file 38 | try: 39 | df['moles'] = df['SMILES'].apply(lambda x: Chem.MolFromSmiles(x)) 40 | except KeyError: 41 | df['SMILES'] = df['canonical'] 42 | df['moles'] = df['SMILES'].apply(lambda x: Chem.MolFromSmiles(x)) 43 | df = df.dropna() 44 | self.smiles = df['SMILES'].tolist() 45 | self.moles = df['moles'].tolist() 46 | self.Y = df[y_field].tolist() if y_field is not None else None 47 | self.id = df[id_field].tolist() if id_field is not None else np.arange(len(self.smiles)) 48 | 49 | else: 50 | raise ValueError('file extension not supported!') 51 | 52 | 53 | assert(len(self.smiles) == len(self.moles) == len(self.id)) 54 | if self.Y is not None: 55 | assert(len(self.smiles) == len(self.Y)) 56 | self.Y = np.array(self.Y) 57 | 58 | 59 | def __getitem__(self, index): 60 | if self.Y is not None: 61 | ret = self.id[index], self.smiles[index], self.moles[index], self.Y[index] 62 | else: 63 | ret = self.id[index], self.smiles[index], self.moles[index] 64 | return ret 65 | 66 | def __len__(self): 67 | return len(self.smiles) 68 | 69 | def __add__(self, other): 70 | pass 71 | 72 | 73 | class DataStructure: 74 | def __init__(self, dataset, feat_fn, y_transforms=None, num_proc=1): 75 | self.dataset = dataset 76 | self.feat_fn = feat_fn 77 | self.Y = dataset.Y 78 | self.id = dataset.id 79 | self.num_proc = num_proc 80 | self.feat_names = [] 81 | self.name_to_idx = {} 82 | 83 | x_s = [] 84 | for fname in self.feat_fn.keys(): 85 | f = self.feat_fn[fname] 86 | with Pool(self.num_proc) as p: 87 | arr = np.array(p.map(f, self.dataset.moles)) 88 | x_s.append(arr) 89 | length = arr.shape[1] 90 | names = list('{}_{}'.format(fname, x+1) for x in range(length)) 91 | self.feat_names += names 92 | x_s = tuple(x_s) 93 | self.X_ = np.concatenate(x_s, axis=1) 94 | 95 | # remove any nan rows 96 | nans = np.isnan(self.X_) 97 | mask = np.any(nans, axis=1) 98 | self.X_ = self.X_[~mask, :] 99 | self.name_to_idx = dict(zip(self.feat_names, range(len(self.feat_names)))) 100 | self.id = list(self.id[j] for j in range(len(mask)) if not mask[j]) 101 | if self.Y is not None: 102 | self.Y = self.Y[~mask] 103 | if y_transforms is not None: 104 | for t in y_transforms: 105 | self.Y = np.array(list(map(t, self.Y))) 106 | 107 | def __len__(self): 108 | return self.X_.shape[0] 109 | 110 | @property 111 | def shape(self): 112 | return self.X_.shape 113 | 114 | def X(self, feats=None): 115 | ''' 116 | Use a list of to select feature columns 117 | ''' 118 | if feats is None: 119 | return self.X_ 120 | else: 121 | mask = list(map(lambda x: self.name_to_idx[x], feats)) 122 | return self.X_[:, mask] 123 | 124 | 125 | 126 | 127 | def load_data(train_data_path, test_data_path, output_dir, train_num, N0, y_field, id_field=None, ext='csv', sampling=True, normalization=False): 128 | print(train_data_path) 129 | train_data=pd.read_csv(train_data_path) 130 | if test_data_path: 131 | test_data=pd.read_csv(test_data_path) 132 | train_data['class']=(train_data[y_field]>=0.5).astype(int) 133 | 134 | if sampling: 135 | print('sampling') 136 | # sample train_num samples from train_data 137 | train_data, _ =train_test_split(train_data, train_size=train_num/train_data.shape[0],stratify=train_data['class']) 138 | train_data=train_data.reset_index(drop=True) 139 | # sample test_num samples from test_data 140 | # _, test_data=train_test_split(test_data, test_size=test_num/test_data.shape[0],stratify=test_data['activity']) 141 | 142 | if test_data_path is None: 143 | print('splitting') 144 | train_data, test_data =train_test_split(train_data, test_size=0.2,stratify=train_data[y_field]) 145 | 146 | L=train_data.sample(n=int(N0)).index 147 | #print('training data have {} active'.format(train_data.loc[train_data[y_field]>=0.5].shape)) 148 | #print('training data have {} inactive'.format(train_data.loc[train_data[y_field]<0.5].shape)) 149 | #print('testing data have {} active'.format(test_data.loc[test_data[y_field]>=0.5].shape)) 150 | #print('testing data have {} inactive'.format(test_data.loc[test_data[y_field]<0.5].shape)) 151 | train_ds=Dataset2D(train_data, y_field=y_field, id_field=id_field, ext=ext) 152 | test_ds=Dataset2D(test_data, y_field=y_field, id_field=id_field, ext=ext) 153 | start=time.time() 154 | train_str = DataStructure(train_ds, dict(physchem=morganX), num_proc=8) 155 | test_str = DataStructure(test_ds, dict(physchem=morganX), num_proc=8) 156 | print("trainsformation spend {} s".format(time.time()-start)) 157 | 158 | # X contains features 159 | X_train = train_str.X() 160 | y_train = train_str.Y 161 | y_train=y_train.reshape(-1, 1) 162 | X_test = test_str.X() 163 | y_test = test_str.Y 164 | y_test=y_test.reshape(-1,1) 165 | 166 | smiles_train= np.array([ item[1] for item in train_str.dataset]) 167 | id_train= np.array([item[0] for item in train_str.dataset]) 168 | 169 | if normalization: 170 | print('Normalizing...') 171 | #Normalization, speed up training process 172 | scaler = StandardScaler() 173 | X_train=scaler.fit_transform(X_train) 174 | X_test=scaler.transform(X_test) 175 | modelfile=output_dir+"/models/scaler.pkl" 176 | try: 177 | os.mkdir(output_dir + '/models') 178 | except FileExistsError: 179 | pass 180 | with open(modelfile, "wb+") as f: 181 | pickle.dump(scaler, f) 182 | U=np.setdiff1d(np.arange(X_train.shape[0]),L) 183 | X_L, X_U, y_L, y_U=X_train[L], X_train[U], y_train[L], y_train[U] 184 | return X_train, X_test, y_train, y_test, smiles_train, id_train, X_L, X_U, y_L, y_U, L , U -------------------------------------------------------------------------------- /Chemists-Component/models.py: -------------------------------------------------------------------------------- 1 | import gpflow 2 | from kernels import Tanimoto 3 | 4 | 5 | def Tanimoto_model(X,Y): 6 | ''' 7 | Gaussian process regression model with Tanimoto kernel 8 | X: (n_samples, n_features) 9 | Y: (n_samples,) , binary classification 10 | ''' 11 | m = gpflow.models.GPR((X, Y), kernel=Tanimoto(), mean_function=None, noise_variance=1) 12 | 13 | opt = gpflow.optimizers.Scipy() 14 | opt.minimize(m.training_loss, variables=m.trainable_variables) 15 | 16 | return m -------------------------------------------------------------------------------- /Chemists-Component/molecular_features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit.Chem import DataStructs,Descriptors 3 | from rdkit.Chem import AllChem as Chem 4 | from rdkit.Chem import MACCSkeys 5 | from rdkit.Avalon import pyAvalonTools 6 | 7 | #morgan fingerprints 8 | def morganX(mol, bits=1024, radius=3): 9 | morgan = np.zeros((1, bits)) 10 | fp = Chem.GetMorganFingerprintAsBitVect(mol, radius, nBits=bits) 11 | DataStructs.ConvertToNumpyArray(fp, morgan) 12 | return morgan 13 | 14 | #MACCS 15 | def maccsX(mol): 16 | maccs = np.zeros((1,167)) 17 | fp = MACCSkeys.GenMACCSKeys(mol) 18 | DataStructs.ConvertToNumpyArray(fp, maccs) 19 | return maccs 20 | 21 | #Avalon fps 22 | def avalonX(mol, avbits=512): 23 | avalon = np.zeros((1, avbits)) 24 | fp = pyAvalonTools.GetAvalonFP(mol) 25 | DataStructs.ConvertToNumpyArray(fp, avalon) 26 | return avalon 27 | 28 | # physchem descriptors 29 | def descriptorsX(m): 30 | descr = [Descriptors.ExactMolWt(m), 31 | Descriptors.MolLogP(m), 32 | Descriptors.TPSA(m), 33 | Descriptors.NumHAcceptors(m), 34 | Descriptors.NumHDonors(m), 35 | Descriptors.NumRotatableBonds(m), 36 | Descriptors.NumHeteroatoms(m), 37 | Descriptors.NumAromaticRings(m), 38 | Descriptors.FractionCSP3(m)] 39 | return np.array(descr) -------------------------------------------------------------------------------- /Chemists-Component/query.py: -------------------------------------------------------------------------------- 1 | import reinvent_scoring 2 | 3 | import numpy as np 4 | 5 | def query(smiles): 6 | expert_model_path = './data/drd2.pkl' 7 | qsar_model ={ "component_type": "predictive_property", 8 | "name": "DRD2", 9 | "weight": 1, 10 | "model_path": expert_model_path, 11 | "smiles": [], 12 | "specific_parameters": { 13 | "transformation_type": "no_transformation", 14 | "scikit": "classification", 15 | "transformation": False, 16 | "descriptor_type": "ecfp", 17 | "size": 2048, 18 | "radius": 3 19 | } 20 | } 21 | scoring_function = { 22 | "name": "custom_sum", 23 | "parallel": True, 24 | "parameters": [ 25 | qsar_model 26 | ] 27 | } 28 | scoring_function_parameters = reinvent_scoring.scoring.ScoringFuncionParameters(**scoring_function) 29 | expert_scoring_function = reinvent_scoring.scoring.ScoringFunctionFactory(scoring_function_parameters) 30 | result=expert_scoring_function.get_final_score(smiles) 31 | return result.total_score -------------------------------------------------------------------------------- /Chemists-Component/write.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | import pickle 5 | import numpy as np 6 | def write_sample_file(jobid, jobname, output_dir, loop, it): 7 | try: 8 | os.makedirs(output_dir) 9 | except FileExistsError: 10 | pass 11 | configuration={ 12 | "logging": { 13 | "job_id": jobid, 14 | "job_name": "{}_loop{}".format(jobname,loop), 15 | "logging_path": os.path.join(output_dir, "progress_t{}.log".format(it)), 16 | "recipient": "local", 17 | "sender": "http://127.0.0.1" 18 | }, 19 | "parameters": { 20 | "model_path": os.path.join(output_dir, "results_t{}/Agent.ckpt".format(it)), 21 | "output_smiles_path": os.path.join(output_dir, "results_t{}/sampled.csv".format(it)), 22 | "num_smiles": 1024, 23 | "batch_size": 128, 24 | "with_likelihood": False 25 | }, 26 | "run_type": "sampling", 27 | "model_type": "default", 28 | "version": 3 29 | } 30 | conf_filename = os.path.join(output_dir, "config_t{}_sampling.json".format(it)) 31 | with open(conf_filename, 'w') as f: 32 | json.dump(configuration, f, indent=4, sort_keys=True) 33 | return conf_filename 34 | 35 | def write_config_file(jobid, jobname, reinvent_dir, reinvent_env, output_dir, fpdim, loop, it, modelfile, seed): 36 | # if required, generate a folder to store the results 37 | try: 38 | os.makedirs(output_dir) 39 | except FileExistsError: 40 | pass 41 | 42 | diversity_filter = { 43 | "name": "IdenticalMurckoScaffold", 44 | "bucket_size": 25, 45 | "minscore": 0.2, 46 | "minsimilarity": 0.4 47 | } 48 | 49 | inception = { 50 | "memory_size": 20, 51 | "sample_size": 5, 52 | "smiles": [] 53 | } 54 | 55 | human_component = { 56 | "component_type": "predictive_property", 57 | "name": "Human-component", 58 | "weight": 1, 59 | "specific_parameters": { 60 | "model_path": modelfile, 61 | "gpflow": "regression", 62 | "descriptor_type": "ecfp", 63 | "size": fpdim, 64 | "container_type":"gpflow_container", 65 | "use_counts": True, 66 | "use_features": True, 67 | "transformation": { 68 | "transformation_type":"clipping", 69 | "low":0, 70 | "high":1 71 | } 72 | } 73 | } 74 | 75 | scoring_function = { 76 | "name": "custom_sum", 77 | "parallel": False, 78 | "parameters": [ 79 | human_component 80 | ] 81 | } 82 | 83 | configuration = { 84 | "version": 3, 85 | "run_type": "reinforcement_learning", 86 | "model_type": "default", 87 | "parameters": { 88 | "scoring_function": scoring_function 89 | } 90 | } 91 | 92 | configuration["parameters"]["diversity_filter"] = diversity_filter 93 | configuration["parameters"]["inception"] = inception 94 | 95 | configuration["parameters"]["reinforcement_learning"] = { 96 | "prior": os.path.join(reinvent_dir, "data/random.prior.new"), 97 | "agent": os.path.join(reinvent_dir, "data/random.prior.new"), 98 | "n_steps": 300, 99 | "sigma": 128, 100 | "learning_rate": 0.0001, 101 | "batch_size": 128, 102 | "reset": 0, 103 | "reset_score_cutoff": 0.5, 104 | "margin_threshold": 50 105 | } 106 | 107 | configuration["logging"] = { 108 | "sender": "http://127.0.0.1", 109 | "recipient": "local", 110 | "logging_frequency": 0, 111 | "logging_path": os.path.join(output_dir, "progress_t{}.log".format(it)), 112 | "result_folder": os.path.join(output_dir, "results_t{}".format(it)), 113 | "job_name": "{}_loop{}".format(jobname,loop), 114 | "job_id": jobid 115 | } 116 | 117 | # write the configuration file to the disc 118 | conf_filename = os.path.join(output_dir, "config_t{}.json".format(it)) 119 | with open(conf_filename, 'w') as f: 120 | json.dump(configuration, f, indent=4, sort_keys=True) 121 | 122 | return conf_filename 123 | 124 | 125 | def write_query_to_csv(smiles, ids, query, file, output_dir): 126 | try: 127 | os.mkdir(output_dir + '/query') 128 | except FileExistsError: 129 | pass 130 | data = {'id': [ids[i] for i in query], 'SMILES': smiles} 131 | df = pd.DataFrame(data) 132 | df.to_csv(file, index = False, header=True) 133 | 134 | 135 | def write_run_sample(seed, output_dir, reinvent_env, reinvent_dir, step, n_iteration): 136 | runfile = output_dir + '/run_sampling.sh' 137 | array_num=int(np.ceil(n_iteration/step)) 138 | try: 139 | os.mkdir(output_dir + '/slurm') 140 | except FileExistsError: 141 | pass 142 | with open(runfile, 'w') as f: 143 | f.write("#!/bin/bash -l \n") 144 | f.write("#SBATCH --mem=500M \n") 145 | f.write('#SBATCH --time=00:05:00 \n') 146 | f.write('#SBATCH -o {}/slurm/out_{}_sampling_%a.out\n'.format(output_dir, seed)) 147 | f.write('#SBATCH --array=0-{}\n'.format(array_num)) 148 | f.write('\n') 149 | f.write('module purge\n') 150 | f.write('module load anaconda\n') 151 | f.write('source activate {}\n'.format(reinvent_env)) 152 | f.write('\n') 153 | f.write('config_index=$(($SLURM_ARRAY_TASK_ID*{}))\n'.format(step)) 154 | f.write('conf_filename="{}/config_t${{config_index}}_sampling.json"\n'.format(output_dir)) 155 | f.write('srun python {}/input.py $conf_filename\n'.format(reinvent_dir)) 156 | 157 | def write_runs_sh(seed, output_dir, reinvent_env, reinvent_dir, step, n_iteration): 158 | runfile = output_dir + '/runs.sh' 159 | array_num=int(np.ceil(n_iteration/step)) 160 | try: 161 | os.mkdir(output_dir + '/slurm') 162 | except FileExistsError: 163 | pass 164 | with open(runfile, 'w') as f: 165 | f.write("#!/bin/bash -l \n") 166 | f.write("#SBATCH --mem=25G \n") 167 | f.write('#SBATCH --time=02:00:00\n') 168 | f.write('#SBATCH -o {}/slurm/out_{}_%a.out\n'.format(output_dir, seed)) 169 | f.write('#SBATCH --array=0-{}\n'.format(array_num)) 170 | f.write('\n') 171 | f.write('module purge\n') 172 | f.write('module load anaconda\n') 173 | f.write('source activate {}\n'.format(reinvent_env)) 174 | f.write('\n') 175 | f.write('config_index=$(($SLURM_ARRAY_TASK_ID*{}))\n'.format(step)) 176 | f.write('conf_filename="{}/config_t$config_index.json"\n'.format(output_dir)) 177 | f.write('srun python {}/input.py $conf_filename\n'.format(reinvent_dir)) 178 | 179 | def write_idx(L0, U0, i_query, y_train, output_dir, loop, it): 180 | i_query=np.copy(i_query) 181 | y_response = y_train[i_query] 182 | i_generated_in_query=np.where(y_response==-1)[0] 183 | i_generated=i_query[i_generated_in_query] #indexes in i_query where points to generated data 184 | new_start_idx=np.sum(y_train!=-1) # the num of data with labels 185 | i_query[i_generated_in_query]= np.arange(new_start_idx,new_start_idx+len(i_generated)) 186 | L0=np.union1d(L0,i_query) 187 | U0 = np.setdiff1d(U0,i_query) 188 | 189 | dat_save = { 190 | 'L': L0, # indices of labeled data 191 | 'U': U0, 192 | } 193 | try: 194 | os.mkdir(output_dir + '/idx') 195 | except FileExistsError: 196 | pass 197 | filename = output_dir + '/idx/log_loop{}_it{}.p'.format(loop,it) 198 | with open(filename , 'wb') as f: 199 | pickle.dump(dat_save, f) 200 | f.close() 201 | return L0, U0 202 | 203 | 204 | def write_training_data(smiles, activity, id_train, output_dir, idx_query=None, num_original=None): 205 | 206 | 207 | if idx_query is not None: 208 | idx_original=np.arange(num_original,dtype='i') 209 | i_generated=idx_query[idx_query>=num_original] 210 | idx=np.append(idx_original,i_generated) 211 | else: 212 | idx=np.where(activity!=-1)[0] 213 | idx.astype(int) 214 | print('idx.shape is {}'.format(idx.shape)) 215 | 216 | smiles, activity, smiles_id=smiles[idx], activity[idx], id_train[idx] 217 | dataset=pd.DataFrame({'id':smiles_id, 'canonical': smiles, 'activity': activity}) 218 | dataset.to_csv(os.path.join(output_dir,'drd2.train.csv'),index=False) 219 | 220 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Human-in-the-Loop Assisted De Novo Molecular Design 2 | ================================================================================================================= 3 | 4 | This repository contains source code of the methods presented in paper: 5 | 6 | I. Sundin, A. Voronov, H. Xiao, K. Papadopoulos, E. J. Bjerrum, M. Heinonen, A. Patronov, S. Kaski, O. Engkvist. (2022). Human-in-the-Loop Assisted de Novo Molecular Design. J Cheminform 14, 86 (2022). https://doi.org/10.1186/s13321-022-00667-8 7 | 8 | After the paper was published, this repository is in read-only Archive mode. Further development on this topic will get a separate repository. 9 | 10 | Installation 11 | ------------- 12 | 13 | 1. Install [Conda](https://conda.io/projects/conda/en/latest/index.html) 14 | 2. Clone this Git repository 15 | 3. In a shell, go to the repository and create the Conda environment, e.g. 16 | 17 | $ conda env create -f adapt-mpo.yml 18 | 19 | 4. Activate the environment: 20 | 21 | $ conda activate adapt_mpo 22 | 23 | 5. Run the method (see 'Usage') 24 | 25 | 26 | System Requirements 27 | ------------------- 28 | 29 | Adaptive MPO: 30 | * Python 3.7 31 | * Reinvent 3.2 (https://github.com/MolecularAI/Reinvent) 32 | * reinvent-models 0.0.15rc1 (https://github.com/MolecularAI/reinvent-models) 33 | * reinvent-scoring 0.0.73 34 | * reinvent-chemistry 0.0.51 35 | * Pystan 2.19.11 36 | * This code has been tested on Linux 37 | 38 | Chemist's Component: 39 | * Python 3.7 40 | * Reinvent 3.2 (https://github.com/MolecularAI/Reinvent) 41 | * reinvent-models 0.0.15rc1 (https://github.com/MolecularAI/reinvent-models) 42 | * reinvent-scoring with gpflow [fork](https://github.com/MolecularAI/reinvent-scoring-gpflow) 43 | * reinvent-chemistry 0.0.51 44 | * GPflow 2.3.1 45 | * ["Wall of molecules" (MolWall) GUI](https://github.com/MolecularAI/molwall) 46 | * This code has been tested on Linux 47 | 48 | 49 | Usage - Adaptive MPO 50 | -------------------------------------------- 51 | 52 | In Task1_Adaptive_MPO_HITL.py, modify paths 53 | 54 | ``` 55 | reinvent_dir = "/path/to/Reinvent" 56 | reinvent_env = "/path/to/conda_environment" 57 | output_dir = "/path/to/result/directory/{}_seed{}".format(jobid,seed)) 58 | ``` 59 | 60 | to match those in your system. 61 | 62 | Add a prior agent to Reinvent: 63 | * create a directory "data" under /path/to/Reinvent 64 | * copy random.prior.new from [ReinventCommunity](https://github.com/MolecularAI/ReinventCommunity/tree/master/notebooks/models) to /path/to/Reinvent/data/ 65 | 66 | To run, execute: 67 | 68 | $ python Task1_Adaptive_MPO_HITL.py acquisition seed 69 | 70 | where acquisition is one of the following query selection strategies: random, uncertainty, greedy or thompson (see explanation below), and seed is a random seed for reproducibility and running multiple replicas with different initializations. 71 | 72 | Supported query selection strategies: 73 | * Random sampling: 'random' 74 | * Uncertainty sampling: 'uncertainty' 75 | * Pure exploitation: 'greedy' 76 | * Thompson sampling: 'thompson' 77 | 78 | 79 | Usage - Chemist's Component 80 | -------------------------------------------- 81 | 82 | Chemist's component experiments consists of two phases: 83 | 1. Human-in-the-loop interaction: can be run either with a simulated human or using GUI for user interaction. 84 | - Files: 85 | * Conda environment cc_env_hitl.yml 86 | * Jupyter Notebook Task2_Chemists_Component.ipynb 87 | 2. Evaluating the performance of the resulting chemist's component as Reinvent scoring function. 88 | - Files: 89 | * Conda environment cc_env_reinvent.yml 90 | * Scripts created by Task2_Chemists_Component.ipynb in directory 'output_dir/loop0/' 91 | * Jupyter Notebook evaluate_results_Task2.ipynb to analyze and plot the results 92 | 93 | Preparing the setup 94 | - Create conda virtual environments cc_env_hitl and cc_env_reinvent from cc_env_hitl.yml and cc_env_reinvent.yml respectively 95 | - Manual modifications to cc_env_reinvent: 96 | * Build and install reinvent-scoring locally from [fork](https://github.com/Augmented-Drug-Design-Human-in-the-Loop/reinvent-scoring-gpflow) to support GPflow models. 97 | * Copy test_config.json to the environment if needed: In reinvent-scoring-gpflow run 98 | 99 | $ scp ./reinvent_scoring/configs/test_config.json /path/to/env/lib/python3.7/site-packages/reinvent_scoring/configs/test_config.json 100 | 101 | 102 | To run: 103 | 104 | Activate cc_env_hitl 105 | 106 | 1. Open 'Task2_Chemists_Component.ipynb' notebook. 107 | - In the Configuration-cell, set id number of the experiment 108 | - Set paths to Reinvent (```reinvent_dir```) and the conda environment created from cc_env_reinvent.yml (```reinvent_env```) and to a directory for saving the results (```output_dir```) and create it if needed. 109 | - Set acquisition method: Thompson sampling: 'thompson'; random sampling: 'random'; uncertainty sampling: 'uncertainty'; or pure exploitation: 'greedy'. 110 | - Select between simulated chemist (```simulated_human=True```) and user-interaction via GUI (```simulated_human=False```) 111 | 2. Run the notebook; if ```simulated_human=False``` it will wait for input after writing the first query (query_it1.csv). If ```simulated_human=True```, the notebook will continue a whole run of simulated HITL interaction 112 | - An output directory is automatically created, with a name demo_acquisition_YY-MM-DD-seed 113 | - If ```simulated_human=False```, continue with steps 3-6 to complete the experiment. 114 | 3. To use the GUI, upload the query_it1.csv file in MolWall (saved to output_dir/loop0/query/) 115 | - In later iterations, refresh the browser to upload next file to MolWall 116 | 4. Rate the molecules on the scale 1-5: 117 | 1 = not at all drd2 active (will read as 0 in the model), 5 = most certainly active (will read as 1 in the model) 118 | You may leave some molecules unscored. 119 | 5. Download the file to output_dir/loop0/query/ (the file will be named "query_it1_with_ratings.csv", do not modify it) 120 | 6. Press enter in the notebook to continue the script; a file query_it2.csv will be created, and so on 121 | 7. Steps 3-6 will continue 10 times, then the rest of the notebook will run to plot and save the results 122 | 123 | In phase 2, for evaluating Chemist's component performance as Reinvent scoring component, two bash scripts have been created in output_dir/loop0/ 124 | - runs.sh 125 | * Script to launch Reinvent reinforcement learning runs; For each iteration X, the Chemist's component model has been saved and config_tX.json determines the corresponding Reinvent configuration. We recommend using a computation cluster and slurm: then the jobs can be submitted by 126 | 127 | $ sbatch runs.sh 128 | 129 | - run_sampling.sh 130 | * Evaluates the resulting Reinvent agent by sampling molecules from it 131 | * Once the jobs from runs.sh have completed, run using 132 | 133 | $ sbatch run_sampling.sh 134 | 135 | Collect, analyze and plot the results using a notebook 'evaluate_results_Task2.ipynb' (uses cc_env_reinvent) 136 | - Use this notebook to plot the average oracle scores at each iteration after running Reinvent configurations determined in runs.sh and run_sampling.sh 137 | 138 | 139 | 140 | How to cite 141 | ------------------- 142 | 143 | Sundin, I., Voronov, A., Xiao, H. et al. Human-in-the-loop assisted de novo molecular design. J Cheminform 14, 86 (2022). https://doi.org/10.1186/s13321-022-00667-8 144 | --------------------------------------------------------------------------------