├── LICENSE ├── Makefile ├── README.md ├── bcpaff ├── data_processing │ ├── collect_affinity_data.py │ ├── data_processing.py │ ├── download.py │ ├── filtering.py │ ├── manual_structure_prep.py │ └── structure_prep_and_qm.py ├── ml │ ├── analysis.py │ ├── cluster_tools.py │ ├── generate_pickle.py │ ├── hparam_screen.py │ ├── ml_utils.py │ ├── net.py │ ├── net_utils.py │ ├── run_all_ml_experiments.py │ ├── scrambling.py │ ├── statsig.py │ ├── test.py │ └── train.py ├── qm │ ├── benchmark.py │ ├── compute_wfn_dftb.py │ ├── compute_wfn_psi4.py │ ├── compute_wfn_xtb.py │ ├── prepare_cluster_job.py │ └── xtb.inp ├── qtaim │ ├── critic2_tools.py │ ├── generate_critical_points.txt │ ├── multiwfn_commands.py │ ├── multiwfn_tools.py │ ├── qtaim_reader.py │ └── qtaim_viewer.py └── utils.py ├── env.yml ├── env_psi4.yml ├── hparam_files ├── hparams_bcp_atom_ids.csv ├── hparams_bcp_atom_ids_and_props.csv ├── hparams_bcp_feature_ablation.csv ├── hparams_bcp_props.csv ├── hparams_bcp_props_mini.csv ├── hparams_ncp_atom_ids.csv ├── hparams_ncp_atom_ids_and_props.csv └── hparams_ncp_props.csv └── img.png /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Clemens Isert, Kenneth Atz, Sereina Riniker, Gisbert Schneider (ETH Zurich) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .ONESHELL: 2 | SHELL=/bin/bash 3 | CONDA_ACTIVATE=source $$(conda info --base)/etc/profile.d/conda.sh ; conda activate ; conda activate 4 | PROC_DATA_DIR:=$(CURDIR)/processed_data 5 | PDE10A_SPLITS = random temporal_2011 temporal_2012 temporal_2013 aminohetaryl_c1_amide c1_hetaryl_alkyl_c2_hetaryl aryl_c1_amide_c2_hetaryl 6 | 7 | # Commands 8 | conda=conda 9 | mamba=mamba 10 | python=python 11 | 12 | all: env multiwfn download 13 | 14 | make with_conda: env_conda multiwfn download 15 | 16 | env: 17 | ${mamba} env create -f env.yml 18 | ${mamba} env create -f env_psi4.yml 19 | 20 | env_conda: 21 | ${conda} env create -f env.yml 22 | ${conda} env create -f env_psi4.yml 23 | 24 | multiwfn: 25 | wget http://sobereva.com/multiwfn/misc/Multiwfn_3.8_dev_bin_Linux_noGUI.zip 26 | unzip Multiwfn_3.8_dev_bin_Linux_noGUI.zip 27 | mv Multiwfn_3.8_dev_bin_Linux_noGUI multiwfn 28 | chmod 770 ./multiwfn/Multiwfn_noGUI 29 | rm Multiwfn_3.8_dev_bin_Linux_noGUI.zip 30 | 31 | download: 32 | source activate bcpaff; ${python} -m bcpaff.data_processing.download 33 | python -m bcpaff.data_processing.collect_affinity_data 34 | 35 | data_processing: 36 | source activate bcpaff; ulimit -s unlimited; ${python} -m bcpaff.data_processing.data_processing --action all --test_run --cluster_options no_cluster 37 | 38 | data_processing_report: 39 | source activate bcpaff; ${python} -m bcpaff.data_processing.data_processing --action report 40 | 41 | ml_experiments: 42 | source activate bcpaff; ${python} -m bcpaff.ml.run_all_ml_experiments --cluster_options no_cluster 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bcpaff - Exploring protein-ligand binding affinity prediction with electron density-based geometric deep learning 2 | 3 | Details are described in our [paper](https://doi.org/10.26434/chemrxiv-2023-585vf). Please cite if you use this work. 4 | 5 | ![](img.png) 6 | 7 | To setup the conda environment, install [Multiwfn](http://sobereva.com/multiwfn/), and download the datasets, just run the following in your $CWD: 8 | ```bash 9 | cd bcpaff 10 | make 11 | ``` 12 | (this step uses [mamba](https://github.com/mamba-org/mamba), you can change it to [conda](https://docs.conda.io/en/latest/) by using `make with_conda` instead). 13 | 14 | Structure preparation and training (remove `--test_run` to run on all structures; remove `--cluster_options no_cluster` to run via Slurm): 15 | ```bash 16 | make data_processing 17 | ``` 18 | (basically running `bcpaff.data_processing.data_processing`) 19 | 20 | ML model training: 21 | ```bash 22 | make ml_experiments 23 | ``` 24 | 25 | To interactively visualize BCPs in Jupyter Notebook: 26 | ```python 27 | from bcpaff.qtaim.qtaim_viewer import QtaimViewer 28 | from bcpaff.qtaim.qtaim_reader import QtaimProps 29 | 30 | qp = QtaimProps(basepath="PATH_TO_COMPOUND_FOLDER") 31 | v = QtaimViewer(qp) 32 | v.show() 33 | ``` 34 | -------------------------------------------------------------------------------- /bcpaff/data_processing/collect_affinity_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import os 6 | 7 | import pandas as pd 8 | 9 | from bcpaff.utils import DATA_PATH 10 | 11 | SPLIT_ASSIGNMENTS_BASEPATH = os.path.join(DATA_PATH, "pdbbind", "pdb_ids") 12 | 13 | 14 | def collect_affinity_as_dataframe(affinity_file: str) -> pd.DataFrame: 15 | with open(affinity_file, "r") as f: 16 | lines = f.readlines()[6:] 17 | info = [] 18 | for line in lines: 19 | tokens = line.split() 20 | pdb_id = tokens[0] 21 | act = float(tokens[3]) 22 | info.append([pdb_id, act]) 23 | df = pd.DataFrame(info, columns=["pdb_id", "aff"]) 24 | return df 25 | 26 | 27 | def assign_splits(df: pd.DataFrame) -> pd.DataFrame: 28 | split_names = ["training_set", "core_set", "validation_set", "hold_out_set"] 29 | split_assignments = {} 30 | for split_name in split_names: 31 | split_df = pd.read_csv(os.path.join(SPLIT_ASSIGNMENTS_BASEPATH, f"{split_name}.csv"), names=["pdb_id"]) 32 | for pdb_id in split_df.pdb_id.tolist(): 33 | assert pdb_id not in split_assignments 34 | split_assignments[pdb_id] = split_name 35 | 36 | df.loc[:, "random"] = df.pdb_id.apply( 37 | lambda x: split_assignments[x] if x in split_assignments else "no_assignment" 38 | ) 39 | df = df[df.random != "no_assignment"].reset_index(drop=True) 40 | return df 41 | 42 | 43 | def collect_pdb_affinity_data(): 44 | pdbbind_structure_path = os.path.join(DATA_PATH, "pdbbind") 45 | pdbbind_csv = os.path.join( 46 | pdbbind_structure_path, "PDBbind_v2019_plain_text_index/plain-text-index/index/INDEX_general_PL_data.2019" 47 | ) 48 | df = collect_affinity_as_dataframe(pdbbind_csv) 49 | df = assign_splits(df) 50 | df.to_csv(os.path.join(pdbbind_structure_path, "pdbbind2019_affinity.csv"), index=False) 51 | 52 | 53 | if __name__ == "__main__": 54 | collect_pdb_affinity_data() 55 | -------------------------------------------------------------------------------- /bcpaff/data_processing/data_processing.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import argparse 6 | import glob 7 | import os 8 | import pickle 9 | import subprocess 10 | from typing import List, Optional, Tuple 11 | 12 | import pandas as pd 13 | 14 | from bcpaff.ml.cluster_tools import submit_job 15 | from bcpaff.utils import DATA_PATH, PROCESSED_DATA_PATH, REPORT_PATH 16 | 17 | DATASETS = ["pdbbind", "pde10a"] 18 | AFFINITY_DATA = { 19 | "pdbbind": os.path.join(DATA_PATH, "pdbbind", "pdbbind2019_affinity.csv"), 20 | "pde10a": os.path.join(DATA_PATH, "pde10a", "10822_2022_478_MOESM2_ESM.csv"), 21 | } 22 | INPUT_FOLDERS = { 23 | "pdbbind": os.path.join(DATA_PATH, "pdbbind", "dataset"), 24 | "pde10a": os.path.join(DATA_PATH, "pde10a", "pde-10_pdb_bind_format_blinded"), 25 | "d2dr": os.path.join(DATA_PATH, "d2dr", "D2DR_complexes_prepared"), 26 | } 27 | 28 | 29 | def generate_jobscript_structure_prep_and_qm( 30 | dataset: str, 31 | base_input_dir: str, 32 | base_output_dir: str, 33 | test_run: bool = True, 34 | esp: bool = False, 35 | qm_method: str = "xtb", 36 | cutoff: int = 6, 37 | keep_wfn: bool = True, 38 | ) -> Tuple[str, int]: 39 | """Generate file with python -m bcpaff.data_processing.structure_prep_and_qm commands for compounds from specified dataset. 40 | 41 | Parameters 42 | ---------- 43 | dataset : str 44 | name of dataset, either pdbbind or pde10a 45 | base_input_dir : str 46 | path of input data 47 | base_output_dir : str 48 | path to output folder 49 | test_run : bool, optional 50 | whether to only use the first 5 compounds for quick testing, by default True 51 | 52 | Returns 53 | ------- 54 | tuple (str, str) 55 | path to cmds file, number of lines in cmds file 56 | """ 57 | 58 | cmds_file_path = os.path.join(base_output_dir, "slurm_files", "cmds.txt") 59 | os.makedirs(os.path.dirname(cmds_file_path), exist_ok=True) 60 | folders = sorted(glob.glob(os.path.join(base_input_dir, "*"))) 61 | folders = [folder for folder in folders if os.path.isdir(folder)] 62 | 63 | structure_ids = [os.path.basename(folder) for folder in folders] 64 | 65 | if test_run: 66 | if dataset == "pdbbind": 67 | structure_ids = ["3zzf", "1w8l", "5eb2", "2r58", "3ao4"] # something from train/val/test set 68 | elif dataset == "pde10a": 69 | structure_ids = ["5sf8_0", "5sfr_1", "5sf4_2", "5se7_11"] # something from train/val/test set 70 | 71 | lines = [] 72 | for structure_id in structure_ids: 73 | line = f"python -m bcpaff.data_processing.structure_prep_and_qm {structure_id}" 74 | line += f" --dataset {dataset}" 75 | line += f" --basepath {base_input_dir}" 76 | line += f" --output_basepath {base_output_dir}" 77 | if qm_method.startswith("dftb"): 78 | line += f" --implicit_solvent water" 79 | elif qm_method == "xtb": 80 | line += f" --implicit_solvent alpb_water" 81 | line += f" --cutoff {cutoff}" 82 | line += f" --qm_method {qm_method}" 83 | if esp: 84 | line += f" --esp" 85 | if keep_wfn: 86 | line += f" --keep_wfn" 87 | lines.append(line) 88 | with open(cmds_file_path, "w") as f: 89 | f.write("\n".join(lines)) 90 | return cmds_file_path, len(lines) 91 | 92 | 93 | def generate_bash_script_structure_prep_and_qm( 94 | cmds_file_path: str, time: str = "04:00:00", num_cores: int = 4, memory: int = 16, qm_method: str = "xtb" 95 | ) -> str: 96 | """Generate the bash script which runs the job array for run_all function. 97 | 98 | Parameters 99 | ---------- 100 | cmds_file_path : str 101 | path to cmds file for the job array 102 | time : str 103 | Slurm-formatted time specifier, by default 04:00:00 (increase for ESP) 104 | num_cores : int 105 | number of cores, by default 4 106 | memory : int 107 | total memory in GB (not MB/core; is being converted) 108 | 109 | Returns 110 | ------- 111 | str 112 | path to script file that runs the job array 113 | """ 114 | 115 | script_path = os.path.join(os.path.dirname(cmds_file_path), "jobscript.sh") 116 | out_files_folder = os.path.join(os.path.dirname(script_path), "out_files") 117 | os.makedirs(out_files_folder, exist_ok=True) 118 | conda_env = "bcpaff_psi4" if qm_method == "psi4" else "bcpaff" 119 | conda_env = "bcpaff" 120 | with open(script_path, "w") as f: 121 | f.write( 122 | f"""#!/bin/bash 123 | 124 | #SBATCH --job-name=bcpaff # Job name 125 | #SBATCH -n {num_cores} # Number of CPU cores 126 | #SBATCH --mem-per-cpu={int(memory/num_cores*1024)} # Memory per CPU in MB 127 | #SBATCH --time={time} # Maximum execution time (HH:MM:SS) 128 | #SBATCH --tmp=4000 # Total scratch for job in MB 129 | #SBATCH --output {os.path.join(out_files_folder, "structprep_qm_out_%A_%a.out")} # Standard output 130 | #SBATCH --error {os.path.join(out_files_folder, "structprep_qm_out_%A_%a.out")} # Standard error 131 | 132 | source ~/.bashrc; 133 | eval "$(conda shell.bash hook)"; conda activate {conda_env}; 134 | export CMDS_FILE_PATH={cmds_file_path} 135 | export cmd=$(head -$SLURM_ARRAY_TASK_ID $CMDS_FILE_PATH|tail -1) 136 | echo "=========SLURM_COMMAND=========" 137 | echo $cmd 138 | echo "=========SLURM_COMMAND=========" 139 | eval $cmd 140 | """ 141 | ) 142 | return script_path 143 | 144 | 145 | def submit_job_generate_pickle( 146 | search_path: str, 147 | affinity_data: str, 148 | dataset: str, 149 | qm_method: str, 150 | cluster_options: Optional[str], 151 | save_path: Optional[str] = None, 152 | nucleus_critical_points: bool = False, 153 | ): 154 | """Submit an sbatch job to generate pickle file based on input arguments. 155 | 156 | Parameters 157 | ---------- 158 | search_path : str 159 | search path for pre-processed structures 160 | affinity_data : str 161 | path to affinity data for PDBbind or PDE10A datasets 162 | save_path : str, optional 163 | path where pickle file will be saved, by default None 164 | nucleus_critical_points : bool, optional 165 | whether or not to use BCP-centric (False) or NCP-centric (True) graph, by default False 166 | """ 167 | job_id = None 168 | if save_path is None: 169 | ncp_bcp = "ncp" if nucleus_critical_points else "bcp" 170 | save_path = os.path.join(search_path, f"qtaim_props_{ncp_bcp}.pkl") 171 | cmd_str = "python -m bcpaff.ml.generate_pickle" 172 | cmd_str += f" --search_path {search_path}" 173 | cmd_str += f" --dataset {dataset}" 174 | cmd_str += f" --qm_method {qm_method}" 175 | cmd_str += f" --affinity_data {affinity_data}" 176 | cmd_str += f" --save_path {save_path}" 177 | if nucleus_critical_points: 178 | cmd_str += " --nucleus_critical_points" 179 | if cluster_options == "no_cluster": 180 | completed_process = subprocess.run(cmd_str, shell=True) 181 | else: 182 | slurm_output_file = os.path.join(search_path, "slurm_files", "out_files", "generate_pickle_out_%A.out") 183 | if cluster_options is None: 184 | cluster_options = "" 185 | completed_process = subprocess.run( 186 | f'sbatch --job-name=bcpaff_generate_pickle -n 8 --time=04:00:00 --mem-per-cpu=8192 {cluster_options} --parsable --output={slurm_output_file} --wrap "{cmd_str}"', 187 | shell=True, 188 | universal_newlines=True, 189 | stdout=subprocess.PIPE, 190 | ) 191 | job_id = completed_process.stdout.rstrip("\n") 192 | if completed_process.returncode != 0: 193 | print(completed_process.returncode) 194 | return job_id 195 | 196 | 197 | def homogenize_pickles(dataset: str, output_basepath: str = None): 198 | # only keep complexes which are present in all pickle files (NCP/BCP etc.) 199 | if output_basepath is None: 200 | output_basepath = os.path.join(PROCESSED_DATA_PATH, "prepared_structures", dataset) 201 | savepath_bcp_pickle = os.path.join(output_basepath, "qtaim_props_bcp.pkl") 202 | savepath_ncp_pickle = os.path.join(output_basepath, "qtaim_props_ncp.pkl") 203 | pickle_files = [savepath_bcp_pickle, savepath_ncp_pickle] 204 | keys_per_pickle_file = [] 205 | all_data = {} 206 | for pickle_file in pickle_files: 207 | with open(pickle_file, "rb") as f: 208 | data = pickle.load(f) 209 | all_data[pickle_file] = data 210 | keys_per_pickle_file.append(set(data.keys())) 211 | keys_in_all_pickle_files = set.intersection(*keys_per_pickle_file) 212 | 213 | for pickle_file in pickle_files: 214 | keys_to_remove = sorted(list(set(all_data[pickle_file].keys()) - keys_in_all_pickle_files)) 215 | savepath_report = os.path.join( 216 | REPORT_PATH, "structure_prep", dataset, f"removed_ids_{os.path.basename(pickle_file)}.txt" 217 | ) 218 | os.makedirs(os.path.dirname(savepath_report), exist_ok=True) 219 | with open(savepath_report, "w") as f: 220 | f.write( 221 | f"These IDs were removed from {pickle_file} because they didn't exist in all of the following pickle_files:\n" 222 | ) 223 | f.write("\n".join([f" - {p}" for p in pickle_files]) + "\n") 224 | f.write("\n".join(keys_to_remove)) 225 | new_data = {key: all_data[pickle_file][key] for key in keys_in_all_pickle_files} 226 | with open(pickle_file, "wb") as f: 227 | pickle.dump(new_data, f) 228 | print(f"Wrote homogenized data to {pickle_file}", flush=True) 229 | 230 | 231 | def report_data_processing_outcome(search_path: str, dataset: str, affinity_data: str): 232 | """Generate a report on outcome of data processing (number of successfully cleaned structures, 233 | missing structures, radicals etc.) 234 | 235 | Parameters 236 | ---------- 237 | search_path : str 238 | path to output folder from structure preparation (contains cleaned files) 239 | dataset : str 240 | which dataset, "pdbbind" or "pde10a" 241 | affinity_data : str 242 | path to csv file with affinity data, needed to retrieve the total list of compounds we start with 243 | """ 244 | 245 | # analysis for structure cleaning 246 | col_name = "pdb_id" if dataset == "pdbbind" else "docking_folder" 247 | df = pd.read_csv(affinity_data, dtype={col_name: "str"}, parse_dates=False) 248 | paths_files = sorted(glob.glob(os.path.join(search_path, "*", "paths.pdb"))) 249 | completed_ids = [os.path.basename(os.path.dirname(x)) for x in paths_files] 250 | uhf_error_ids = [x for x in completed_ids if os.path.exists(os.path.join(search_path, x, "uhf_error.txt"))] 251 | all_ids = sorted(df[col_name].unique().astype(str).tolist()) 252 | missing_ids = sorted(set(all_ids) - set(completed_ids)) 253 | base_report_dir = os.path.join(REPORT_PATH, "structure_prep", dataset) 254 | os.makedirs(base_report_dir, exist_ok=True) 255 | with open(os.path.join(base_report_dir, "completed_ids.txt"), "w") as f: 256 | f.write("\n".join(completed_ids)) 257 | with open(os.path.join(base_report_dir, "missing_ids.txt"), "w") as f: 258 | f.write("\n".join(missing_ids)) 259 | with open(os.path.join(base_report_dir, "uhf_error_ids.txt"), "w") as f: 260 | f.write("\n".join(uhf_error_ids)) 261 | with open(os.path.join(base_report_dir, "overview.txt"), "w") as f: 262 | f.write(f"{len(completed_ids)} compounds for which structure preparation was successful.\n\n") 263 | f.write(f"Of those, {len(uhf_error_ids)} compounds had uhf_error. Those were NOT removed.\n\n") 264 | f.write(f"{len(missing_ids)} compounds are missing, {len(all_ids)} compounds exist in total.\n\n") 265 | 266 | # analysis for pickle generation 267 | for ncp_bcp in ["bcp", "ncp"]: 268 | ncp_bcp_base_report_dir = os.path.join(base_report_dir, ncp_bcp) 269 | os.makedirs(ncp_bcp_base_report_dir, exist_ok=True) 270 | pickle_file = os.path.join(search_path, f"qtaim_props_{ncp_bcp}.pkl") 271 | with open(pickle_file, "rb") as f: 272 | data = pickle.load(f) 273 | pickle_ids = [str(x) for x in sorted(list(data.keys()))] 274 | missing_pickle_ids = sorted(set(completed_ids) - set(pickle_ids)) 275 | with open(os.path.join(ncp_bcp_base_report_dir, f"pickle_missing_{ncp_bcp}.txt"), "w") as f: 276 | f.write("\n".join(missing_pickle_ids)) 277 | with open(os.path.join(ncp_bcp_base_report_dir, "overview.txt"), "w") as f: 278 | f.write(f"{pickle_file}\n") 279 | f.write( 280 | f"{len(missing_pickle_ids)} compounds for which structure preparation was successful didn't end up in the pickle file.\n\n" 281 | ) 282 | 283 | 284 | def submit_structure_prep_and_qm_jobs(args: argparse.Namespace, dataset: str, cluster_options=None) -> Optional[str]: 285 | base_input_dir = INPUT_FOLDERS[dataset] 286 | if args.output_basepath is None: 287 | base_output_dir = os.path.join(PROCESSED_DATA_PATH, "prepared_structures", dataset) 288 | else: 289 | base_output_dir = args.output_basepath 290 | cmds_file_path, num_lines = generate_jobscript_structure_prep_and_qm( 291 | dataset, 292 | base_input_dir, 293 | base_output_dir, 294 | test_run=args.test_run, 295 | esp=args.esp, 296 | qm_method=args.qm_method, 297 | cutoff=args.cutoff, 298 | keep_wfn=args.keep_wfn, 299 | ) 300 | script_path = generate_bash_script_structure_prep_and_qm( 301 | cmds_file_path, time="04:00:00", num_cores=1, memory=8, qm_method=args.qm_method 302 | ) 303 | job_id = submit_job(script_path, num_lines=num_lines, cluster_options=cluster_options) 304 | return job_id 305 | 306 | 307 | def submit_pickle_generation_jobs( 308 | dataset: str, cluster_options: Optional[str] = None, output_basepath: Optional[str] = None, qm_method: str = "xtb" 309 | ) -> Optional[str]: 310 | # generate pickle 311 | if output_basepath is None: 312 | search_path = os.path.join(PROCESSED_DATA_PATH, "prepared_structures", dataset) 313 | else: 314 | search_path = output_basepath 315 | print("=====================================") 316 | print(search_path) 317 | print("=====================================") 318 | 319 | job_ids = [] 320 | affinity_data = AFFINITY_DATA[dataset] 321 | savepath = os.path.join(search_path, "qtaim_props_bcp.pkl") 322 | job_id = submit_job_generate_pickle( 323 | search_path, 324 | affinity_data, 325 | dataset, 326 | qm_method=qm_method, 327 | nucleus_critical_points=False, 328 | cluster_options=cluster_options, 329 | save_path=savepath, 330 | ) 331 | job_ids.append(job_id) 332 | savepath = os.path.join(search_path, "qtaim_props_ncp.pkl") 333 | job_id = submit_job_generate_pickle( 334 | search_path, 335 | affinity_data, 336 | dataset, 337 | qm_method=qm_method, 338 | nucleus_critical_points=True, 339 | cluster_options=cluster_options, 340 | save_path=savepath, 341 | ) 342 | job_ids.append(job_id) 343 | if not all([j is None for j in job_ids]): 344 | return ",".join(job_ids) 345 | 346 | 347 | def submit_homogenize_pickle_jobs( 348 | dataset: str, cluster_options: Optional[str] = None, output_basepath: Optional[str] = None 349 | ): 350 | if output_basepath is None: 351 | search_path = os.path.join(PROCESSED_DATA_PATH, "prepared_structures", dataset) 352 | else: 353 | search_path = output_basepath 354 | 355 | cmd_str = f"""python -c 'from bcpaff.data_processing.data_processing import homogenize_pickles; homogenize_pickles(\"{dataset}\", \"{search_path}\")' """ 356 | cluster_options = "" if cluster_options is None else cluster_options 357 | if cluster_options != "no_cluster": 358 | slurm_output_file = os.path.join( 359 | PROCESSED_DATA_PATH, 360 | "prepared_structures", 361 | dataset, 362 | "slurm_files", 363 | "out_files", 364 | "homogenize_pickles_out_%A.out", 365 | ) 366 | cmd_str = f"""sbatch --parsable {cluster_options} --output={slurm_output_file} --wrap "{cmd_str}" """ 367 | completed_process = subprocess.run(cmd_str, shell=True, universal_newlines=True, stdout=subprocess.PIPE,) 368 | job_id = completed_process.stdout.rstrip("\n") 369 | return job_id 370 | 371 | 372 | if __name__ == "__main__": 373 | parser = argparse.ArgumentParser() 374 | parser.add_argument("--action", type=str, default="all") 375 | parser.add_argument("--test_run", action="store_true", default=False, dest="test_run") 376 | parser.add_argument("--dataset", type=str, default=None) 377 | parser.add_argument("--cutoff", type=float, default=6) 378 | parser.add_argument("--qm_method", type=str, default="xtb") 379 | parser.add_argument("--output_basepath", type=str, default=None) 380 | parser.add_argument("--esp", action="store_true", default=False, dest="esp") 381 | parser.add_argument("--cluster_options", type=str, default=None) 382 | parser.add_argument("--keep_wfn", action="store_true", default=False, dest="keep_wfn") 383 | 384 | args = parser.parse_args() 385 | 386 | datasets_to_run = DATASETS if args.dataset is None else [args.dataset] 387 | 388 | if args.action == "structure_prep_and_qm": 389 | for dataset in datasets_to_run: 390 | job_ids = submit_structure_prep_and_qm_jobs(args, dataset, cluster_options=args.cluster_options) 391 | 392 | elif args.action == "generate_pickle": 393 | for dataset in datasets_to_run: 394 | job_ids = submit_pickle_generation_jobs( 395 | dataset, 396 | cluster_options=args.cluster_options, 397 | output_basepath=args.output_basepath, 398 | qm_method=args.qm_method, 399 | ) 400 | 401 | elif args.action == "homogenize_pickles": 402 | # make sure that both NCP- and BCP-based pickle-files contain the same data 403 | for dataset in datasets_to_run: 404 | submit_homogenize_pickle_jobs( 405 | dataset, cluster_options=args.cluster_options, output_basepath=args.output_basepath 406 | ) 407 | 408 | elif args.action == "report": 409 | # do some reporting on the outcome of the data preparation 410 | # number of radicals, number of failed structures etc. 411 | # (only xtb, no dftb+) 412 | for dataset, affinity_data in AFFINITY_DATA.items(): 413 | search_path = os.path.join(PROCESSED_DATA_PATH, "prepared_structures", dataset) 414 | report_data_processing_outcome(search_path, dataset, affinity_data) 415 | 416 | elif args.action == "all": 417 | for dataset in datasets_to_run: 418 | job_ids = submit_structure_prep_and_qm_jobs(args, dataset, cluster_options=args.cluster_options) 419 | 420 | cluster_options = ( 421 | f"--dependency=afterok:{job_ids}" if args.cluster_options is None else args.cluster_options 422 | ) # respect no_cluster 423 | job_ids = submit_pickle_generation_jobs(dataset, cluster_options=cluster_options) 424 | cluster_options = ( 425 | f"--dependency=afterok:{job_ids}" if args.cluster_options is None else args.cluster_options 426 | ) # respect no_cluster 427 | submit_homogenize_pickle_jobs(dataset, cluster_options=cluster_options) 428 | 429 | else: 430 | print("Unknown action!") 431 | -------------------------------------------------------------------------------- /bcpaff/data_processing/download.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import os 6 | import shutil 7 | import tarfile 8 | import zipfile 9 | 10 | import requests 11 | from tqdm import tqdm 12 | 13 | from bcpaff.utils import DATA_PATH 14 | 15 | INDEX_URL = "http://www.pdbbind.org.cn/download/PDBbind_v2019_plain_text_index.tar.gz" 16 | 17 | PDBBIND_STUCTURES_URL = "http://bioinfo-pharma.u-strasbg.fr/labwebsite/downloads/pdbbind.tgz" # PDBbind processed by Volkov et al. (10.1021/acs.jmedchem.2c00487) 18 | PDE10A_STRUCTURES_URL = ( 19 | "https://figshare.com/ndownloader/files/37712256" # Tosstorff et al (10.1007/s10822-022-00478-x) 20 | ) 21 | PDE10A_AFFINITY_URL = "https://static-content.springer.com/esm/art%3A10.1007%2Fs10822-022-00478-x/MediaObjects/10822_2022_478_MOESM2_ESM.csv" 22 | 23 | 24 | def download(src: str, dest: str): 25 | """Simple requests.get with a progress bar 26 | Parameters 27 | ---------- 28 | src : str 29 | Remote path to be downloaded 30 | dest : str 31 | Local path for the download 32 | Returns 33 | ------- 34 | None 35 | """ 36 | r = requests.get(src, stream=True) 37 | tsize = int(r.headers.get("content-length", 0)) 38 | progress = tqdm(total=tsize, unit="iB", unit_scale=True, position=0, leave=False) 39 | 40 | with open(dest, "wb") as handle: 41 | progress.set_description(os.path.basename(dest)) 42 | for chunk in r.iter_content(chunk_size=1024): 43 | handle.write(chunk) 44 | progress.update(len(chunk)) 45 | 46 | 47 | def download_pdbbind_data(): 48 | """ 49 | Download PDBbind data as prepared by Volkov et al. 50 | (10.1021/acs.jmedchem.2c00487) 51 | """ 52 | pdbbind_structure_path = os.path.join(DATA_PATH, "pdbbind") 53 | pdbbind_csv = os.path.join( 54 | pdbbind_structure_path, "PDBbind_v2019_plain_text_index/plain-text-index/index/INDEX_general_PL_data.2019" 55 | ) 56 | if not os.path.exists(pdbbind_csv): 57 | os.makedirs(pdbbind_structure_path, exist_ok=True) 58 | 59 | # download affinity data 60 | dest_archive = os.path.join(pdbbind_structure_path, os.path.basename(INDEX_URL)) 61 | dest_extract = dest_archive[: -len(".tar.gz")] 62 | download(INDEX_URL, dest_archive) 63 | with tarfile.open(dest_archive) as handle: 64 | handle.extractall(dest_extract) 65 | os.remove(dest_archive) 66 | 67 | # download structure data 68 | dest_archive = os.path.join(pdbbind_structure_path, os.path.basename(PDBBIND_STUCTURES_URL)) 69 | dest_extract = dest_archive[: -len(".tgz")] 70 | download(PDBBIND_STUCTURES_URL, dest_archive) 71 | print("Extracting, takes a few minutes...") 72 | with tarfile.open(dest_archive) as handle: 73 | handle.extractall(dest_extract) 74 | shutil.move(os.path.join(dest_extract, "dataset"), os.path.join(pdbbind_structure_path, "dataset")) 75 | shutil.move(os.path.join(dest_extract, "pdb_ids"), os.path.join(pdbbind_structure_path, "pdb_ids")) 76 | os.remove(dest_archive) 77 | shutil.rmtree(dest_extract) 78 | 79 | 80 | def download_pde10a_data(): 81 | """ 82 | Download PDE10A inhibitor data from Tosstorff et al. 83 | (10.1007/s10822-022-00478-x) 84 | """ 85 | pde10a_structure_path = os.path.join(DATA_PATH, "pde10a") 86 | pde10a_csv = os.path.join(pde10a_structure_path, "10822_2022_478_MOESM2_ESM.csv") 87 | if not os.path.exists(pde10a_csv): 88 | os.makedirs(pde10a_structure_path, exist_ok=True) 89 | 90 | download(PDE10A_AFFINITY_URL, pde10a_csv) 91 | 92 | dest_archive = os.path.join(pde10a_structure_path, "pde-10_pdb_bind_format_blinded.zip") 93 | download(PDE10A_STRUCTURES_URL, dest_archive) 94 | print("Extracting, takes a few minutes...") 95 | 96 | with zipfile.ZipFile(dest_archive, "r") as handle: 97 | handle.extractall(pde10a_structure_path) 98 | 99 | os.remove(dest_archive) 100 | if os.path.exists(os.path.join(pde10a_structure_path, "__MACOSX")): 101 | shutil.rmtree(os.path.join(pde10a_structure_path, "__MACOSX")) 102 | 103 | 104 | if __name__ == "__main__": 105 | download_pdbbind_data() 106 | download_pde10a_data() 107 | -------------------------------------------------------------------------------- /bcpaff/data_processing/filtering.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import os 6 | 7 | import pandas as pd 8 | 9 | from bcpaff.utils import DATA_PATH 10 | 11 | other_binding_sites_txt = os.path.join(DATA_PATH, "pdbbind/verified_other_binding_sites.txt") 12 | OTHER_BINDING_SITES = pd.read_csv(other_binding_sites_txt, sep=";", header=4).pdb_id.tolist() 13 | 14 | 15 | ALLOSTERIC_BINDING_SITES = [] 16 | with open( 17 | os.path.join(DATA_PATH, "pdbbind/PDBbind_v2019_plain_text_index/plain-text-index/index/INDEX_general_PL.2019"), "r" 18 | ) as f: 19 | for line in f.readlines(): 20 | if "allosteric" in line.lower(): 21 | ALLOSTERIC_BINDING_SITES.append(line.split(" ")[0]) 22 | -------------------------------------------------------------------------------- /bcpaff/data_processing/structure_prep_and_qm.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import argparse 6 | import os 7 | 8 | from rdkit import Chem 9 | 10 | from bcpaff.data_processing.manual_structure_prep import full_structure_prep 11 | from bcpaff.qtaim.critic2_tools import run_critic2_analysis 12 | from bcpaff.qtaim.multiwfn_tools import run_multiwfn_analysis 13 | from bcpaff.utils import DATA_PATH, PROCESSED_DATA_PATH 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("structure_id", type=str) 18 | parser.add_argument("--dataset", type=str, default="pdbbind") 19 | parser.add_argument("--cutoff", type=float, default=6) 20 | parser.add_argument("--qm_method", type=str, default="xtb") 21 | parser.add_argument("--num_cores", type=int, default=1) 22 | parser.add_argument("--basepath", type=str, default=None) 23 | parser.add_argument("--output_basepath", type=str, default=None) 24 | parser.add_argument("--implicit_solvent", type=str, default=None) 25 | parser.add_argument("--esp", action="store_true", default=False, dest="esp") 26 | parser.add_argument("--keep_wfn", action="store_true", default=False, dest="keep_wfn") 27 | 28 | args = parser.parse_args() 29 | if args.output_basepath is None: 30 | output_basepath = os.path.join(PROCESSED_DATA_PATH, "prepared_structures", args.dataset) 31 | else: 32 | output_basepath = args.output_basepath 33 | 34 | if args.basepath is None: 35 | if args.dataset == "pdbbind": 36 | basepath = os.path.join(DATA_PATH, args.dataset, "dataset") 37 | elif args.dataset == "pde10a": 38 | basepath = os.path.join(DATA_PATH, args.dataset, "pde-10_pdb_bind_format_blinded") 39 | elif args.dataset == "d2dr": 40 | basepath = os.path.join(DATA_PATH, args.dataset, "D2DR_complexes_prepared") 41 | else: 42 | basepath = args.basepath 43 | 44 | # 1) structure prep 45 | full_structure_prep( 46 | basepath, 47 | structure_id=args.structure_id, 48 | output_basepath=output_basepath, 49 | cutoff=args.cutoff, 50 | dataset=args.dataset, 51 | ) 52 | out_folder = os.path.join(output_basepath, args.structure_id) 53 | 54 | # 2) run xTB/Psi4 55 | if args.qm_method == "psi4": 56 | from bcpaff.qm.compute_wfn_psi4 import compute_wfn_psi4 57 | 58 | # different environments for psi4 and xTB, so need to import depending on use case 59 | 60 | psi4_input_pickle = os.path.join(out_folder, f"psi4_input.pkl") 61 | compute_wfn_psi4(psi4_input_pickle, num_cores=args.num_cores, memory=140) 62 | wfn_file = os.path.join(out_folder, "wfn.fchk") 63 | elif args.qm_method == "xtb": 64 | from bcpaff.qm.compute_wfn_xtb import compute_wfn_xtb 65 | 66 | xyz_path = os.path.join(out_folder, f"pl_complex.xyz") 67 | compute_wfn_xtb(xyz_path, args.implicit_solvent) 68 | wfn_file = os.path.join(out_folder, "molden.input") 69 | elif args.qm_method.startswith("dftb"): 70 | from bcpaff.qm.compute_wfn_dftb import compute_wfn_dftb 71 | 72 | xyz_path = os.path.join(out_folder, f"pl_complex.xyz") 73 | compute_wfn_dftb(xyz_path, args.qm_method, args.implicit_solvent) 74 | wfn_file = os.path.join(out_folder, "detailed.xml") 75 | else: 76 | raise ValueError("Invalid qm_method") 77 | 78 | # 3) run multiwfn/critic2 79 | ligand_sdf = os.path.join(out_folder, f"{args.structure_id}_ligand_with_hydrogens.sdf") 80 | if args.qm_method.startswith("dftb"): # run critic2 81 | output_cri = run_critic2_analysis(wfn_file) 82 | else: # run multiwfn 83 | cp_file, cpprop_file, paths_file = run_multiwfn_analysis( 84 | wfn_file=wfn_file, 85 | only_intermolecular=False, 86 | only_bcps=False, 87 | num_ligand_atoms=next(Chem.SDMolSupplier(ligand_sdf, removeHs=False, sanitize=False)).GetNumAtoms(), 88 | include_esp=args.esp, 89 | ) 90 | 91 | # 4) potential cleanup 92 | if not args.keep_wfn: 93 | os.remove(wfn_file) # to save space, those files get quite big 94 | -------------------------------------------------------------------------------- /bcpaff/ml/analysis.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import os 6 | from typing import Optional, Tuple 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import pandas as pd 11 | from matplotlib.lines import Line2D 12 | 13 | from bcpaff.ml import statsig 14 | from bcpaff.ml.generate_pickle import ESP_NAMES 15 | from bcpaff.ml.test import datasets_and_splits 16 | from bcpaff.utils import DATA_PATH, DEFAULT_PROPS, PROCESSED_DATA_PATH 17 | 18 | ANALYSIS_SAVEPATH = os.path.join(PROCESSED_DATA_PATH, "analysis") 19 | 20 | 21 | def get_mad_val(dataset: str, split_type: str) -> Tuple[float, float, float]: 22 | if dataset == "pdbbind": 23 | split_assignment_df = pd.read_csv(os.path.join(DATA_PATH, "pdbbind", "pdbbind2019_affinity.csv")) 24 | y_mad_val = split_assignment_df[split_assignment_df["split"] == "validation_set"].aff.values 25 | yhat_mad_val = split_assignment_df[split_assignment_df["split"] == "training_set"].aff.mean() * np.ones( 26 | y_mad_val.shape 27 | ) 28 | elif dataset == "pde10a": 29 | split_assignment_df = pd.read_csv(os.path.join(DATA_PATH, "pde10a", "10822_2022_478_MOESM2_ESM.csv")) 30 | split_type = split_type + "_split" 31 | y_mad_val = split_assignment_df[split_assignment_df[split_type] == "val"].pic50.values 32 | yhat_mad_val = split_assignment_df[split_assignment_df[split_type] == "train"].pic50.mean() * np.ones( 33 | y_mad_val.shape 34 | ) 35 | rmse_mad_val, le_mad_val, ue_mad_val = statsig.rmse(y_mad_val, yhat_mad_val) 36 | return rmse_mad_val, le_mad_val, ue_mad_val 37 | 38 | 39 | def plot_hparam_results( 40 | base_output_dir: str, 41 | dataset: str, 42 | split_type: str, 43 | savepath: Optional[str] = None, 44 | df: Optional[pd.DataFrame] = None, 45 | top_n: int = 50, 46 | ): 47 | if df is None: 48 | df = pd.read_csv(os.path.join(base_output_dir, "hparam_results.csv")) 49 | 50 | fig = plt.figure(figsize=(15, 7)) 51 | ax = fig.add_subplot(111) 52 | sub_df = df.sort_values(by="eval_rmse")[:top_n] 53 | colors = [] 54 | for _, row in sub_df.iterrows(): 55 | if not row.baseline_atom_ids: 56 | colors.append("white") 57 | else: 58 | if row.properties == "n" * (len(DEFAULT_PROPS) + len(ESP_NAMES)): 59 | colors.append("green") 60 | else: 61 | colors.append("orange") 62 | x = range(len(sub_df)) 63 | heights = sub_df.eval_rmse.tolist() 64 | ax.scatter(x, heights, color=colors, marker="D", s=100, zorder=5, edgecolor="black") 65 | errorbars = np.array([sub_df.rmse_le.tolist(), sub_df.rmse_ue.tolist()], dtype="float") 66 | ax.errorbar(x, heights, yerr=errorbars, linestyle="", color="black", zorder=4) 67 | ax.set_ylim(0.7, 1.8) 68 | ax.set_xlim(-1, len(sub_df)) 69 | ax.set_xlabel("Experiments ranked by validation set RMSE", fontsize=16) 70 | ax.set_ylabel("Validation set RMSE", fontsize=16) 71 | ax.set_xticks([]) 72 | ax.tick_params(axis="both", which="major", labelsize=14) 73 | custom_lines = [ 74 | Line2D([], [], color="green", lw=4, marker="D", markersize=10, markeredgecolor="black", linestyle="None"), 75 | Line2D([], [], color="white", lw=4, marker="D", markersize=10, markeredgecolor="black", linestyle="None"), 76 | Line2D([], [], color="orange", lw=4, marker="D", markersize=10, markeredgecolor="black", linestyle="None"), 77 | Line2D([0], [0], color="black", linestyle="--", lw=1), 78 | ] 79 | labels = ["Atom-IDs", "BCPs", "Atom-IDs & \nBCPs", "MAD"] 80 | rmse_mad_val, le_mad_val, ue_mad_val = get_mad_val(dataset, split_type) 81 | ax.plot([min(x) - 10, max(x) + 10], [rmse_mad_val, rmse_mad_val], color="black", linestyle="--") 82 | ax.fill_between( 83 | [min(x) - 10, max(x) + 10], 84 | [rmse_mad_val - le_mad_val, rmse_mad_val - le_mad_val], 85 | [rmse_mad_val + ue_mad_val, rmse_mad_val + ue_mad_val], 86 | color="gray", 87 | alpha=0.2, 88 | ) 89 | ax.legend( 90 | custom_lines, 91 | labels, 92 | fontsize=16, 93 | ncol=1, 94 | bbox_to_anchor=(1, 0.5), 95 | bbox_transform=ax.transAxes, 96 | loc="center left", 97 | ) 98 | if savepath is not None: 99 | os.makedirs(os.path.dirname(savepath), exist_ok=True) 100 | fig.savefig(savepath, dpi=300, bbox_inches="tight") 101 | 102 | 103 | def plot_hparam_results_esp( 104 | base_output_dir: str, 105 | dataset: str, 106 | split_type: str, 107 | savepath: Optional[str] = None, 108 | df: Optional[pd.DataFrame] = None, 109 | top_n: int = 50, 110 | ): 111 | if df is None: 112 | df = pd.read_csv(os.path.join(base_output_dir, "hparam_results.csv")) 113 | 114 | fig = plt.figure(figsize=(15, 7)) 115 | ax = fig.add_subplot(111) 116 | sub_df = df.sort_values(by="eval_rmse")[:top_n] 117 | colors = [] 118 | h_esp = [] 119 | for _, row in sub_df.iterrows(): 120 | if not row.baseline_atom_ids: # only QM props 121 | colors.append("white") 122 | else: # baseline_atom_ids + QM props 123 | if row.properties == "n" * (len(DEFAULT_PROPS) + len(ESP_NAMES)): # no props 124 | colors.append("green") 125 | else: # some QM props 126 | colors.append("orange") 127 | 128 | h_esp.append(True) if "y" in row.properties[-3:] else h_esp.append(False) 129 | 130 | x = range(len(sub_df)) 131 | heights = sub_df.eval_rmse.tolist() 132 | for this_x, height, color, esp in zip(x, heights, colors, h_esp): 133 | ax.scatter(this_x, height, color=color, marker="D", s=100, zorder=5, edgecolor="black") 134 | if esp: 135 | ax.scatter(this_x, 1.75, color="black", marker="*", s=30) 136 | errorbars = np.array([sub_df.rmse_le.tolist(), sub_df.rmse_ue.tolist()], dtype="float") 137 | ax.errorbar(x, heights, yerr=errorbars, linestyle="", color="black", zorder=4) 138 | ax.set_ylim(0.7, 1.8) 139 | ax.set_xlim(-1, len(sub_df)) 140 | ax.set_xlabel("Experiments ranked by validation set RMSE", fontsize=16) 141 | ax.set_ylabel("Validation set RMSE", fontsize=16) 142 | ax.set_xticks([]) 143 | ax.tick_params(axis="both", which="major", labelsize=14) 144 | custom_lines = [ 145 | Line2D([], [], color="green", lw=4, marker="D", markersize=10, markeredgecolor="black", linestyle="None"), 146 | Line2D([], [], color="white", lw=4, marker="D", markersize=10, markeredgecolor="black", linestyle="None"), 147 | Line2D([], [], color="orange", lw=4, marker="D", markersize=10, markeredgecolor="black", linestyle="None"), 148 | Line2D([0], [0], color="black", linestyle="--", lw=1), 149 | ] 150 | labels = ["Atom-IDs", "BCPs", "Atom-IDs & \nBCPs", "MAD"] 151 | rmse_mad_val, le_mad_val, ue_mad_val = get_mad_val(dataset, split_type) 152 | ax.plot([min(x) - 10, max(x) + 10], [rmse_mad_val, rmse_mad_val], color="black", linestyle="--") 153 | ax.fill_between( 154 | [min(x) - 10, max(x) + 10], 155 | [rmse_mad_val - le_mad_val, rmse_mad_val - le_mad_val], 156 | [rmse_mad_val + ue_mad_val, rmse_mad_val + ue_mad_val], 157 | color="gray", 158 | alpha=0.2, 159 | ) 160 | ax.legend( 161 | custom_lines, 162 | labels, 163 | fontsize=16, 164 | ncol=1, 165 | bbox_to_anchor=(1, 0.5), 166 | bbox_transform=ax.transAxes, 167 | loc="center left", 168 | ) 169 | if savepath is not None: 170 | os.makedirs(os.path.dirname(savepath), exist_ok=True) 171 | fig.savefig(savepath, dpi=300, bbox_inches="tight") 172 | 173 | 174 | if __name__ == "__main__": 175 | # hparam visualization 176 | dataset = "pde10a" 177 | for ncp_bcp in ["bcp", "ncp"]: 178 | for split_type in datasets_and_splits[dataset]: 179 | base_output_dir = os.path.join(PROCESSED_DATA_PATH, "model_runs_esp", ncp_bcp, dataset, split_type) 180 | savepath = os.path.join( 181 | ANALYSIS_SAVEPATH, "hparam_visualization_esp", f"{dataset}_{split_type}_{ncp_bcp}.png" 182 | ) 183 | plot_hparam_results_esp( 184 | base_output_dir, dataset, split_type, savepath=savepath, df=None, top_n=99999 185 | ) # plot all 186 | -------------------------------------------------------------------------------- /bcpaff/ml/cluster_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import os 6 | import subprocess 7 | from typing import List, Optional 8 | 9 | 10 | def generate_cmds_file( 11 | cmds_file_path: str, 12 | config_files: List[str], 13 | pickle_file: str, 14 | dataset: str = "pdbbind", 15 | split_type: str = "random", 16 | base_output_dir: Optional[str] = None, 17 | overwrite=False, 18 | hparam_file: Optional[str] = None, 19 | num_epochs: int = 300, 20 | ) -> int: 21 | lines = [] 22 | for config_file in config_files: 23 | python_cmd = "python -m bcpaff.ml.train" 24 | python_cmd += f" --pickle_file {pickle_file}" 25 | python_cmd += f" --config_file {config_file}" 26 | python_cmd += f" --dataset {dataset}" 27 | python_cmd += f" --split_type {split_type}" 28 | python_cmd += f" --num_epochs {num_epochs}" 29 | if base_output_dir is not None: 30 | python_cmd += f" --base_output_dir {base_output_dir}" 31 | if overwrite: 32 | python_cmd += " --overwrite" 33 | if hparam_file is not None: 34 | python_cmd += f" --hparam_file {hparam_file}" 35 | lines.append(python_cmd) 36 | with open(cmds_file_path, "w") as f: 37 | f.write("\n".join(lines)) 38 | return len(lines) 39 | 40 | 41 | def generate_bash_script( 42 | script_path: str, 43 | cmds_file_path: str, 44 | time: str = "04:00:00", 45 | num_cores: int = 4, 46 | memory: int = 16, 47 | ): 48 | out_files_folder = os.path.join(os.path.dirname(script_path), "out_files") 49 | os.makedirs(out_files_folder, exist_ok=True) 50 | with open(script_path, "w") as f: 51 | f.write( 52 | f"""#!/bin/bash 53 | 54 | #SBATCH --job-name=bcp_model_run # Job name 55 | #SBATCH -n {num_cores} # Number of CPU cores 56 | #SBATCH --mem-per-cpu={int(memory/num_cores*1024)} # Memory per CPU in MB 57 | #SBATCH --gpus=1 # Number of GPUs 58 | #SBATCH --time={time} # Maximum execution time (HH:MM:SS) 59 | #SBATCH --tmp=4000 # Total scratch for job in MB 60 | #SBATCH --output {os.path.join(out_files_folder, "out_%A_%a.out")} # Standard output 61 | #SBATCH --error {os.path.join(out_files_folder, "out_%A_%a.out")} # Standard error 62 | 63 | 64 | source ~/.bashrc; 65 | source activate bcpaff; 66 | export CMDS_FILE_PATH={cmds_file_path} 67 | export cmd=$(head -$SLURM_ARRAY_TASK_ID $CMDS_FILE_PATH|tail -1) 68 | echo "=========SLURM_COMMAND=========" 69 | echo $cmd 70 | echo "=========SLURM_COMMAND=========" 71 | eval $cmd 72 | """ 73 | ) 74 | 75 | 76 | def submit_job(script_path: str, num_lines: int = None, cluster_options: Optional[str] = None): 77 | job_id = None 78 | if cluster_options == "no_cluster": 79 | if num_lines is not None: # job array --> run iteratively 80 | env = os.environ 81 | for i in range(num_lines): 82 | env["SLURM_ARRAY_TASK_ID"] = f"{i + 1}" # not zero-indexed 83 | completed_process = subprocess.run(f"bash {script_path}", shell=True, env=env) 84 | else: # no job array --> can run directly 85 | completed_process = subprocess.run(f"bash {script_path}", shell=True, env=env) 86 | else: 87 | if num_lines is not None: 88 | if cluster_options is None: 89 | cmd_str = f"sbatch --parsable --array=1-{num_lines} < {script_path}" 90 | else: 91 | cmd_str = ( 92 | f"sbatch --dependency=after:{cluster_options}:+5 --parsable --array=1-{num_lines} < {script_path}" 93 | ) 94 | completed_process = subprocess.run(cmd_str, shell=True, stdout=subprocess.PIPE, universal_newlines=True) 95 | else: 96 | if cluster_options is None: 97 | cmd_str = f"sbatch --parsable < {script_path}" 98 | else: 99 | cmd_str = f"sbatch --dependency=after:{cluster_options}:+5 --parsable < {script_path}" 100 | completed_process = subprocess.run(cmd_str, shell=True, stdout=subprocess.PIPE, universal_newlines=True) 101 | job_id = completed_process.stdout.rstrip("\n") 102 | if completed_process.returncode != 0: 103 | print(completed_process.returncode) 104 | return job_id 105 | 106 | 107 | if __name__ == "__main__": 108 | submit_job(pickle_file="test", config_file="test2", no_cluster=True) 109 | -------------------------------------------------------------------------------- /bcpaff/ml/generate_pickle.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import argparse 6 | import glob 7 | import os 8 | import pickle 9 | from typing import Dict, Union 10 | 11 | import networkx as nx 12 | import numpy as np 13 | import pandas as pd 14 | import torch 15 | from joblib import Parallel, delayed 16 | from rdkit import Chem 17 | from scipy.spatial.distance import cdist 18 | from tqdm import tqdm 19 | 20 | from bcpaff.qtaim.qtaim_reader import QtaimProps, QtaimPropsCritic2 21 | from bcpaff.utils import ATOM_NEIGHBOR_IDS, DATA_PATH, DEFAULT_PROPS, OTHER 22 | 23 | MAX_INTERACTION_DISTANCE = 3 # Angstrom 24 | 25 | COL_NAMES = { 26 | "pdbbind": {"id_col": "pdb_id", "aff_col": "aff"}, 27 | "pde10a": {"id_col": "docking_folder", "aff_col": "pic50"}, 28 | } 29 | NULL_COORDS = torch.FloatTensor([0.0] * 3) 30 | ESP_NAMES = ["esp", "esp_nuc", "esp_ele"] 31 | 32 | 33 | def get_structure_id(folder: str, dataset: str) -> int: 34 | return os.path.basename(folder) 35 | # if dataset == "pdbbind" or data: 36 | # elif dataset == "pde10a": 37 | # return int(os.path.basename(folder).split("_")[-1]) 38 | 39 | 40 | def get_bcp_graph_from_cp_prop_file(df: pd.DataFrame, cpprop_file: str, dataset: str, **kwargs) -> Union[Dict, None]: 41 | """Same as get_graph_from_cp_prop_file_intermolecular, but for BCP-based graphs. Uses only intermolecular BCPs 42 | 43 | Parameters 44 | ---------- 45 | df : pd.DataFrame 46 | DataFrame with critical-point information 47 | cpprop_file : str 48 | path to CPprop.txt file from Multiwfn output 49 | dataset : str 50 | type of dataset, either "pdbbind" or "pde10a" 51 | only_key_interactions: bool 52 | only for PDE10A dataset, use only interactions with Y693 & Q726 53 | only_short_interactions: bool 54 | use only interactions (BCPs) where the corresponding atoms are <= MAX_INTERACTION_DISTANCE apart 55 | no_c_c_interactions: bool 56 | remove BCPs between two carbon atoms 57 | 58 | Raises 59 | ------ 60 | ValueError 61 | in case of strange assigments of atom neighbors 62 | """ 63 | folder = os.path.dirname(cpprop_file) 64 | structure_id = get_structure_id(folder, dataset) 65 | docking_folder = os.path.basename(folder) 66 | try: 67 | ligand_sdf = os.path.join(folder, f"{os.path.basename(folder)}_ligand_with_hydrogens.sdf") 68 | pl_complex_xyz = os.path.join(folder, "pl_complex.xyz") 69 | cp_file = os.path.join(folder, "CPs.txt") 70 | paths_file = os.path.join(folder, "paths.txt") 71 | qtaim_props = QtaimProps( 72 | cp_file=cp_file, 73 | cpprop_file=cpprop_file, 74 | paths_file=paths_file, 75 | ligand_sdf=ligand_sdf, 76 | pl_complex_xyz=pl_complex_xyz, 77 | identifier=structure_id, 78 | ) 79 | include_esp = "esp" in qtaim_props.critical_points[0].props 80 | prop_list = DEFAULT_PROPS + ESP_NAMES if include_esp else DEFAULT_PROPS 81 | 82 | target = df[df[COL_NAMES[dataset]["id_col"]] == structure_id][COL_NAMES[dataset]["aff_col"]].values[0] 83 | 84 | G = nx.Graph(target=target, include_esp=include_esp) 85 | 86 | bcps = [cp for cp in qtaim_props.critical_points if cp.name == "bond_critical_point" and cp.intermolecular] 87 | 88 | if kwargs["only_key_interactions"]: # remove bcps that are not key interactions 89 | if dataset != "pde10a": 90 | raise ValueError(f"Choice only_key_interactions is not available for {dataset}") 91 | 92 | pdb_file = os.path.join( 93 | DATA_PATH, "pde10a", "pde-10_pdb_bind_format_blinded", docking_folder, f"{docking_folder}_protein.pdb" 94 | ) 95 | protein = Chem.rdmolfiles.MolFromPDBFile(pdb_file, sanitize=False) 96 | all_protein_positions = protein.GetConformer().GetPositions() 97 | protein_positions = [] 98 | # find contact atom in the key protein residues, e.g. 693 --> take position 99 | for a in protein.GetAtoms(): 100 | if a.GetPDBResidueInfo().GetResidueNumber() == 693 or a.GetPDBResidueInfo().GetResidueNumber() == 726: 101 | protein_positions.append(all_protein_positions[a.GetIdx()]) 102 | 103 | # using this position, find the corresponding atom in qtaim_props 104 | distance_matrix = cdist(np.vstack(protein_positions), qtaim_props.pl_complex_coords) 105 | pl_complex_idxs = distance_matrix.argmin(axis=1)[ 106 | distance_matrix.min() < 0.01 107 | ] # not all atoms of those residues might have been included (cutting between backbone and sidechain) 108 | pl_complex_idxs = set(pl_complex_idxs.flatten().tolist()) 109 | 110 | # get the corresponding bcp and verify that it is connected to the correct ligand atom 111 | bcps = [bcp for bcp in bcps if len(pl_complex_idxs.intersection(set(bcp.atom_neighbors)))] 112 | 113 | # loop over the intermolecular BCPs and pull out their neighboring NCPs 114 | for bcp in bcps: 115 | path_length = np.linalg.norm(bcp.path_positions[0] - bcp.path_positions[-1]) 116 | if kwargs["only_short_interactions"] and path_length > MAX_INTERACTION_DISTANCE: 117 | continue # skip 118 | 119 | if len(bcp.atom_neighbors) > 2: 120 | raise ValueError(f"More than two neighbors: {structure_id}") 121 | # neighbors = [ncp for ncp in ncps if ncp.corresponding_atom_id in bcp.atom_neighbors] 122 | atom_neighbors_symbol = bcp.atom_neighbors_symbol 123 | if len(atom_neighbors_symbol) == 2: # BCP with two neighbors 124 | if kwargs["no_c_c_interactions"] and atom_neighbors_symbol == ["C", "C"]: 125 | continue # skip this interaction 126 | elif len(atom_neighbors_symbol) == 1: # BCP where only one neighbor was found 127 | atom_neighbors_symbol = atom_neighbors_symbol + ["*"] # append dummy 128 | elif len(atom_neighbors_symbol) == 0: # NCP 129 | atom_neighbors_symbol = ["*", "*"] # append 2x dummy 130 | atom_neighbors_type_id = [ATOM_NEIGHBOR_IDS.get(n, OTHER) for n in atom_neighbors_symbol] 131 | 132 | props = [bcp.props[prop] for prop in prop_list] 133 | coords = bcp.position 134 | G.add_node( 135 | bcp.idx, 136 | node_props=torch.FloatTensor(props), 137 | atom_type_id=atom_neighbors_type_id, 138 | node_coords=torch.FloatTensor(coords), 139 | ) 140 | return {structure_id: G} 141 | except KeyboardInterrupt: 142 | raise ValueError 143 | except: 144 | print(f"\n{docking_folder}\n", flush=True) 145 | return None 146 | 147 | 148 | def get_graph_from_cp_prop_file_intramolecular( 149 | df: pd.DataFrame, cpprop_file: str, dataset: str, **kwargs 150 | ) -> Union[Dict, None]: 151 | "kwargs unused, just for compatibility with get_bcp_graph_from_cp_prop_file" 152 | folder = os.path.dirname(cpprop_file) 153 | structure_id = get_structure_id(folder, dataset) 154 | try: 155 | ligand_sdf = os.path.join(folder, f"{os.path.basename(folder)}_ligand_with_hydrogens.sdf") 156 | pl_complex_xyz = os.path.join(folder, "pl_complex.xyz") 157 | cp_file = os.path.join(folder, "CPs.txt") 158 | paths_file = os.path.join(folder, "paths.txt") 159 | qtaim_props = QtaimProps( 160 | cp_file=cp_file, 161 | cpprop_file=cpprop_file, 162 | paths_file=paths_file, 163 | ligand_sdf=ligand_sdf, 164 | pl_complex_xyz=pl_complex_xyz, 165 | identifier=structure_id, 166 | ) 167 | include_esp = "esp" in qtaim_props.critical_points[0].props 168 | prop_list = DEFAULT_PROPS + ESP_NAMES if include_esp else DEFAULT_PROPS 169 | null_props = torch.FloatTensor([0.0] * len(prop_list)) 170 | target = df[df[COL_NAMES[dataset]["id_col"]] == structure_id][COL_NAMES[dataset]["aff_col"]].values[0] 171 | 172 | G = nx.Graph(target=target) 173 | 174 | bcps = [cp for cp in qtaim_props.critical_points if cp.name == "bond_critical_point"] 175 | ncps = [cp for cp in qtaim_props.critical_points if cp.name == "nucleus_critical_point"] 176 | atom_id_to_ncp = {ncp.corresponding_atom_id: ncp for ncp in ncps} 177 | 178 | # loop over the intermolecular BCPs and pull out their neighboring NCPs 179 | for bcp in bcps: 180 | if len(bcp.atom_neighbors) < 2: 181 | continue # incomplete path 182 | elif len(bcp.atom_neighbors) > 2: 183 | raise ValueError(f"More than two neighbors: {structure_id}") 184 | if not all([neighbor_id in atom_id_to_ncp for neighbor_id in bcp.atom_neighbors]): 185 | continue 186 | # in rare cases, NCPs do not map 1:1 to atoms, e.g. because two atoms are too close together 187 | # --> not included in graph construction; example: PDE10A, 5sfj_1153 188 | ncp_neighbors = atom_id_to_ncp[bcp.atom_neighbors[0]], atom_id_to_ncp[bcp.atom_neighbors[1]] 189 | intra_ligand = all([idx < qtaim_props.natoms_ligand for idx in bcp.atom_neighbors]) 190 | if not (bcp.intermolecular or intra_ligand): 191 | continue # only saving intermolecular BCPs and those for BCPs within the ligand (covalent & non-covalent) 192 | for ncp in ncp_neighbors: 193 | props = [ncp.props[prop] for prop in prop_list] 194 | atom_type_id = ATOM_NEIGHBOR_IDS.get(ncp.corresponding_atom_symbol, OTHER) 195 | coords = ncp.position 196 | is_ligand = ncp.corresponding_atom_id < qtaim_props.natoms_ligand 197 | G.add_node( 198 | ncp.idx, 199 | node_props=torch.FloatTensor(props), 200 | atom_type_id=atom_type_id, 201 | node_coords=torch.FloatTensor(coords), 202 | is_ligand=is_ligand, 203 | ) 204 | 205 | props = [bcp.props[prop] for prop in prop_list] 206 | coords = bcp.position 207 | distance = np.linalg.norm(ncp_neighbors[0].position - ncp_neighbors[1].position) 208 | G.add_edge( 209 | *(ncp.idx for ncp in ncp_neighbors), 210 | edge_props=torch.FloatTensor(props), 211 | edge_coords=torch.FloatTensor(coords), 212 | distance=distance, 213 | ) 214 | 215 | # add intra-ligand edges 216 | coords = nx.get_node_attributes(G, "node_coords") 217 | for n1 in G.nodes: 218 | for n2 in G.nodes: 219 | if G.nodes[n1]["is_ligand"] and G.nodes[n2]["is_ligand"]: 220 | if G.get_edge_data(n1, n2) is None: # don't overwrite existing ligand edges characterized by BCPs 221 | distance = np.linalg.norm(coords[n1] - coords[n2]) 222 | if distance > 0: # no self-loops 223 | G.add_edge( 224 | n1, 225 | n2, 226 | edge_props=null_props, 227 | edge_coords=NULL_COORDS, 228 | distance=distance, 229 | ) 230 | # need same edge attributes for all edges so we can convert from networkx to torch_geometric 231 | return {structure_id: G} 232 | except: 233 | print(f"\n{structure_id}\n", flush=True) 234 | return None 235 | 236 | 237 | def get_graph_from_cp_prop_file_intramolecular_critic2(*args, **kwargs): 238 | raise NotImplementedError 239 | 240 | 241 | def get_bcp_graph_from_cp_prop_file_critic2( 242 | df: pd.DataFrame, output_cri: str, dataset: str, **kwargs 243 | ) -> Union[Dict, None]: 244 | """Same as get_graph_from_cp_prop_file_intermolecular, but for BCP-based graphs. Uses only intermolecular BCPs and critic2. 245 | kwargs unused, just for compatibility. 246 | 247 | Parameters 248 | ---------- 249 | df : pd.DataFrame 250 | DataFrame with critical-point information 251 | output_cri : str 252 | path to output.cri file from critic2 output 253 | dataset : str 254 | type of dataset, either "pdbbind" or "pde10a" 255 | 256 | Raises 257 | ------ 258 | ValueError 259 | in case of strange assigments of atom neighbors 260 | """ 261 | try: 262 | folder = os.path.dirname(output_cri) 263 | structure_id = get_structure_id(folder, dataset) 264 | ligand_sdf = os.path.join(folder, f"{os.path.basename(folder)}_ligand_with_hydrogens.sdf") 265 | pl_complex_xyz = os.path.join(folder, "pl_complex.xyz") 266 | qtaim_props = QtaimPropsCritic2( 267 | basepath=folder, output_cri=output_cri, pl_complex_xyz=pl_complex_xyz, ligand_sdf=ligand_sdf 268 | ) 269 | prop_list = ["density", "laplacian", "gradient_norm"] # critic2 270 | 271 | target = df[df[COL_NAMES[dataset]["id_col"]] == structure_id][COL_NAMES[dataset]["aff_col"]].values[0] 272 | 273 | G = nx.Graph(target=target, include_esp=False) # no esp in critic2? 274 | 275 | bcps = [cp for cp in qtaim_props.critical_points if cp.name == "bond_critical_point" and cp.intermolecular] 276 | 277 | # loop over the intermolecular BCPs and pull out their neighboring NCPs 278 | for bcp in bcps: 279 | if len(bcp.atom_neighbors) > 2: 280 | raise ValueError(f"More than two neighbors: {structure_id}") 281 | # neighbors = [ncp for ncp in ncps if ncp.corresponding_atom_id in bcp.atom_neighbors] 282 | atom_neighbors_symbol = bcp.atom_neighbors_symbol 283 | if len(atom_neighbors_symbol) == 2: # BCP with two neighbors 284 | pass 285 | elif len(atom_neighbors_symbol) == 1: # BCP where only one neighbor was found 286 | atom_neighbors_symbol = atom_neighbors_symbol + ["*"] # append dummy 287 | elif len(atom_neighbors_symbol) == 0: # NCP 288 | atom_neighbors_symbol = ["*", "*"] # append 2x dummy 289 | atom_neighbors_type_id = [ATOM_NEIGHBOR_IDS.get(n, OTHER) for n in atom_neighbors_symbol] 290 | 291 | props = [bcp.props[prop] for prop in prop_list] 292 | coords = bcp.position 293 | G.add_node( 294 | bcp.idx, 295 | node_props=torch.FloatTensor(props), 296 | atom_type_id=atom_neighbors_type_id, 297 | node_coords=torch.FloatTensor(coords), 298 | ) 299 | return {structure_id: G} 300 | except: 301 | return None 302 | 303 | 304 | def generate_pickle( 305 | search_path: str, 306 | affinity_data: str, 307 | save_path: str, 308 | qm_method: str = "xtb", 309 | dataset: str = "pdbbind", 310 | nucleus_critical_points=False, 311 | **kwargs, 312 | ): 313 | if qm_method == "xtb" or qm_method == "psi4": 314 | graph_fun = ( 315 | get_graph_from_cp_prop_file_intramolecular if nucleus_critical_points else get_bcp_graph_from_cp_prop_file 316 | ) 317 | cpprop_files = sorted(glob.glob(os.path.join(search_path, "*", "CPprop.txt"))) 318 | elif qm_method.startswith("dftb"): 319 | graph_fun = ( 320 | get_graph_from_cp_prop_file_intramolecular_critic2 321 | if nucleus_critical_points 322 | else get_bcp_graph_from_cp_prop_file_critic2 323 | ) 324 | cpprop_files = sorted(glob.glob(os.path.join(search_path, "*", "output.cri"))) 325 | else: 326 | raise ValueError("Unknown qm_method") 327 | 328 | df = pd.read_csv(affinity_data) 329 | 330 | res = {} 331 | 332 | res_par = Parallel(n_jobs=-1)( 333 | delayed(graph_fun)(df, cpprop_file, dataset, **kwargs) for cpprop_file in tqdm(cpprop_files) 334 | ) 335 | # res_par = [] 336 | # for cpprop_file in tqdm(cpprop_files): 337 | # res_par.append(graph_fun(df, cpprop_file, dataset, **kwargs)) 338 | 339 | for r in res_par: 340 | if r is not None: 341 | res.update(r) 342 | 343 | with open(save_path, "wb") as f: 344 | pickle.dump(res, f, protocol=pickle.HIGHEST_PROTOCOL) 345 | 346 | print(f"Saved pickle to {args.save_path}") 347 | 348 | 349 | if __name__ == "__main__": 350 | parser = argparse.ArgumentParser() 351 | parser.add_argument("--search_path", type=str, required=True) 352 | parser.add_argument("--affinity_data", type=str, required=True) 353 | parser.add_argument("--save_path", type=str, required=True) 354 | parser.add_argument("--only_key_interactions", action="store_true", dest="only_key_interactions", default=False) 355 | parser.add_argument( 356 | "--only_short_interactions", action="store_true", dest="only_short_interactions", default=False 357 | ) 358 | parser.add_argument("--no_c_c_interactions", action="store_true", dest="no_c_c_interactions", default=False) 359 | parser.add_argument("--dataset", type=str, default="pdbbind") 360 | parser.add_argument( 361 | "--nucleus_critical_points", action="store_true", dest="nucleus_critical_points", default=False 362 | ) 363 | parser.add_argument("--qm_method", type=str, default="xtb") 364 | 365 | args = parser.parse_args() 366 | 367 | generate_pickle( 368 | args.search_path, 369 | args.affinity_data, 370 | args.save_path, 371 | qm_method=args.qm_method, 372 | dataset=args.dataset, 373 | nucleus_critical_points=args.nucleus_critical_points, 374 | only_key_interactions=args.only_key_interactions, 375 | only_short_interactions=args.only_short_interactions, 376 | no_c_c_interactions=args.no_c_c_interactions, 377 | ) 378 | -------------------------------------------------------------------------------- /bcpaff/ml/hparam_screen.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import argparse 6 | import glob 7 | import json 8 | import os 9 | from typing import List, Optional 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | from tqdm import tqdm 15 | 16 | from bcpaff.ml import statsig 17 | from bcpaff.ml.cluster_tools import generate_bash_script, generate_cmds_file, submit_job 18 | from bcpaff.ml.ml_utils import run_id_to_hparams 19 | from bcpaff.ml.test import HPARAM_KEYS 20 | from bcpaff.utils import HPARAMS, ROOT_PATH 21 | 22 | 23 | def generate_all_config_files(base_output_dir: str, hparam_file: Optional[str] = None) -> List[str]: 24 | config_files = [] 25 | if hparam_file is not None: 26 | hparam_df = pd.read_csv(hparam_file) 27 | else: 28 | hparam_df = HPARAMS 29 | for _, row in hparam_df.iterrows(): 30 | hparams = row[HPARAM_KEYS].to_dict() 31 | config_file = os.path.join(base_output_dir, row.run_id, "hparams.json") 32 | dirname = os.path.dirname(config_file) 33 | checkpoint_path = os.path.join(dirname, "checkpoint.pt") 34 | if os.path.exists(checkpoint_path): # already trained 35 | continue 36 | os.makedirs(dirname, exist_ok=True) 37 | with open(config_file, "w") as f: 38 | json.dump(hparams, f) 39 | config_files.append(config_file) 40 | return config_files 41 | 42 | 43 | def run_hparam_screen( 44 | pickle_file: str, 45 | hparam_file: str, 46 | base_output_dir: Optional[str] = None, 47 | dataset: str = "pdbbind", 48 | split_type: str = "random", 49 | overwrite: bool = False, 50 | cluster_options: Optional[str] = None, 51 | num_epochs: int = 300, 52 | ) -> Optional[int]: 53 | config_files = generate_all_config_files(base_output_dir=base_output_dir, hparam_file=hparam_file) 54 | cmds_file_path = os.path.join(base_output_dir, "slurm_files", "cmds.txt") 55 | script_path = os.path.join(base_output_dir, "slurm_files", "jobscript.sh") 56 | os.makedirs(os.path.dirname(cmds_file_path), exist_ok=True) 57 | num_lines = generate_cmds_file( 58 | cmds_file_path, 59 | config_files, 60 | pickle_file, 61 | dataset=dataset, 62 | split_type=split_type, 63 | base_output_dir=base_output_dir, 64 | overwrite=overwrite, 65 | hparam_file=hparam_file, 66 | num_epochs=num_epochs, 67 | ) 68 | time = "04:00:00" if dataset == "pde10a" else "24:00:00" 69 | generate_bash_script(script_path, cmds_file_path, time=time, num_cores=4, memory=64) 70 | 71 | job_id = submit_job(script_path, num_lines=num_lines, cluster_options=cluster_options) 72 | return job_id 73 | # a = 2 74 | 75 | # for config_file in config_files: 76 | # dirname = os.path.dirname(config_file) 77 | # checkpoint_path = os.path.join(dirname, "checkpoint.pt") 78 | # if os.path.exists(checkpoint_path): # already trained 79 | # continue 80 | # else: 81 | # script_path = os.path.join(dirname, "train.sh") 82 | # generate_bash_script( 83 | # script_path, 84 | # pickle_file, 85 | # config_file, 86 | # dataset=dataset, 87 | # split_type=split_type, 88 | # num_cores=4, 89 | # memory=16, 90 | # base_output_dir=base_output_dir, 91 | # overwrite=overwrite, 92 | # ) 93 | # submit_job(script_path, no_cluster=no_cluster) 94 | 95 | # submit_job( 96 | # pickle_file, 97 | # config_file, 98 | # dataset=dataset, 99 | # split_type=split_type, 100 | # base_output_dir=base_output_dir, 101 | # overwrite=False, 102 | # num_cores=4, 103 | # memory=16, 104 | # no_cluster=no_cluster, 105 | # ) 106 | 107 | 108 | def collect_results(basepath, force_recompute=True, last_epoch=False, hparam_file=None): 109 | if last_epoch: 110 | hparam_results_csv = os.path.join(basepath, "hparam_results_last_epoch.csv") 111 | else: 112 | hparam_results_csv = os.path.join(basepath, "hparam_results.csv") 113 | have_existing = False 114 | if os.path.exists(hparam_results_csv) and not force_recompute: 115 | df = pd.read_csv(hparam_results_csv) 116 | have_existing = True 117 | 118 | folders = sorted(glob.glob(os.path.join(basepath, "run_*"))) 119 | keys = ["eval_mae", "eval_rmse", "train_mae", "train_rmse"] 120 | all_res = [] 121 | for folder in tqdm(folders): 122 | run_id = os.path.basename(folder) 123 | if last_epoch: 124 | checkpoint_file = os.path.join(folder, "last_epoch_checkpoint.pt") 125 | else: 126 | checkpoint_file = os.path.join(folder, "checkpoint.pt") 127 | if not os.path.exists(checkpoint_file): 128 | continue 129 | with open(os.path.join(folder, "hparams.json"), "r") as f: 130 | hparams = json.load(f) 131 | try: 132 | assert run_id_to_hparams(run_id, hparam_file=hparam_file) == hparams # sanity check 133 | except: 134 | print(run_id) 135 | if have_existing and len(df[df.run_id == run_id]): 136 | continue # already in the df 137 | checkpoint = torch.load(checkpoint_file, map_location="cpu") 138 | res = {"run_id": run_id} 139 | res.update({k: checkpoint[k] for k in keys}) 140 | rmse, le, ue = statsig.rmse(checkpoint["y_eval"], checkpoint["yhat_eval"]) 141 | res["rmse_le"] = le 142 | res["rmse_ue"] = ue 143 | assert np.isclose(rmse, res["eval_rmse"]) 144 | hparams.update(res) 145 | all_res.append(hparams) 146 | df_update = pd.DataFrame(all_res) 147 | df = pd.concat([df, df_update]) if have_existing else df_update 148 | if len(df) == 0: 149 | print(f"No completed experiments, couldn't write results from hparam screen", flush=True) 150 | else: 151 | df = df[["run_id"] + [col for col in df.columns if col != "run_id"]] # move run_id column to front 152 | df.to_csv(hparam_results_csv, index=False) 153 | print(f"Wrote results from hparam screen to {hparam_results_csv}", flush=True) 154 | return df 155 | 156 | 157 | if __name__ == "__main__": 158 | parser = argparse.ArgumentParser() 159 | parser.add_argument( 160 | "--pickle_file", 161 | type=str, 162 | ) 163 | parser.add_argument("--dataset", type=str, default="pdbbind") 164 | parser.add_argument("--split_type", type=str, default="random") 165 | parser.add_argument("--base_output_dir", type=str, required=True) 166 | parser.add_argument("--collect_results", action="store_true", default=False, dest="collect_results") 167 | parser.add_argument("--force_recompute", action="store_true", default=False, dest="force_recompute") 168 | parser.add_argument("--cluster_options", type=str, default=None) 169 | parser.add_argument("--hparam_file", type=str, default=os.path.join(ROOT_PATH, "new_hparams.csv")) 170 | parser.add_argument("--last_epoch", action="store_true", default=False, dest="last_epoch") 171 | parser.add_argument("--scaler_name", type=str, default="qtaim_scaler") 172 | parser.add_argument("--num_epochs", type=int, default=1000) 173 | args = parser.parse_args() 174 | 175 | if args.collect_results: 176 | df = collect_results( 177 | args.base_output_dir, 178 | force_recompute=args.force_recompute, 179 | last_epoch=args.last_epoch, 180 | hparam_file=args.hparam_file, 181 | ) 182 | else: 183 | job_id = run_hparam_screen( 184 | args.pickle_file, 185 | hparam_file=args.hparam_file, 186 | base_output_dir=args.base_output_dir, 187 | dataset=args.dataset, 188 | split_type=args.split_type, 189 | overwrite=False, 190 | cluster_options=args.cluster_options, 191 | num_epochs=args.num_epochs, 192 | ) 193 | if job_id is not None: 194 | print(job_id) 195 | -------------------------------------------------------------------------------- /bcpaff/ml/ml_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import os 6 | from typing import List, Optional 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | from torch_geometric.loader import DataLoader 13 | 14 | from bcpaff.ml.net import EGNN, EGNN_NCP 15 | from bcpaff.ml.net_utils import QtaimDataBCP, QtaimDataNCP, QtaimScaler 16 | from bcpaff.utils import BASE_OUTPUT_DIR, HPARAMS 17 | 18 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | HPARAM_KEYS = [ 20 | "batch_size", 21 | "kernel_dim", 22 | "mlp_dim", 23 | "cutoff", 24 | "baseline_atom_ids", 25 | "aggr", 26 | "pool", 27 | "properties", 28 | "n_kernels", 29 | "ncp_graph", 30 | ] 31 | 32 | 33 | def run_id_to_hparams(run_id: str, hparam_file: Optional[str] = None): 34 | if hparam_file is not None: 35 | df = pd.read_csv(hparam_file) 36 | else: 37 | df = HPARAMS 38 | sub_df = df[df.run_id == run_id][HPARAM_KEYS] 39 | if len(sub_df) == 0: 40 | return None 41 | elif len(sub_df) == 1: 42 | return sub_df.iloc[0].to_dict() 43 | elif len(sub_df) > 1: 44 | raise ValueError(f">1 entry for {run_id}") 45 | 46 | 47 | def hparams_to_run_id(hparams: dict, hparam_file: Optional[str] = None): 48 | if hparam_file is not None: 49 | df = pd.read_csv(hparam_file) 50 | else: 51 | df = HPARAMS 52 | sub_df = df.loc[(df[list(hparams)] == pd.Series(hparams)).all(axis=1)] 53 | if len(sub_df) == 0: 54 | return None 55 | elif len(sub_df) == 1: 56 | return sub_df.run_id.values[0] 57 | elif len(sub_df) > 1: 58 | raise ValueError(f">1 entry for {hparams}") 59 | 60 | 61 | def get_output_dir(hparams: dict, base_output_dir: Optional[str] = None): 62 | if base_output_dir is None: 63 | base_output_dir = BASE_OUTPUT_DIR 64 | output_dir = os.path.join(base_output_dir, "_".join([f"{key}{val}" for key, val in sorted(hparams.items())])) 65 | return output_dir 66 | 67 | 68 | def save_checkpoint( 69 | model: EGNN, optimizer: torch.optim.Adam, epoch: int, savepath: str, datapoints: Optional[str] = None 70 | ): 71 | checkpoint = { 72 | "epoch": epoch, 73 | "model_state": model.state_dict(), 74 | "optim_state": optimizer.state_dict(), 75 | } 76 | if datapoints is not None: 77 | checkpoint.update(datapoints) 78 | os.makedirs(os.path.dirname(savepath), exist_ok=True) 79 | torch.save(checkpoint, savepath) 80 | 81 | 82 | def load_checkpoint(hparams: dict, checkpoint_savepath: str): 83 | model = EGNN_NCP if hparams["ncp_graph"] else EGNN 84 | model = model( 85 | n_kernels=hparams["n_kernels"], 86 | aggr=hparams["aggr"], 87 | pool=hparams["pool"], 88 | mlp_dim=hparams["mlp_dim"], 89 | kernel_dim=hparams["kernel_dim"], 90 | baseline_atom_ids=hparams["baseline_atom_ids"], 91 | properties=hparams["properties"], 92 | ).to(DEVICE) 93 | checkpoint = torch.load(checkpoint_savepath, map_location=DEVICE) 94 | model.load_state_dict(checkpoint["model_state"]) 95 | return model 96 | 97 | 98 | def generate_scatter_plot(y_train: np.array, yhat_train: np.array, y_eval: np.array, yhat_eval: np.array): 99 | global_min = np.min(np.concatenate([y_train, yhat_train, y_eval, yhat_eval])) 100 | global_max = np.max(np.concatenate([y_train, yhat_train, y_eval, yhat_eval])) 101 | 102 | fig = plt.figure(figsize=(10, 10)) 103 | ax = fig.add_subplot(111) 104 | ax.plot([global_min, global_max], [global_min, global_max], color="black", linestyle="--", zorder=10) 105 | ax.scatter(y_train, yhat_train, color="blue", zorder=3) 106 | ax.scatter(y_eval, yhat_eval, color="orange", zorder=5) 107 | 108 | return fig 109 | 110 | 111 | def get_data_loader( 112 | pickle_file: str, 113 | hparams: dict, 114 | scaler: QtaimScaler, 115 | idxs: List, 116 | shuffle: bool = True, 117 | pickle_data: Optional[dict] = None, 118 | ) -> DataLoader: 119 | data = QtaimDataNCP if hparams["ncp_graph"] else QtaimDataBCP 120 | data = data( 121 | pickle_file, 122 | scaler=scaler, 123 | idxs=idxs, 124 | cutoff=hparams["cutoff"], 125 | baseline_atom_ids=hparams["baseline_atom_ids"], 126 | properties=hparams["properties"], 127 | pickle_data=pickle_data, 128 | ) 129 | loader = DataLoader(data, batch_size=hparams["batch_size"], num_workers=0, shuffle=shuffle) 130 | return loader 131 | -------------------------------------------------------------------------------- /bcpaff/ml/net_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import pickle 6 | 7 | import networkx as nx 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | from scipy.spatial.distance import pdist, squareform 12 | from sklearn.preprocessing import MinMaxScaler 13 | from torch_geometric.data import Dataset 14 | from torch_geometric.transforms import RadiusGraph 15 | from torch_geometric.utils.convert import from_networkx 16 | from tqdm import tqdm 17 | 18 | from bcpaff.ml.generate_pickle import ESP_NAMES 19 | from bcpaff.utils import ATOM_NEIGHBOR_IDS, DEFAULT_PROPS, DEFAULT_PROPS_CRITIC2, OTHER 20 | 21 | EPS = 0.001 # small epsilon to avoid undefined logarithms 22 | 23 | MIN_PERCENTILE = 1 24 | MAX_PERCENTILE = 99 25 | ELLIPTICITY_OFFSET = np.zeros(len(DEFAULT_PROPS)) 26 | ELLIPTICITY_OFFSET[DEFAULT_PROPS.index("ellipticity")] = 1e-5 # needed to scale ellipticity for NCPs 27 | ATOM_TYPE_ID_PLACEHOLDER = -9999 # placeholder 28 | 29 | 30 | class QtaimDataBCP(Dataset): 31 | def __init__( 32 | self, 33 | pickle_file, 34 | scaler, 35 | idxs=None, 36 | cutoff=6, 37 | baseline_atom_ids=False, 38 | properties="yyyyyyyyy", 39 | pickle_data=None, 40 | ): 41 | if pickle_data is None: 42 | with open(pickle_file, "rb") as handle: 43 | self.all_props = pickle.load(handle) 44 | else: 45 | self.all_props = pickle_data 46 | self.idxs = idxs 47 | if idxs is None: # use all 48 | self.pdb_ids = sorted(self.all_props.keys()) 49 | else: 50 | self.pdb_ids = [key for key in self.all_props.keys() if key in idxs] 51 | self.cutoff = cutoff 52 | self.radius_graph = RadiusGraph(r=cutoff) 53 | self.baseline_atom_ids = baseline_atom_ids # not used here, just keeping track (filtering in network) 54 | self.properties = [x == "y" for x in properties] # boolean mask for which properties to use (translate y/n) 55 | self.scaler = scaler 56 | graph_data = [] 57 | print("Preparing graph data...", flush=True) 58 | successfully_processed_pdb_ids = [] 59 | for pdb_id in tqdm(self.pdb_ids): # precompute graph data 60 | this_graph_data = self._prepare_graph_data(pdb_id) 61 | if this_graph_data is not None: 62 | graph_data.append(this_graph_data) 63 | successfully_processed_pdb_ids.append(pdb_id) 64 | print(f"Successfully processed {len(graph_data)}/{len(self.pdb_ids)} graphs.", flush=True) 65 | self.pdb_ids = successfully_processed_pdb_ids 66 | self.graph_data = graph_data 67 | 68 | def _prepare_graph_data(self, pdb_id): 69 | G = self.all_props[pdb_id] 70 | 71 | # add edges based on self.cutoff 72 | bcp_coords = list(nx.get_node_attributes(G, "node_coords").values()) 73 | if len(bcp_coords): 74 | bcp_coords = torch.stack(list(nx.get_node_attributes(G, "node_coords").values())) 75 | else: # no BCPs that fulfil the filters (e.g., only short interactions etc.) 76 | return None 77 | distance_matrix = squareform(pdist(bcp_coords)) 78 | edge_idxs = np.stack(np.where(distance_matrix < self.cutoff)) 79 | self_loops = edge_idxs[0, :] == edge_idxs[1, :] 80 | edge_idxs = edge_idxs[:, ~self_loops] 81 | G.add_edges_from( 82 | np.array(G.nodes)[edge_idxs].T 83 | ) # edge_idxs where only in terms of 0 ... n, but node ids in the graph are different 84 | 85 | untransformed_node_props = torch.stack(list(nx.get_node_attributes(G, "node_props").values())) 86 | 87 | transformed_node_props = self.scaler.transform(untransformed_node_props, point_name="bond_critical_point")[ 88 | :, self.properties 89 | ] 90 | nx.set_node_attributes( 91 | G, {key: val for key, val in zip(G.nodes, torch.FloatTensor(transformed_node_props))}, name="node_props" 92 | ) 93 | graph_data = from_networkx(G) 94 | if sum(self.properties) == 1: # single property 95 | graph_data["node_props"] = graph_data["node_props"].unsqueeze(dim=1) 96 | else: 97 | graph_data["node_props"] = torch.stack(graph_data["node_props"]) 98 | graph_data["node_coords"] = torch.stack(graph_data["node_coords"]) 99 | graph_data["target"] = G.graph["target"] 100 | 101 | return graph_data 102 | 103 | def __getitem__(self, idx): 104 | return self.graph_data[idx] 105 | 106 | def __len__(self): 107 | return len(self.pdb_ids) 108 | 109 | 110 | class QtaimDataNCP(Dataset): 111 | def __init__( 112 | self, 113 | pickle_file, 114 | scaler, 115 | idxs=None, 116 | cutoff=6, 117 | baseline_atom_ids=False, 118 | properties="yyyyyyyyy", 119 | pickle_data=None, 120 | ): 121 | if pickle_data is None: 122 | with open(pickle_file, "rb") as handle: 123 | self.all_props = pickle.load(handle) 124 | else: 125 | self.all_props = pickle_data 126 | self.idxs = idxs 127 | if idxs is None: # use all 128 | self.pdb_ids = sorted(self.all_props.keys()) 129 | else: 130 | self.pdb_ids = [key for key in self.all_props.keys() if key in idxs] 131 | self.cutoff = cutoff 132 | self.radius_graph = RadiusGraph(r=cutoff) 133 | self.baseline_atom_ids = baseline_atom_ids # not used here, just keeping track (filtering in network) 134 | self.properties = [x == "y" for x in properties] # boolean mask for which properties to use (translate y/n) 135 | self.include_esp = any(self.properties[-3:]) 136 | self.prop_list = DEFAULT_PROPS + ESP_NAMES if self.include_esp else DEFAULT_PROPS 137 | self.null_props = torch.FloatTensor([0.0] * len(self.properties)) 138 | 139 | self.scaler = scaler 140 | graph_data = [] 141 | print("Preparing graph data...") 142 | for pdb_id in tqdm(self.pdb_ids): # precompute graph data 143 | graph_data.append(self._prepare_graph_data(pdb_id)) 144 | self.graph_data = graph_data 145 | 146 | def _prepare_graph_data(self, pdb_id): 147 | G = self.all_props[pdb_id] 148 | 149 | # remove edges with NULL_PROPS & distance > self.cutoff 150 | edge_distances = torch.FloatTensor(list(nx.get_edge_attributes(G, "distance").values())) 151 | edge_props_for_checking = torch.stack(list(nx.get_edge_attributes(G, "edge_props").values())) 152 | null_props_mask = (edge_props_for_checking == self.null_props).all(axis=1) 153 | remove_edges = null_props_mask & (edge_distances > self.cutoff) 154 | G.remove_edges_from(np.asarray(list(G.edges))[remove_edges]) 155 | 156 | untransformed_edge_props = torch.stack(list(nx.get_edge_attributes(G, "edge_props").values())) 157 | untransformed_node_props = torch.stack(list(nx.get_node_attributes(G, "node_props").values())) 158 | 159 | transformed_node_props = self.scaler.transform(untransformed_node_props, point_name="nucleus_critical_point")[ 160 | :, self.properties 161 | ] 162 | transformed_edge_props = self.scaler.transform(untransformed_edge_props, point_name="bond_critical_point")[ 163 | :, self.properties 164 | ] 165 | nx.set_node_attributes( 166 | G, {key: val for key, val in zip(G.nodes, torch.FloatTensor(transformed_node_props))}, name="node_props" 167 | ) 168 | nx.set_edge_attributes( 169 | G, {key: val for key, val in zip(G.edges, torch.FloatTensor(transformed_edge_props))}, name="edge_props" 170 | ) 171 | graph_data = from_networkx(G) 172 | if sum(self.properties) == 1: # maintain same dimension if single QM property is used 173 | graph_data["edge_props"] = torch.stack(list(graph_data["edge_props"])).unsqueeze(dim=1) 174 | graph_data["node_props"] = torch.stack(list(graph_data["node_props"])).unsqueeze(dim=1) 175 | else: 176 | graph_data["edge_props"] = torch.stack(list(graph_data["edge_props"])) 177 | graph_data["node_props"] = torch.stack(list(graph_data["node_props"])) 178 | graph_data["node_coords"] = torch.stack(graph_data["node_coords"]) 179 | graph_data["edge_coords"] = torch.stack(graph_data["edge_coords"]) 180 | graph_data["target"] = G.graph["target"] 181 | 182 | return graph_data 183 | 184 | def __getitem__(self, idx): 185 | return self.graph_data[idx] 186 | 187 | def __len__(self): 188 | return len(self.pdb_ids) 189 | 190 | 191 | def pickle_to_df_ncp(pickle_file): 192 | with open(pickle_file, "rb") as handle: 193 | pickle_data = pickle.load(handle) 194 | all_ids, all_coords, all_props, all_targets, all_point_names = ([], [], [], [], []) 195 | all_atom_type_ids, all_is_ligand = [], [] 196 | for pdb_id, G in pickle_data.items(): 197 | # NCP data 198 | ncp_coords = list(nx.get_node_attributes(G, "node_coords").values()) 199 | ncp_features = nx.get_node_attributes(G, "node_props").values() 200 | ncp_atom_type_ids = nx.get_node_attributes(G, "atom_type_id").values() 201 | 202 | # BCP data 203 | bcp_coords = nx.get_edge_attributes(G, "edge_coords").values() 204 | bcp_features = nx.get_edge_attributes(G, "edge_props").values() 205 | bcp_atom_type_ids = [ATOM_TYPE_ID_PLACEHOLDER] * len(bcp_features) # dummy 206 | 207 | target = G.graph["target"] 208 | cp_names = len(ncp_coords) * ["nucleus_critical_point"] + len(bcp_coords) * ["bond_critical_point"] 209 | num_points = len(ncp_coords) + len(bcp_coords) 210 | all_ids.extend([pdb_id] * num_points) 211 | all_coords.extend(ncp_coords) 212 | all_coords.extend(bcp_coords) 213 | all_props.extend(ncp_features) 214 | all_props.extend(bcp_features) 215 | all_atom_type_ids.extend(ncp_atom_type_ids) 216 | all_atom_type_ids.extend(bcp_atom_type_ids) 217 | all_targets.extend([target] * num_points) 218 | all_point_names.extend(cp_names) 219 | all_is_ligand.extend(nx.get_node_attributes(G, "is_ligand").values()) 220 | all_is_ligand.extend([False] * len(bcp_coords)) # only labelling NCPs as belonging to ligand 221 | all_ids = np.array(all_ids) 222 | all_targets = np.array(all_targets) 223 | all_coords = np.stack(all_coords) 224 | all_props = np.stack(all_props) 225 | all_atom_type_ids = np.stack(all_atom_type_ids).astype(int) 226 | if all_props.shape[1] == len(DEFAULT_PROPS): 227 | prop_list = DEFAULT_PROPS 228 | elif all_props.shape[1] == len(DEFAULT_PROPS) + len(ESP_NAMES): 229 | prop_list = DEFAULT_PROPS + ESP_NAMES 230 | elif all_props.shape[1] == len(DEFAULT_PROPS_CRITIC2): 231 | prop_list = DEFAULT_PROPS_CRITIC2 232 | dict_for_df = { 233 | "pdb_id": all_ids, 234 | "cp_name": all_point_names, 235 | "atom_type_id": all_atom_type_ids, 236 | "is_ligand": all_is_ligand, 237 | "target": all_targets, 238 | "x": all_coords[:, 0], 239 | "y": all_coords[:, 1], 240 | "z": all_coords[:, 2], 241 | } 242 | update_dict = {prop_name: all_props[:, i] for i, prop_name in enumerate(prop_list)} 243 | dict_for_df.update(update_dict) 244 | if len(set([len(x) for x in dict_for_df.values()])) != 1: 245 | print("You probably supplied the BCP pickle to the NCP function...") # catching common error 246 | df = pd.DataFrame(dict_for_df) 247 | return df, pickle_data 248 | 249 | 250 | def pickle_to_df_bcp(pickle_file): 251 | with open(pickle_file, "rb") as handle: 252 | pickle_data = pickle.load(handle) 253 | all_ids, all_coords, all_props, all_targets, all_point_names = ([], [], [], [], []) 254 | all_atom_type_ids, all_is_ligand = [], [] 255 | for pdb_id, G in pickle_data.items(): 256 | # BCP data 257 | bcp_coords = list(nx.get_node_attributes(G, "node_coords").values()) 258 | bcp_features = nx.get_node_attributes(G, "node_props").values() 259 | bcp_atom_type_ids = nx.get_node_attributes(G, "atom_type_id").values() 260 | 261 | target = G.graph["target"] 262 | cp_names = len(bcp_coords) * ["bond_critical_point"] # only doing BCPs in this function 263 | num_points = len(bcp_coords) 264 | all_ids.extend([pdb_id] * num_points) 265 | all_coords.extend(bcp_coords) 266 | all_props.extend(bcp_features) 267 | all_atom_type_ids.extend(bcp_atom_type_ids) 268 | all_targets.extend([target] * num_points) 269 | all_point_names.extend(cp_names) 270 | all_ids = np.array(all_ids) 271 | all_targets = np.array(all_targets) 272 | all_coords = np.stack(all_coords) 273 | all_props = np.stack(all_props) 274 | dict_for_df = { 275 | "pdb_id": all_ids, 276 | "cp_name": all_point_names, 277 | "atom_type_id": all_atom_type_ids, 278 | "target": all_targets, 279 | "x": all_coords[:, 0], 280 | "y": all_coords[:, 1], 281 | "z": all_coords[:, 2], 282 | } 283 | if all_props.shape[1] == len(DEFAULT_PROPS): 284 | prop_list = DEFAULT_PROPS 285 | elif all_props.shape[1] == len(DEFAULT_PROPS) + len(ESP_NAMES): 286 | prop_list = DEFAULT_PROPS + ESP_NAMES 287 | elif all_props.shape[1] == len(DEFAULT_PROPS_CRITIC2): 288 | prop_list = DEFAULT_PROPS_CRITIC2 289 | update_dict = {prop_name: all_props[:, i] for i, prop_name in enumerate(prop_list)} 290 | dict_for_df.update(update_dict) 291 | df = pd.DataFrame(dict_for_df) 292 | return df, pickle_data 293 | 294 | 295 | def pickle_to_df(pickle_file, ncp_graph=False): 296 | if ncp_graph: 297 | return pickle_to_df_ncp(pickle_file) 298 | else: 299 | return pickle_to_df_bcp(pickle_file) 300 | 301 | 302 | class QtaimScaler(object): 303 | def __init__(self, pickle_file, train_idxs, ncp_graph=False): 304 | self.pickle_file = pickle_file 305 | self.train_idxs = train_idxs 306 | self.ncp_graph = ncp_graph 307 | df, self.pickle_data = pickle_to_df(pickle_file, ncp_graph=self.ncp_graph) 308 | self.df = df[df.pdb_id.isin(train_idxs)] # only fit scaler to training data 309 | self.prop_list = self._get_prop_list() 310 | self.ellipticity_offset = self._get_ellipticity_offset() 311 | self.min_percentiles = {} 312 | self.max_percentiles = {} 313 | self.scaled_abs_min_val = {} 314 | self.minmax_scaler = {} 315 | self.fit() 316 | 317 | def _get_ellipticity_offset(self): 318 | if self.prop_list == DEFAULT_PROPS_CRITIC2: 319 | return 0 320 | else: 321 | if "esp" in self.prop_list: 322 | return np.hstack([ELLIPTICITY_OFFSET, [0.0] * len(ESP_NAMES)]) 323 | else: 324 | return ELLIPTICITY_OFFSET 325 | 326 | def _get_prop_list(self): 327 | prop_list = [x for x in DEFAULT_PROPS + ESP_NAMES if x in self.df.keys()] 328 | return prop_list 329 | 330 | def transform(self, props_unscaled, point_name): 331 | # clip anything above/below 1st/99th percentile 332 | props_scaled = np.clip(props_unscaled, self.min_percentiles[point_name], self.max_percentiles[point_name]) 333 | # move minimum to zero and add small amount (EPS) so all positive (needed for log) 334 | props_scaled = ( 335 | props_scaled + ((1 + EPS) * self.scaled_abs_min_val[point_name]).to_numpy() + self.ellipticity_offset 336 | ) 337 | # take log10 338 | props_scaled = np.log10(props_scaled) 339 | # scale from 0 to 1 340 | props_scaled = self.minmax_scaler[point_name].transform(props_scaled) 341 | return props_scaled 342 | 343 | def fit(self): 344 | cp_names = ["bond_critical_point", "nucleus_critical_point"] 345 | for cp_name in cp_names: 346 | sub_df = self.df[self.df.cp_name == cp_name] # separate scaling for BCPs and NCPs 347 | if len(sub_df) == 0: # don't have any points of this type --> no need to scale 348 | pass 349 | else: 350 | if cp_name == "bond_critical_point": 351 | sub_df = sub_df[(sub_df[self.prop_list] != 0).all(axis=1)] # don't scale for NULL_PROPS 352 | # find 1st/99st percentiles for each property 353 | self.min_percentiles[cp_name] = np.percentile(sub_df[self.prop_list], MIN_PERCENTILE, axis=0) 354 | self.max_percentiles[cp_name] = np.percentile(sub_df[self.prop_list], MAX_PERCENTILE, axis=0) 355 | # clip anything above/below 1st/99th percentile 356 | props_scaled = np.clip( 357 | sub_df[self.prop_list], self.min_percentiles[cp_name], self.max_percentiles[cp_name] 358 | ) 359 | # move minimum to zero and add small amount (EPS) so all positive (needed for log) 360 | self.scaled_abs_min_val[cp_name] = abs(props_scaled.min()) 361 | 362 | props_scaled = props_scaled + (1 + EPS) * self.scaled_abs_min_val[cp_name] + self.ellipticity_offset 363 | # take log10 364 | props_scaled = np.log10(props_scaled) 365 | # scale to 0 mean and unit variance 366 | self.minmax_scaler[cp_name] = MinMaxScaler().fit(props_scaled.values) 367 | -------------------------------------------------------------------------------- /bcpaff/ml/run_all_ml_experiments.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import argparse 6 | import os 7 | import subprocess 8 | from typing import Optional 9 | 10 | from bcpaff.data_processing.data_processing import DATASETS 11 | from bcpaff.ml.hparam_screen import run_hparam_screen 12 | from bcpaff.utils import DATASETS_AND_SPLITS, PROCESSED_DATA_PATH, ROOT_PATH 13 | 14 | 15 | def submit_job_collect_results( 16 | base_output_dir: str, hparam_file: str, cluster_options: Optional[str] = None, last_epoch: bool = False 17 | ): 18 | cmd_str = f"""python -c 'from bcpaff.ml.hparam_screen import collect_results; collect_results(\\\"{base_output_dir}\\\", hparam_file=\\\"{hparam_file}\\\", last_epoch={last_epoch})' """ 19 | cluster_options = "" if cluster_options is None else cluster_options 20 | if cluster_options != "no_cluster": 21 | slurm_output_file = os.path.join(base_output_dir, "slurm_files", "out_files", "collect_results_out_%A.out") 22 | cmd_str = f"""sbatch --parsable {cluster_options} --output={slurm_output_file} --wrap "{cmd_str}" """ 23 | completed_process = subprocess.run( 24 | cmd_str, 25 | shell=True, 26 | universal_newlines=True, 27 | stdout=subprocess.PIPE, 28 | ) 29 | job_id = completed_process.stdout.rstrip("\n") 30 | return job_id 31 | 32 | 33 | def submit_job_run_test( 34 | pickle_file: str, 35 | base_output_dir: str, 36 | dataset: str, 37 | split_type: str, 38 | cluster_options: Optional[str] = None, 39 | last_epoch: bool = False, 40 | ): 41 | cmd_str = f"""python -c 'from bcpaff.ml.test import run_test; run_test(\\\"{pickle_file}\\\", \\\"{base_output_dir}\\\", \\\"{dataset}\\\", \\\"{split_type}\\\", last_epoch={last_epoch})' """ 42 | cluster_options = "" if cluster_options is None else cluster_options 43 | if cluster_options != "no_cluster": 44 | slurm_output_file = os.path.join(base_output_dir, "slurm_files", "out_files", "test_out_%A.out") 45 | cmd_str = f"""sbatch -n 4 --mem-per-cpu=16000 --parsable {cluster_options} --output={slurm_output_file} --wrap "{cmd_str}" """ 46 | completed_process = subprocess.run( 47 | cmd_str, 48 | shell=True, 49 | universal_newlines=True, 50 | stdout=subprocess.PIPE, 51 | ) 52 | job_id = completed_process.stdout.rstrip("\n") 53 | return job_id 54 | 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument("--test_run", action="store_true", dest="test_run", default=False) 59 | parser.add_argument("--cluster_options", type=str, default=None) 60 | parser.add_argument("--dataset", type=str, default=None) 61 | parser.add_argument("--last_epoch", action="store_true", default=False) 62 | args = parser.parse_args() 63 | 64 | datasets_to_run = DATASETS if args.dataset is None else [args.dataset] 65 | t = 0 66 | for dataset in datasets_to_run: 67 | for split_type in DATASETS_AND_SPLITS[dataset]: 68 | for ncp_bcp in ["ncp", "bcp"]: 69 | for atom_ids in ["props", "atom_ids", "atom_ids_and_props"]: 70 | base_output_dir = os.path.join( 71 | PROCESSED_DATA_PATH, "model_runs_1000", ncp_bcp, dataset, split_type, atom_ids 72 | ) 73 | print(t) 74 | t += 1 75 | pickle_file = os.path.join( 76 | PROCESSED_DATA_PATH, 77 | "prepared_structures", 78 | dataset, 79 | f"qtaim_props_{ncp_bcp}.pkl", 80 | ) 81 | if args.test_run: 82 | hparam_file = os.path.join(ROOT_PATH, "hparam_files", f"hparams_{ncp_bcp}_{atom_ids}_mini.csv") 83 | num_epochs = 20 84 | else: 85 | hparam_file = os.path.join(ROOT_PATH, "hparam_files", f"hparams_{ncp_bcp}_{atom_ids}.csv") 86 | num_epochs = 1000 87 | job_id = run_hparam_screen( 88 | pickle_file=pickle_file, 89 | hparam_file=hparam_file, 90 | base_output_dir=base_output_dir, 91 | dataset=dataset, 92 | split_type=split_type, 93 | overwrite=True, 94 | cluster_options=args.cluster_options, 95 | num_epochs=num_epochs, 96 | ) 97 | cluster_options = ( 98 | f"--dependency=afterok:{job_id}" if args.cluster_options is None else args.cluster_options 99 | ) 100 | job_id = submit_job_collect_results( 101 | base_output_dir, hparam_file, cluster_options=cluster_options, last_epoch=args.last_epoch 102 | ) 103 | cluster_options = ( 104 | f"--dependency=afterok:{job_id}" if args.cluster_options is None else args.cluster_options 105 | ) 106 | submit_job_run_test( 107 | pickle_file, 108 | base_output_dir, 109 | dataset, 110 | split_type, 111 | cluster_options=cluster_options, 112 | last_epoch=args.last_epoch, 113 | ) 114 | -------------------------------------------------------------------------------- /bcpaff/ml/scrambling.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import argparse 6 | import os 7 | import pickle 8 | import random 9 | from collections import defaultdict 10 | 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import pandas as pd 14 | import torch 15 | 16 | from bcpaff.ml import statsig 17 | from bcpaff.ml.ml_utils import get_data_loader, load_checkpoint 18 | from bcpaff.ml.net_utils import QtaimScaler 19 | from bcpaff.ml.test import DATASETS_AND_SPLITS, get_best_hparams, run_test 20 | from bcpaff.ml.train import eval_loop, get_split_idxs 21 | from bcpaff.utils import ROOT_PATH, SEED 22 | 23 | random.seed(SEED) 24 | torch.manual_seed(SEED) 25 | 26 | 27 | def get_rmse_bounds_for_one_exp(test_results_savepath): 28 | with open(test_results_savepath, "rb") as f: 29 | res = pickle.load(f) 30 | rmse, le, ue = statsig.rmse(res["y_test"], res["yhat_test"]) 31 | assert np.isclose(rmse, res["test_rmse"]) 32 | return rmse, le, ue 33 | 34 | 35 | def run_scrambling_experiments(pickle_file, dataset, base_output_dir): 36 | last_epoch = False 37 | 38 | for split_type in DATASETS_AND_SPLITS[dataset]: 39 | this_base_output_dir = os.path.join(base_output_dir, dataset, split_type) 40 | 41 | hparams, checkpoint_savepath = get_best_hparams(this_base_output_dir, last_epoch=last_epoch, quiet=True) 42 | try: 43 | model = load_checkpoint(hparams, checkpoint_savepath) 44 | except: 45 | print( 46 | f"scp -r euler:/cluster/project/schneider/cisert/bcpaff/processed_data/model_runs_esp/bcp/pde10a/{split_type}/{os.path.basename(os.path.dirname(checkpoint_savepath))} ./{split_type}" 47 | ) 48 | continue 49 | 50 | train_idxs, _, _, test_idxs = get_split_idxs(dataset, split_type) 51 | 52 | scaler = QtaimScaler( 53 | pickle_file, train_idxs, ncp_graph=hparams["ncp_graph"] 54 | ) # still use train set normalization values 55 | test_loader = get_data_loader(pickle_file, hparams, scaler, test_idxs, shuffle=False) 56 | 57 | # y_scrambling 58 | test_mae, test_rmse, test_loss, y_test, yhat_test = eval_loop( 59 | model, test_loader, torch.nn.MSELoss(), y_scrambling=True, input_scrambling=False 60 | ) 61 | results = { 62 | "test_mae": test_mae, 63 | "test_rmse": test_rmse, 64 | "test_loss": test_loss, 65 | "y_test": y_test, 66 | "yhat_test": yhat_test, 67 | "hparams": hparams, 68 | } 69 | results_savepath = os.path.join(this_base_output_dir, "test_results_y_scrambling.pkl") 70 | with open(results_savepath, "wb") as f: 71 | pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL) 72 | print(f"Saved test set results to {results_savepath}") 73 | 74 | # input_scrambling 75 | test_mae, test_rmse, test_loss, y_test, yhat_test = eval_loop( 76 | model, test_loader, torch.nn.MSELoss(), y_scrambling=False, input_scrambling=True 77 | ) 78 | results = { 79 | "test_mae": test_mae, 80 | "test_rmse": test_rmse, 81 | "test_loss": test_loss, 82 | "y_test": y_test, 83 | "yhat_test": yhat_test, 84 | "hparams": hparams, 85 | } 86 | results_savepath = os.path.join(this_base_output_dir, "test_results_input_scrambling.pkl") 87 | with open(results_savepath, "wb") as f: 88 | pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL) 89 | print(f"Saved test set results to {results_savepath}") 90 | 91 | 92 | def collect_scrambling_results(base_output_dir, dataset): 93 | df_res = defaultdict(dict) 94 | mad_benchmark = pd.read_csv(os.path.join(ROOT_PATH, "notebooks", "collect_results", "mad_benchmark.csv")) 95 | for split_type in DATASETS_AND_SPLITS[dataset]: 96 | experiment = f"{dataset}_{split_type}" 97 | rmse, le, ue = get_rmse_bounds_for_one_exp( 98 | os.path.join(base_output_dir, dataset, split_type, "test_results.pkl") 99 | ) 100 | rmse_y, le_y, ue_y = get_rmse_bounds_for_one_exp( 101 | os.path.join(base_output_dir, dataset, split_type, "test_results_y_scrambling.pkl") 102 | ) 103 | rmse_input, le_input, ue_input = get_rmse_bounds_for_one_exp( 104 | os.path.join(base_output_dir, dataset, split_type, "test_results_input_scrambling.pkl") 105 | ) 106 | 107 | mad_benchmark_rmse = mad_benchmark[ 108 | (mad_benchmark.dataset == dataset) & (mad_benchmark.split_type == split_type) 109 | ].rmse.values[0] 110 | mad_benchmark_le = mad_benchmark[ 111 | (mad_benchmark.dataset == dataset) & (mad_benchmark.split_type == split_type) 112 | ]["le"].values[0] 113 | mad_benchmark_ue = mad_benchmark[ 114 | (mad_benchmark.dataset == dataset) & (mad_benchmark.split_type == split_type) 115 | ].ue.values[0] 116 | 117 | df_res[experiment] = { 118 | "rmse": rmse, 119 | "le": le, 120 | "ue": ue, 121 | "rmse_y": rmse_y, 122 | "le_y": le_y, 123 | "ue_y": ue_y, 124 | "rmse_input": rmse_input, 125 | "le_input": le_input, 126 | "ue_input": ue_input, 127 | "rmse_mad": mad_benchmark_rmse, 128 | "le_mad": mad_benchmark_le, 129 | "ue_mad": mad_benchmark_ue, 130 | } 131 | df_res = pd.DataFrame(df_res).T 132 | return df_res 133 | 134 | 135 | def plot_scrambling_results(df_res, dataset, savepath): 136 | names = DATASETS_AND_SPLITS[dataset] 137 | names_replace = { 138 | "random": "Random", 139 | "temporal_2011": "Temp.\n2011", 140 | "temporal_2012": "Temp.\n2012", 141 | "temporal_2013": "Temp.\n2013", 142 | "aminohetaryl_c1_amide": "Binding\nmode 1", 143 | "c1_hetaryl_alkyl_c2_hetaryl": "Binding\nmode 2", 144 | "aryl_c1_amide_c2_hetaryl": "Binding\nmode 3", 145 | } 146 | names = [names_replace[n] for n in names] 147 | 148 | heights_bcpaff = df_res["rmse"] 149 | errorbars_bcpaff = np.array([df_res["le"].tolist(), df_res["ue"].tolist()], dtype="float") 150 | 151 | heights_y = df_res["rmse_y"] 152 | errorbars_y = np.array([df_res["le_y"].tolist(), df_res["ue_y"].tolist()], dtype="float") 153 | 154 | heights_input = df_res["rmse_input"] 155 | errorbars_input = np.array([df_res["le_input"].tolist(), df_res["ue_input"].tolist()], dtype="float") 156 | 157 | heights_mad = df_res["rmse_mad"] 158 | errorbars_mad = np.array([df_res["le_mad"].tolist(), df_res["ue_mad"].tolist()], dtype="float") 159 | 160 | fig = plt.figure(figsize=(10, 4)) 161 | ax = fig.add_subplot(111) 162 | w = 0.15 163 | x = np.array(range(len(names))) 164 | ax.bar(x - 1.5 * w, heights_bcpaff, width=w, label="Normal network", color="orange", edgecolor="black") 165 | ax.errorbar(x - 1.5 * w, heights_bcpaff, yerr=errorbars_bcpaff, linestyle="", color="black") 166 | 167 | ax.bar(x - 0.5 * w, heights_y, width=w, label="Y scrambling", color="green", edgecolor="black") 168 | ax.errorbar(x - 0.5 * w, heights_y, yerr=errorbars_y, linestyle="", color="black") 169 | 170 | ax.bar(x + 0.5 * w, heights_input, width=w, label="Input scrambling", color="gray", edgecolor="black") 171 | ax.errorbar(x + 0.5 * w, heights_input, yerr=errorbars_input, linestyle="", color="black") 172 | 173 | ax.bar(x + 1.5 * w, heights_mad, width=w, label="MAD", color="white", edgecolor="black") 174 | ax.errorbar(x + 1.5 * w, heights_mad, yerr=errorbars_mad, linestyle="", color="black") 175 | 176 | ax.set_ylabel("Test set RMSE", fontsize=12) 177 | ax.set_xticks(x) 178 | ax.tick_params(axis="y", which="major", labelsize=12) 179 | ax.set_xticklabels(names, rotation=0, ha="center", fontsize=12) 180 | ax.legend(fontsize=12, ncol=4) 181 | ax.set_ylim([0, 2.0]) 182 | os.makedirs(os.path.dirname(savepath), exist_ok=True) 183 | fig.savefig(savepath, dpi=300, bbox_inches="tight") 184 | print(f"Saved plot to {savepath}") 185 | 186 | 187 | if __name__ == "__main__": 188 | parser = argparse.ArgumentParser() 189 | parser.add_argument("--pickle_file", type=str, required=True) 190 | parser.add_argument("--dataset", type=str, default="pdbbind") 191 | parser.add_argument("--base_output_dir", type=str, default=None) 192 | parser.add_argument("--plot_savepath", type=str, default=None) 193 | args = parser.parse_args() 194 | 195 | run_scrambling_experiments(args.pickle_file, args.dataset, args.base_output_dir) 196 | df_res = collect_scrambling_results(args.base_output_dir, args.dataset) 197 | if args.plot_savepath is None: 198 | plot_savepath = os.path.join(args.base_output_dir, args.dataset, "scrambling_results.pdf") 199 | else: 200 | plot_savepath = args.plot_savepath 201 | plot_scrambling_results(df_res, args.dataset, plot_savepath) 202 | -------------------------------------------------------------------------------- /bcpaff/ml/statsig.py: -------------------------------------------------------------------------------- 1 | """From https://github.com/jensengroup/statsig (Accessed 01.11.22) 2 | MIT License 3 | 4 | Copyright (c) 2016 Jensen Group 5 | """ 6 | 7 | 8 | import numpy as np 9 | 10 | 11 | def correl(X, Y): 12 | (N,) = X.shape 13 | 14 | if N < 9: 15 | print(f"not enough points. {N} datapoints given. at least 9 is required") 16 | return 17 | 18 | r = np.corrcoef(X, Y)[0][1] 19 | r_sig = 1.96 / np.sqrt(N - 2 + 1.96 ** 2) 20 | F_plus = 0.5 * np.log((1 + r) / (1 - r)) + r_sig 21 | F_minus = 0.5 * np.log((1 + r) / (1 - r)) - r_sig 22 | le = r - (np.exp(2 * F_minus) - 1) / (np.exp(2 * F_minus) + 1) 23 | ue = (np.exp(2 * F_plus) - 1) / (np.exp(2 * F_plus) + 1) - r 24 | 25 | return r, le, ue 26 | 27 | 28 | def rmse(X, Y): 29 | """ 30 | Root-Mean-Square Error 31 | 32 | Lower Error = RMSE \left( 1- \sqrt{ 1- \frac{1.96\sqrt{2}}{\sqrt{N-1}} } \right ) 33 | Upper Error = RMSE \left( \sqrt{ 1+ \frac{1.96\sqrt{2}}{\sqrt{N-1}} } - 1 \right ) 34 | 35 | This only works for N >= 8.6832, otherwise the lower error will be 36 | imaginary. 37 | 38 | Parameters: 39 | X -- One dimensional Numpy array of floats 40 | Y -- One dimensional Numpy array of floats 41 | 42 | Returns: 43 | rmse -- Root-mean-square error between X and Y 44 | le -- Lower error on the RMSE value 45 | ue -- Upper error on the RMSE value 46 | """ 47 | 48 | (N,) = X.shape 49 | 50 | if N < 9: 51 | print(f"not enough points. {N} datapoints given. at least 9 is required") 52 | return 53 | 54 | diff = X - Y 55 | diff = diff ** 2 56 | rmse = np.sqrt(diff.mean()) 57 | 58 | le = rmse * (1.0 - np.sqrt(1 - 1.96 * np.sqrt(2.0) / np.sqrt(N - 1))) 59 | ue = rmse * (np.sqrt(1 + 1.96 * np.sqrt(2.0) / np.sqrt(N - 1)) - 1) 60 | 61 | return rmse, le, ue 62 | 63 | 64 | def mae(X, Y): 65 | """ 66 | Mean Absolute Error (MAE) 67 | 68 | Lower Error = MAE_X \left( 1- \sqrt{ 1- \frac{1.96\sqrt{2}}{\sqrt{N-1}} } \right ) 69 | Upper Error = MAE_X \left( \sqrt{ 1+ \frac{1.96\sqrt{2}}{\sqrt{N-1}} }-1 \right ) 70 | 71 | Parameters: 72 | X -- One dimensional Numpy array of floats 73 | Y -- One dimensional Numpy array of floats 74 | 75 | Returns: 76 | mae -- Mean-absolute error between X and Y 77 | le -- Lower error on the MAE value 78 | ue -- Upper error on the MAE value 79 | """ 80 | 81 | (N,) = X.shape 82 | 83 | mae = np.abs(X - Y) 84 | mae = mae.mean() 85 | 86 | le = mae * (1 - np.sqrt(1 - 1.96 * np.sqrt(2) / np.sqrt(N - 1))) 87 | ue = mae * (np.sqrt(1 + 1.96 * np.sqrt(2) / np.sqrt(N - 1)) - 1) 88 | 89 | return mae, le, ue 90 | 91 | 92 | def me(X, Y): 93 | """ 94 | mean error (ME) 95 | 96 | L_X = U_X = \frac{1.96 s_N}{\sqrt{N}} 97 | where sN is the standard population deviation (e.g. STDEVP in Excel). 98 | 99 | Parameters: 100 | X -- One dimensional Numpy array of floats 101 | Y -- One dimensional Numpy array of floats 102 | 103 | Returns: 104 | mae -- Mean error between X and Y 105 | e -- Upper and Lower error on the ME 106 | """ 107 | 108 | (N,) = X.shape 109 | 110 | error = X - Y 111 | me = error.mean() 112 | 113 | s_N = stdevp(error, me, N) 114 | e = 1.96 * s_N / np.sqrt(N) 115 | 116 | return me, e 117 | 118 | 119 | def stdevp(X, X_hat, N): 120 | """ 121 | Parameters: 122 | X -- One dimensional Numpy array of floats 123 | X_hat -- Float 124 | N -- Integer 125 | 126 | Returns: 127 | 128 | Calculates standard deviation based on the entire population given as 129 | arguments. The standard deviation is a measure of how widely values are 130 | dispersed from the average value (the mean). 131 | """ 132 | return np.sqrt(np.sum((X - X_hat) ** 2) / N) 133 | 134 | 135 | if __name__ == "__main__": 136 | 137 | import sys 138 | 139 | import matplotlib.pyplot as plt 140 | import numpy as np 141 | 142 | if len(sys.argv) < 2: 143 | exit("usage: python example.py example_input.csv") 144 | 145 | filename = sys.argv[1] 146 | f = open(filename, "r") 147 | data = np.genfromtxt(f, delimiter=",", names=True) 148 | f.close() 149 | 150 | try: 151 | ref = data["REF"] 152 | except: 153 | ref = data["\xef\xbb\xbfREF"] 154 | n = len(ref) 155 | 156 | methods = data.dtype.names 157 | methods = methods[methods.index("REF") + 1 :] 158 | nm = len(methods) 159 | 160 | rmse_list = [] 161 | rmse_lower = [] 162 | rmse_upper = [] 163 | 164 | mae_list = [] 165 | mae_lower = [] 166 | mae_upper = [] 167 | 168 | me_list = [] 169 | me_lower = [] 170 | me_upper = [] 171 | 172 | r_list = [] 173 | r_lower = [] 174 | r_upper = [] 175 | 176 | for method in methods: 177 | mdata = data[method] 178 | 179 | # RMSE 180 | mrmse, mle, mue = rmse(mdata, ref) 181 | rmse_list.append(mrmse) 182 | rmse_lower.append(mle) 183 | rmse_upper.append(mue) 184 | 185 | # MAD 186 | mmae, maele, maeue = mae(mdata, ref) 187 | mae_list.append(mmae) 188 | mae_lower.append(maele) 189 | mae_upper.append(maeue) 190 | 191 | # ME 192 | mme, mmee = me(mdata, ref) 193 | me_list.append(mme) 194 | me_lower.append(mmee) 195 | me_upper.append(mmee) 196 | 197 | # r 198 | r, rle, rue = correl(mdata, ref) 199 | r_list.append(r) 200 | r_lower.append(rle) 201 | r_upper.append(rue) 202 | 203 | print( 204 | f"{'Method_A':<31}{'Method_B':<35}{'RMSE_A':<7}{'RMSE_B':<8}{'RMSE_A-RMSE_B':<20}{'Comp Err':<8}{'same?':<15}" 205 | ) 206 | ps = "{:30s} " * 2 + "{:8.3f} " * 2 + "{:8.3f}" + "{:15.3f}" + " {:}" 207 | 208 | check = "rmse" 209 | 210 | if check == "pearson": 211 | measure = r_list 212 | upper_error = r_upper 213 | lower_error = r_lower 214 | else: 215 | measure = rmse_list 216 | upper_error = rmse_upper 217 | lower_error = rmse_lower 218 | # measure = mae_list 219 | # upper_error = mae_upper 220 | # lower_error = mae_lower 221 | 222 | for i in range(nm): 223 | for j in range(i + 1, nm): 224 | 225 | m_i = methods[i] 226 | m_j = methods[j] 227 | 228 | rmse_i = measure[i] 229 | rmse_j = measure[j] 230 | 231 | r_ij = np.corrcoef(data[m_i], data[m_j])[0][1] 232 | 233 | if rmse_i > rmse_j: 234 | lower = lower_error[i] 235 | upper = upper_error[j] 236 | else: 237 | lower = lower_error[j] 238 | upper = upper_error[i] 239 | 240 | comp_error = np.sqrt(upper ** 2 + lower ** 2 - 2.0 * r_ij * upper * lower) 241 | significance = abs(rmse_i - rmse_j) < comp_error 242 | 243 | print(ps.format(m_i, m_j, rmse_i, rmse_j, rmse_i - rmse_j, comp_error, significance)) 244 | 245 | print("\\begin{table}[]") 246 | print("\centering") 247 | print("\caption{}") 248 | print("\label{}") 249 | print("\\begin{tabular}{l" + nm * "c" + "}") 250 | print("\midrule") 251 | print("& " + " & ".join(methods) + "\\\\") 252 | print("\midrule") 253 | # for i in xrange(nm-1): 254 | # print '%.1f $\pm$ %.1f/%.1f &'%(rmse_list[i],lower_error[i],rmse_upper[i]), 255 | # print '%.1f $\pm$ %.1f/%.1f'%(rmse_list[-1],lower_error[-1],rmse_upper[-1]) 256 | print("RMSE & " + " & ".join(format(x, "3.2f") for x in rmse_list) + "\\\\") 257 | 258 | temp_list = [ 259 | i + "/" + j for i, j in zip([format(x, "3.3f") for x in rmse_upper], [format(x, "3.3f") for x in rmse_lower]) 260 | ] 261 | print("95 \% conf & $\pm$ " + " & $\pm$ ".join(temp_list) + "\\\\") 262 | 263 | temp_list = [ 264 | i + " $\pm$ " + j for i, j in zip([format(x, "3.3f") for x in me_list], [format(x, "3.3f") for x in me_upper]) 265 | ] 266 | print("ME & " + " & ".join(temp_list) + "\\\\") 267 | 268 | print("$r$ & " + " & ".join(format(x, "3.3f") for x in r_list) + "\\\\") 269 | 270 | temp_list = [ 271 | i + "/" + j for i, j in zip([format(x, "3.3f") for x in r_upper], [format(x, "3.3f") for x in r_lower]) 272 | ] 273 | print("95 \% conf & $\pm$ " + " & $\pm$ ".join(temp_list) + "\\\\") 274 | 275 | print("\midrule") 276 | print("\end{tabular}") 277 | print("\end{table}") 278 | 279 | # Create x-axis 280 | x = range(len(methods)) 281 | 282 | # Errorbar (upper and lower) 283 | asymmetric_error = [rmse_lower, rmse_upper] 284 | 285 | # Add errorbar for RMSE 286 | plt.errorbar(x, rmse_list, yerr=asymmetric_error, fmt="o") 287 | 288 | # change x-axis to method names and rotate the ticks 30 degrees 289 | plt.xticks(x, methods, rotation=30, ha="right") 290 | 291 | # Pad margins so that markers don't get clipped by the axes 292 | plt.margins(0.2) 293 | 294 | # Tweak spacing to prevent clipping of tick-labels 295 | plt.subplots_adjust(bottom=0.15) 296 | 297 | # Add grid to plot 298 | plt.grid(True) 299 | 300 | # Set plot title 301 | plt.title("Root-mean-squared error") 302 | 303 | # Save plot to PNG format 304 | plt.savefig("example_rmse.png", bbox_inches="tight") 305 | 306 | # Clear figure 307 | plt.clf() 308 | 309 | # MAE plot 310 | asymmetric_error = [mae_lower, mae_upper] 311 | plt.errorbar(x, mae_list, yerr=asymmetric_error, fmt="o") 312 | plt.xticks(x, methods, rotation=30, ha="right") 313 | plt.margins(0.2) 314 | plt.subplots_adjust(bottom=0.15) 315 | plt.grid(True) 316 | plt.title("Mean Absolute Error") 317 | plt.savefig("example_mae.png", bbox_inches="tight") 318 | 319 | # Clear figure 320 | plt.clf() 321 | 322 | # ME plot 323 | asymmetric_error = [me_lower, me_upper] 324 | plt.errorbar(x, me_list, yerr=asymmetric_error, fmt="o") 325 | plt.xticks(x, methods, rotation=30, ha="right") 326 | plt.margins(0.2) 327 | plt.subplots_adjust(bottom=0.15) 328 | plt.grid(True) 329 | plt.title("Mean Error") 330 | plt.savefig("example_me.png", bbox_inches="tight") 331 | -------------------------------------------------------------------------------- /bcpaff/ml/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import argparse 6 | import os 7 | import pickle 8 | import random 9 | 10 | import pandas as pd 11 | import torch 12 | 13 | from bcpaff.ml.ml_utils import get_data_loader, load_checkpoint 14 | from bcpaff.ml.net_utils import QtaimScaler 15 | from bcpaff.ml.test import run_test 16 | from bcpaff.ml.train import eval_loop, get_split_idxs 17 | from bcpaff.utils import SEED 18 | 19 | random.seed(SEED) 20 | torch.manual_seed(SEED) 21 | HPARAM_KEYS = [ 22 | "batch_size", 23 | "kernel_dim", 24 | "mlp_dim", 25 | "cutoff", 26 | "baseline_atom_ids", 27 | "aggr", 28 | "pool", 29 | "properties", 30 | "n_kernels", 31 | "ncp_graph", 32 | ] 33 | 34 | 35 | def get_best_hparams(base_output_dir, last_epoch=False, quiet=False): 36 | if last_epoch: 37 | hparam_file = os.path.join(base_output_dir, "hparam_results_last_epoch.csv") 38 | else: 39 | hparam_file = os.path.join(base_output_dir, "hparam_results.csv") 40 | if not os.path.exists(hparam_file): 41 | raise ValueError(f"hparam_file not found: {hparam_file}") 42 | df = pd.read_csv(hparam_file).sort_values(by="eval_rmse") 43 | if not quiet: 44 | print(f"Found {len(df)} results from hparam optimization and picking the one with best eval_rmse") 45 | best_hparams = df.iloc[0][HPARAM_KEYS].to_dict() 46 | best_run_id = df.iloc[0].run_id 47 | dir_best_eval_rmse = os.path.join(base_output_dir, best_run_id) 48 | if last_epoch: 49 | checkpoint_savepath = os.path.join(dir_best_eval_rmse, "last_epoch_checkpoint.pt") 50 | else: 51 | checkpoint_savepath = os.path.join(dir_best_eval_rmse, "checkpoint.pt") 52 | if not quiet: 53 | print(f"Using model {checkpoint_savepath}") 54 | return best_hparams, checkpoint_savepath 55 | 56 | 57 | def run_test(pickle_file, base_output_dir, dataset, split_type, last_epoch=False): 58 | hparams, checkpoint_savepath = get_best_hparams(base_output_dir, last_epoch=last_epoch) 59 | model = load_checkpoint(hparams, checkpoint_savepath) 60 | train_idxs, _, core_idxs, test_idxs = get_split_idxs(dataset, split_type) 61 | 62 | scaler = QtaimScaler( 63 | pickle_file, train_idxs, ncp_graph=hparams["ncp_graph"] 64 | ) # still use train set normalization values 65 | 66 | test_loader = get_data_loader( 67 | pickle_file, hparams, scaler, test_idxs, shuffle=False, pickle_data=scaler.pickle_data 68 | ) 69 | 70 | loaders = {"test": test_loader} 71 | 72 | if dataset == "pdbbind": # core set only for pdbbind 73 | core_loader = get_data_loader( 74 | pickle_file, hparams, scaler, core_idxs, shuffle=False, pickle_data=scaler.pickle_data 75 | ) 76 | loaders["core"] = core_loader 77 | 78 | for key, loader in loaders.items(): 79 | mae, rmse, loss, y, yhat = eval_loop(model, loader, torch.nn.MSELoss()) 80 | results = { 81 | f"{key}_mae": mae, 82 | f"{key}_rmse": rmse, 83 | f"{key}_loss": loss, 84 | f"y_{key}": y, 85 | f"yhat_{key}": yhat, 86 | f"{key}_idxs": loader.dataset.pdb_ids, 87 | "hparams": hparams, 88 | } 89 | 90 | if last_epoch: 91 | results_savepath = os.path.join(base_output_dir, f"{key}_results_last_epoch.pkl") 92 | else: 93 | results_savepath = os.path.join(base_output_dir, f"{key}_results.pkl") 94 | with open(results_savepath, "wb") as f: 95 | pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL) 96 | print(f"Saved {key} set results to {results_savepath}") 97 | 98 | 99 | if __name__ == "__main__": 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument("--pickle_file", type=str, required=True) 102 | parser.add_argument("--base_output_dir", type=str, default=None) 103 | parser.add_argument("--dataset", type=str, default="pdbbind") 104 | parser.add_argument("--split_type", type=str, default="random") 105 | parser.add_argument("--last_epoch", action="store_true", default=False) 106 | args = parser.parse_args() 107 | 108 | run_test( 109 | args.pickle_file, 110 | args.base_output_dir, 111 | args.dataset, 112 | args.split_type, 113 | args.last_epoch, 114 | ) 115 | -------------------------------------------------------------------------------- /bcpaff/ml/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import argparse 6 | import json 7 | import os 8 | import random 9 | from typing import List, Tuple, Union 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | from sklearn.metrics import mean_squared_error 15 | from torch.utils.tensorboard import SummaryWriter 16 | from torch_geometric.loader import DataLoader 17 | from tqdm import tqdm 18 | 19 | from bcpaff.ml.ml_utils import get_data_loader, hparams_to_run_id, save_checkpoint 20 | from bcpaff.ml.net import EGNN, EGNN_NCP, EGNNAtt 21 | from bcpaff.ml.net_utils import QtaimScaler 22 | from bcpaff.utils import DATA_PATH, SEED 23 | 24 | random.seed(SEED) 25 | torch.manual_seed(SEED) 26 | 27 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | 29 | 30 | def train_loop( 31 | model: Union[EGNN, EGNN_NCP, EGNNAtt], loader: DataLoader, optimizer: torch.optim.Adam, criterion: torch.nn.MSELoss 32 | ) -> Tuple[float, float, float, np.ndarray, np.ndarray]: 33 | model.train() 34 | training_loss = [] 35 | targets = [] 36 | predictions = [] 37 | 38 | # for g_batch in tqdm(loader, total=len(loader)): 39 | for g_batch in loader: 40 | g_batch = g_batch.to(DEVICE) 41 | optimizer.zero_grad() 42 | 43 | prediction = model(g_batch).squeeze(axis=1) 44 | target = g_batch.target 45 | 46 | loss = criterion(prediction, target.float()) 47 | loss.backward() 48 | optimizer.step() 49 | 50 | training_loss.append(loss.item()) 51 | predictions.append(prediction.detach().cpu().numpy()) 52 | 53 | targets.append(target.detach().cpu().numpy()) 54 | 55 | targets = np.concatenate(targets) 56 | predictions = np.concatenate(predictions) 57 | train_mae = np.mean(np.abs(targets - predictions)) 58 | train_rmse = mean_squared_error(targets, predictions, squared=False) 59 | 60 | return ( 61 | train_mae, 62 | train_rmse, 63 | np.mean(training_loss, axis=0), 64 | targets, 65 | predictions, 66 | ) 67 | 68 | 69 | def eval_loop( 70 | model: Union[EGNN, EGNN_NCP, EGNNAtt], 71 | loader: DataLoader, 72 | criterion: torch.nn.MSELoss, 73 | y_scrambling: bool = False, 74 | input_scrambling: bool = False, 75 | ) -> Tuple[float, float, float, np.ndarray, np.ndarray]: 76 | model.eval() 77 | eval_loss = [] 78 | targets = [] 79 | predictions = [] 80 | 81 | with torch.no_grad(): 82 | # for g_batch in tqdm(loader, total=len(loader)): 83 | for g_batch in loader: 84 | g_batch = g_batch.to(DEVICE) 85 | 86 | prediction = model(g_batch, input_scrambling=input_scrambling).squeeze(axis=1) 87 | target = g_batch.target 88 | if y_scrambling: 89 | target = target[torch.randperm(target.shape[0])] 90 | loss = criterion(prediction, target.float()) 91 | 92 | eval_loss.append(loss.item()) 93 | predictions.append(prediction.detach().cpu().numpy()) 94 | targets.append(target.detach().cpu().numpy()) 95 | 96 | targets = np.concatenate(targets) 97 | predictions = np.concatenate(predictions) 98 | eval_mae = np.mean(np.abs(targets - predictions)) 99 | eval_rmse = mean_squared_error(targets, predictions, squared=False) 100 | 101 | return ( 102 | eval_mae, 103 | eval_rmse, 104 | np.mean(eval_loss, axis=0), 105 | targets, 106 | predictions, 107 | ) 108 | 109 | 110 | def get_split_idxs(dataset: str, split_type: str) -> Tuple[List[str],]: 111 | split_col = f"{split_type}_split" 112 | if dataset.startswith("pdbbind"): # allowed split types: random, carbonic_anhydrase_2 (core_set not available) 113 | split_assignment_df = pd.read_csv(os.path.join(DATA_PATH, "pdbbind", "pdbbind2019_affinity.csv")) 114 | train_idxs = split_assignment_df[split_assignment_df[split_type] == "training_set"].pdb_id.tolist() 115 | eval_idxs = split_assignment_df[split_assignment_df[split_type] == "validation_set"].pdb_id.tolist() 116 | core_idxs = split_assignment_df[split_assignment_df[split_type] == "core_set"].pdb_id.tolist() 117 | test_idxs = split_assignment_df[split_assignment_df[split_type] == "hold_out_set"].pdb_id.tolist() 118 | elif dataset == "pde10a": 119 | split_assignment_df = pd.read_csv(os.path.join(DATA_PATH, "pde10a", "10822_2022_478_MOESM2_ESM.csv")) 120 | train_idxs = split_assignment_df[split_assignment_df[split_col] == "train"].docking_folder.tolist() 121 | eval_idxs = split_assignment_df[split_assignment_df[split_col] == "val"].docking_folder.tolist() 122 | test_idxs = split_assignment_df[split_assignment_df[split_col] == "test"].docking_folder.tolist() 123 | core_idxs = None # not applicable for this dataset 124 | else: 125 | raise ValueError(f"Unsupported dataset: {dataset}") 126 | return train_idxs, eval_idxs, core_idxs, test_idxs 127 | 128 | 129 | def run_training( 130 | hparams: dict, 131 | hparam_file: str, 132 | pickle_file: str, 133 | dataset: str, 134 | split_type: str, 135 | base_output_dir: str, 136 | overwrite: bool = False, 137 | no_lr_decay: bool = False, 138 | y_scrambling: bool = False, 139 | num_epochs: int = 300, 140 | ): 141 | if not os.path.exists(pickle_file): 142 | raise ValueError(f"pickle_file {pickle_file} missing") 143 | 144 | run_id = hparams_to_run_id(hparams, hparam_file=hparam_file) 145 | 146 | output_dir = os.path.join(base_output_dir, str(run_id)) 147 | 148 | checkpoint_savepath = os.path.join(output_dir, "checkpoint.pt") 149 | if os.path.exists(checkpoint_savepath) and not overwrite: 150 | raise ValueError(f"Checkpoint {checkpoint_savepath} already exists and overwrite = False.") 151 | os.makedirs(output_dir, exist_ok=True) 152 | writer = SummaryWriter(log_dir=os.path.join(output_dir, "tensorboard")) 153 | print(output_dir) 154 | print("=================================") 155 | print(f"DEVICE = {DEVICE}") 156 | print(f"tensorboard --logdir {output_dir}") 157 | print("=================================", flush=True) 158 | 159 | train_idxs, eval_idxs, _, _ = get_split_idxs(dataset, split_type) 160 | 161 | scaler = QtaimScaler(pickle_file, train_idxs, ncp_graph=hparams["ncp_graph"]) 162 | 163 | train_loader = get_data_loader( 164 | pickle_file, hparams, scaler, train_idxs, shuffle=True, pickle_data=scaler.pickle_data 165 | ) 166 | print("Got train_loader", flush=True) 167 | eval_loader = get_data_loader( 168 | pickle_file, hparams, scaler, eval_idxs, shuffle=False, pickle_data=scaler.pickle_data 169 | ) 170 | print("Got eval_loader", flush=True) 171 | 172 | if hparams["ncp_graph"]: 173 | model = EGNN_NCP 174 | else: 175 | if hparams["pool"].startswith("att"): 176 | model = EGNNAtt 177 | else: 178 | model = EGNN 179 | model = model( 180 | n_kernels=hparams["n_kernels"], 181 | aggr=hparams["aggr"], 182 | pool=hparams["pool"], 183 | mlp_dim=hparams["mlp_dim"], 184 | kernel_dim=hparams["kernel_dim"], 185 | baseline_atom_ids=hparams["baseline_atom_ids"], 186 | properties=hparams["properties"], 187 | ) 188 | print("Got model", flush=True) 189 | model = model.to(DEVICE) 190 | print("Put model in device", flush=True) 191 | 192 | print(model, flush=True) 193 | 194 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-10) 195 | 196 | if not no_lr_decay: 197 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 198 | optimizer, mode="min", factor=0.7, patience=20, verbose=True 199 | ) 200 | 201 | min_mae = float("inf") 202 | 203 | for epoch in range(num_epochs): 204 | print(f"Epoch {epoch}/{num_epochs}", flush=True) 205 | 206 | train_mae, train_rmse, train_loss, y_train, yhat_train = train_loop( 207 | model, train_loader, optimizer, torch.nn.MSELoss() 208 | ) 209 | writer.add_scalars("loss", {"train": train_loss}, global_step=epoch) 210 | writer.add_scalars("mae", {"train": train_mae}, global_step=epoch) 211 | writer.add_scalars("rmse", {"train": train_rmse}, global_step=epoch) 212 | if epoch % 1 == 0: 213 | eval_mae, eval_rmse, eval_loss, y_eval, yhat_eval = eval_loop( 214 | model, eval_loader, torch.nn.MSELoss(), y_scrambling=y_scrambling 215 | ) 216 | # fig = generate_scatter_plot(y_train, yhat_train, y_eval, yhat_eval) 217 | # writer.add_figure("scatter_plot", fig, global_step=epoch) 218 | writer.add_scalars("loss", {"eval": eval_loss}, global_step=epoch) 219 | writer.add_scalars("mae", {"eval": eval_mae}, global_step=epoch) 220 | writer.add_scalars("rmse", {"eval": eval_rmse}, global_step=epoch) 221 | datapoints = { 222 | "train_loss": train_loss, 223 | "eval_loss": eval_loss, 224 | "y_train": y_train, 225 | "yhat_train": yhat_train, 226 | "y_eval": y_eval, 227 | "yhat_eval": yhat_eval, 228 | "train_mae": train_mae, 229 | "train_rmse": train_rmse, 230 | "eval_mae": eval_mae, 231 | "eval_rmse": eval_rmse, 232 | } 233 | if not no_lr_decay: 234 | scheduler.step(eval_mae) 235 | 236 | if eval_mae < min_mae: 237 | min_mae = eval_mae 238 | print(f"New min eval_mae in epoch {epoch}: {eval_mae:.6f}", flush=True) 239 | save_checkpoint(model, optimizer, epoch, checkpoint_savepath, datapoints=datapoints) 240 | if epoch == num_epochs - 1: 241 | print(f"Final epoch {epoch} eval_mae: {eval_mae:.6f}", flush=True) 242 | checkpoint_savepath = os.path.join(output_dir, "last_epoch_checkpoint.pt") 243 | save_checkpoint(model, optimizer, epoch, checkpoint_savepath, datapoints=datapoints) 244 | print("Done.", flush=True) 245 | 246 | 247 | if __name__ == "__main__": 248 | parser = argparse.ArgumentParser() 249 | parser.add_argument("--pickle_file", type=str, required=True) 250 | parser.add_argument("--hparam_file", type=str, default=None) 251 | parser.add_argument("--dataset", type=str, default="pdbbind") 252 | parser.add_argument("--split_type", type=str, default="random") 253 | parser.add_argument("--config_file", type=str, required=True) 254 | parser.add_argument("--base_output_dir", type=str, default=None) 255 | parser.add_argument("--overwrite", action="store_true", default=False) 256 | parser.add_argument("--no_lr_decay", action="store_true", default=False) 257 | parser.add_argument("--y_scrambling", action="store_true", default=False) 258 | parser.add_argument("--num_epochs", type=int, default=1000) 259 | args = parser.parse_args() 260 | 261 | with open(args.config_file, "r") as f: 262 | hparams = json.load(f) 263 | 264 | run_training( 265 | hparams, 266 | args.hparam_file, 267 | args.pickle_file, 268 | args.dataset, 269 | args.split_type, 270 | args.base_output_dir, 271 | overwrite=args.overwrite, 272 | no_lr_decay=args.no_lr_decay, 273 | y_scrambling=args.y_scrambling, 274 | num_epochs=args.num_epochs, 275 | ) 276 | -------------------------------------------------------------------------------- /bcpaff/qm/benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import argparse 6 | import glob 7 | import os 8 | import pickle 9 | import shutil 10 | import subprocess 11 | import time 12 | from typing import Optional 13 | 14 | import numpy as np 15 | import pandas as pd 16 | from rdkit import Chem 17 | from scipy.spatial.distance import cdist 18 | from tqdm import tqdm 19 | 20 | from bcpaff.qtaim.multiwfn_tools import run_multiwfn_analysis 21 | from bcpaff.qtaim.qtaim_reader import QtaimProps 22 | from bcpaff.utils import ANALYSIS_PATH, DATA_PATH, PROCESSED_DATA_PATH 23 | 24 | XTB_VALIDATION_BASEPATH = os.path.join(ANALYSIS_PATH, "xtb_validation") 25 | os.makedirs(XTB_VALIDATION_BASEPATH, exist_ok=True) 26 | 27 | NUM_BENCHMARK_STRUCTURES = 3 28 | THRESHOLD = 0.1 29 | NO_CORRES_PLACEHOLDER = -9999 30 | ABBREVIATIONS = { 31 | "nucleus_critical_point": "NCP", 32 | "bond_critical_point": "BCP", 33 | "ring_critical_point": "RCP", 34 | "cage_critical_point": "CCP", 35 | } 36 | 37 | 38 | def submit_benchmark_job(structure_basepath: str, level_of_theory: str, cluster_options: Optional[str] = None) -> str: 39 | conda_env = "bcpaff" if level_of_theory == "xtb" else "bcpaff_psi4" 40 | cmd_str = f"""source ~/.bashrc; source activate {conda_env}; python -c 'from bcpaff.qm.benchmark import run_benchmark; run_benchmark(\\\"{structure_basepath}\\\", \\\"{level_of_theory}\\\")' """ 41 | cluster_options = "" if cluster_options is None else cluster_options 42 | if cluster_options != "no_cluster": 43 | pdb_id = os.path.basename(structure_basepath) 44 | slurm_output_file = os.path.join( 45 | XTB_VALIDATION_BASEPATH, 46 | pdb_id, 47 | level_of_theory, 48 | "slurm_files", 49 | "out_files", 50 | f"{level_of_theory}_out_%A.out", 51 | ) 52 | os.makedirs(os.path.dirname(slurm_output_file), exist_ok=True) 53 | cmd_str = f"""sbatch -n 4 --mem-per-cpu=16000 --tmp=100000 --time=48:00:00 --parsable {cluster_options} --output={slurm_output_file} --wrap "{cmd_str}" """ 54 | completed_process = subprocess.run( 55 | cmd_str, 56 | shell=True, 57 | universal_newlines=True, 58 | stdout=subprocess.PIPE, 59 | ) 60 | job_id = completed_process.stdout.rstrip("\n") 61 | return job_id 62 | 63 | 64 | def run_benchmark(structure_basepath: str, level_of_theory: str): 65 | pdb_id = os.path.basename(structure_basepath) 66 | dest_basepath = os.path.join(XTB_VALIDATION_BASEPATH, pdb_id, level_of_theory, pdb_id) # naming_convention... 67 | os.makedirs(dest_basepath, exist_ok=True) 68 | 69 | filenames_to_copy = [ 70 | f"{pdb_id}_ligand_with_hydrogens.sdf", 71 | f"{pdb_id}_pocket_with_hydrogens.sdf", 72 | f"{pdb_id}_pocket_with_hydrogens.xyz", 73 | "pl_complex.sdf", 74 | "pl_complex.xyz", 75 | "psi4_input.pkl", 76 | "chrg_uhf.json", 77 | ] 78 | for filename in filenames_to_copy: 79 | src = os.path.join(structure_basepath, filename) 80 | dest = os.path.join(dest_basepath, filename) 81 | shutil.copy(src, dest) 82 | 83 | results = {} 84 | 85 | # compute wfn 86 | t0 = time.time() 87 | if level_of_theory == "dft": 88 | import psi4 # do the imports here since environments for psi4 and xtb are incompatible 89 | 90 | from bcpaff.qm.compute_wfn_psi4 import compute_wfn_psi4 91 | 92 | psi4.core.be_quiet() 93 | psi4_input_pickle = os.path.join(dest_basepath, "psi4_input.pkl") 94 | wfn_file = compute_wfn_psi4( 95 | psi4_input_pickle, 96 | memory=8, 97 | num_cores=1, 98 | level_of_theory="wb97x-d", 99 | basis_set="def2-qzvp", 100 | ) 101 | elif level_of_theory == "xtb": 102 | from bcpaff.qm.compute_wfn_xtb import ( 103 | compute_wfn_xtb, 104 | ) # do the imports here since environments for psi4 and xtb are incompatible 105 | 106 | wfn_file = compute_wfn_xtb(os.path.join(dest_basepath, "pl_complex.xyz")) 107 | else: 108 | raise ValueError("Unknown level of theory") 109 | t1 = time.time() 110 | time_needed = t1 - t0 111 | 112 | # run multiwfn analysis 113 | ligand_sdf = os.path.join(dest_basepath, f"{pdb_id}_ligand_with_hydrogens.sdf") 114 | num_ligand_atoms = next(Chem.SDMolSupplier(ligand_sdf, removeHs=False)).GetNumAtoms() 115 | cp_file, cpprop_file, paths_file = run_multiwfn_analysis( 116 | wfn_file, only_intermolecular=False, only_bcps=False, num_ligand_atoms=num_ligand_atoms, include_esp=False 117 | ) 118 | qtaim_props = QtaimProps(basepath=dest_basepath) 119 | 120 | results[f"{level_of_theory}"] = ( 121 | qtaim_props, 122 | time_needed, 123 | cp_file, 124 | cpprop_file, 125 | paths_file, 126 | wfn_file, 127 | ) 128 | 129 | results_savepath = os.path.join(dest_basepath, "results.pkl") 130 | with open(results_savepath, "wb") as f: 131 | pickle.dump(results, f) 132 | 133 | 134 | def get_equivalent_points(results): 135 | equivalent_points = {} 136 | qtaim_props_ref = results["dft"] 137 | other_methods = ["xtb"] 138 | other_qtaim_props = [results[method] for method in other_methods] 139 | 140 | for method, oqp in zip(other_methods, other_qtaim_props): 141 | other_points = oqp.cp_positions 142 | distance_matrix = cdist(qtaim_props_ref.cp_positions, other_points) 143 | corresponding_idx = np.argmin(distance_matrix, axis=1) 144 | no_corresponding = np.min(distance_matrix, axis=1) > THRESHOLD 145 | corresponding_idx[no_corresponding] = NO_CORRES_PLACEHOLDER 146 | equivalent_points[method] = corresponding_idx 147 | df_equiv = pd.DataFrame.from_dict(equivalent_points) 148 | df_equiv.loc[:, "point_name"] = [ABBREVIATIONS[cp.name] for cp in qtaim_props_ref.critical_points] 149 | # check that the point names (types) match 150 | for _, row in df_equiv.iterrows(): 151 | for method in other_methods: 152 | method_idx = row[method] 153 | if method_idx == NO_CORRES_PLACEHOLDER: 154 | continue 155 | assert row.point_name == ABBREVIATIONS[results[method].critical_points[method_idx].name] 156 | return df_equiv 157 | 158 | 159 | def summarize_benchmark_results(computed_results_benchmark, force_recompute=False): 160 | folders = glob.glob(os.path.join(computed_results_benchmark, "*")) 161 | pdb_ids = [folder for folder in folders if not os.path.isfile(folder)] 162 | pdb_ids = sorted([os.path.basename(pdb_id) for pdb_id in pdb_ids]) 163 | all_savepaths = [] 164 | for pdb_id in pdb_ids: 165 | results_savepath = os.path.join(computed_results_benchmark, pdb_id, "benchmark_results.pkl") 166 | all_savepaths.append(results_savepath) 167 | if os.path.exists(results_savepath) and not force_recompute: 168 | continue 169 | # load results from xTB and DFT 170 | results = {} 171 | for level_of_theory in ["dft", "xtb"]: 172 | qtaim_props = QtaimProps( 173 | basepath=os.path.join(computed_results_benchmark, pdb_id, level_of_theory, pdb_id) 174 | ) 175 | results[level_of_theory] = qtaim_props 176 | 177 | reference_name = "dft" 178 | df_equiv = get_equivalent_points(results) 179 | properties = list(results[reference_name].critical_points[0].props.keys()) 180 | iterables = [results.keys(), properties + ["x", "y", "z", "intermolecular"]] 181 | multiindex = pd.MultiIndex.from_product(iterables, names=["method", "property"]) 182 | df = pd.DataFrame(columns=multiindex) 183 | for method in results.keys(): 184 | qtaim_props = results[method] 185 | for prop in properties: 186 | for atom_idx_ref, row in df_equiv.iterrows(): 187 | atom_idx = row[method] if method != reference_name else atom_idx_ref 188 | if atom_idx == NO_CORRES_PLACEHOLDER: 189 | continue 190 | cp = qtaim_props.critical_points[atom_idx] 191 | val = cp.props[prop] 192 | df.loc[atom_idx_ref, (method, prop)] = val 193 | for atom_idx_ref, row in df_equiv.iterrows(): # add the coordinates (only needs one iteration) 194 | atom_idx = row[method] if method != reference_name else atom_idx_ref 195 | if atom_idx == NO_CORRES_PLACEHOLDER: 196 | continue 197 | cp = qtaim_props.critical_points[atom_idx] 198 | pos = cp.position 199 | df.loc[atom_idx_ref, (method, "x")] = pos[0] 200 | df.loc[atom_idx_ref, (method, "y")] = pos[1] 201 | df.loc[atom_idx_ref, (method, "z")] = pos[2] 202 | df.loc[atom_idx_ref, (method, "intermolecular")] = cp.intermolecular 203 | df.loc[:, "point_name"] = pd.Series(df.index).apply(lambda x: df_equiv.loc[x, "point_name"]) 204 | with open(results_savepath, "wb") as f: 205 | pickle.dump((df, df_equiv), f) 206 | return all_savepaths 207 | 208 | 209 | def plot_benchmark_results(all_savepaths): 210 | dfs, dfs_equiv = [], [] 211 | for savepath in all_savepaths: 212 | with open(savepath, "rb") as f: 213 | df, df_equiv = pickle.load(f) 214 | dfs.append(df) 215 | dfs_equiv.append(df_equiv) 216 | df = pd.concat(dfs, axis=0) 217 | df = df.reset_index() 218 | 219 | 220 | def select_benchmark_structures() -> pd.DataFrame: 221 | structure_basepath = os.path.join(PROCESSED_DATA_PATH, "prepared_structures", "pdbbind") 222 | pl_complex_xyzs = glob.glob(os.path.join(structure_basepath, "*", "pl_complex.xyz")) 223 | pdb_ids = [os.path.basename(os.path.dirname(x)) for x in pl_complex_xyzs] 224 | num_atoms = [] 225 | for pl_complex_xyz in tqdm(pl_complex_xyzs): 226 | with open(pl_complex_xyz, "r") as f: 227 | num_atoms.append(int(f.readlines()[0])) 228 | df = pd.DataFrame({"pdb_id": pdb_ids, "num_atoms": num_atoms, "pl_complex_xyz": pl_complex_xyzs}) 229 | df_sample = df.sort_values(by="num_atoms").iloc[:NUM_BENCHMARK_STRUCTURES] 230 | 231 | csv_savepath = os.path.join(XTB_VALIDATION_BASEPATH, "df_sample.csv") 232 | df_sample.to_csv(csv_savepath, index=False) 233 | print(f"Wrote df_sample to {csv_savepath}", flush=True) 234 | return df_sample 235 | 236 | 237 | if __name__ == "__main__": 238 | parser = argparse.ArgumentParser() 239 | parser.add_argument("--cluster_options", type=str, default="no_cluster") 240 | args = parser.parse_args() 241 | 242 | computed_results_basepath = os.path.join(XTB_VALIDATION_BASEPATH, "computed_on_euler") 243 | all_savepaths = summarize_benchmark_results(computed_results_basepath, force_recompute=True) 244 | plot_benchmark_results(all_savepaths) 245 | # df_sample = select_benchmark_structures() 246 | # job_ids = [] 247 | # for _, row in df_sample.iterrows(): 248 | # structure_basepath = os.path.dirname(row.pl_complex_xyz) 249 | # for level_of_theory in ["dft", "xtb"]: 250 | # job_id = submit_benchmark_job(structure_basepath, level_of_theory) 251 | -------------------------------------------------------------------------------- /bcpaff/qm/compute_wfn_dftb.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import argparse 6 | import json 7 | import os 8 | import subprocess 9 | from typing import List, Optional 10 | 11 | from openbabel.pybel import readfile 12 | from rdkit import Chem 13 | 14 | from bcpaff.utils import DFTBPLUS_DATA_PATH, SEED 15 | 16 | periodic_table = Chem.GetPeriodicTable() 17 | hubbard_derivs_dict = { 18 | 1: "H = -0.1857", 19 | 8: "O = -0.1575", 20 | 35: "Br = -0.0573", 21 | 6: "C = -0.1492", 22 | 17: "Cl = -0.0697", 23 | 9: "F = -0.1623", 24 | 53: "I = -0.0433", 25 | 7: "N = -0.1535", 26 | 15: "P = -0.14", 27 | 16: "S = -0.11", 28 | 20: "Ca = -0.0340", 29 | 19: "K= -0.0339", 30 | 12: "Mg = -0.02", 31 | 11: "Na = -0.0454", 32 | 30: "Zn = -0.03", 33 | } # https://dftb.org/parameters/download/3ob/3ob-3-1-cc (accessed 06.02.23) 34 | 35 | 36 | def get_max_angular_momentum(atomic_num: int) -> str: 37 | """Figure out maximum angular momentum we should take into account 38 | 39 | Parameters 40 | ---------- 41 | atomic_num : int 42 | atomic number 43 | 44 | Returns 45 | ------- 46 | str 47 | descriptions of maximum angular momentum for given atom type 48 | """ 49 | element_symbol = periodic_table.GetElementSymbol(atomic_num) 50 | if atomic_num <= 2: # first period 51 | max_angular_momentum = "s" 52 | elif atomic_num <= 10: # second period 53 | max_angular_momentum = "p" 54 | elif atomic_num <= 18: # third period 55 | max_angular_momentum = "d" 56 | elif atomic_num > 18: 57 | max_angular_momentum = "f" 58 | return f"{element_symbol} = {max_angular_momentum}" 59 | 60 | 61 | def get_hubbard_derivs(atomicnums: List[int]) -> str: 62 | """Get Hubbard values for set of atomic numbers""" 63 | hubbard_str = " " 64 | for a in atomicnums: 65 | if a in hubbard_derivs_dict: 66 | hubbard_str += f"\n {hubbard_derivs_dict[a]}" 67 | hubbard_str += "\n }" 68 | return hubbard_str 69 | 70 | 71 | def generate_dftb_input_file(xyz_path: str, qm_method: str, implicit_solvent: Optional[str] = None): 72 | """Generate DFTB+ instructions""" 73 | 74 | basepath = os.path.dirname(xyz_path) 75 | json_path = os.path.join(basepath, "chrg_uhf.json") 76 | 77 | with open(json_path, "r") as f: 78 | chrg_uhf = json.load(f) 79 | 80 | mol = next(readfile("xyz", xyz_path)) 81 | atomicnums = sorted(set([a.atomicnum for a in mol.atoms])) 82 | max_angular_momentum = " " + "\n ".join([get_max_angular_momentum(a) for a in atomicnums]) 83 | 84 | if implicit_solvent == "water": 85 | solvation = "Solvation = GeneralisedBorn {\n " 86 | solvation += f' ParamFile = "{os.path.join(DFTBPLUS_DATA_PATH, "param_gbsa_h2o.txt")}"' 87 | solvation += "\n }\n" 88 | elif implicit_solvent is None: 89 | solvation = "" 90 | 91 | if chrg_uhf["num_unpaired_electrons"] == 0: 92 | spin_polarisation = "" 93 | else: 94 | spin_polarisation = "SpinPolarisation = Colinear {\n " 95 | spin_polarisation += f' UnpairedElectrons = {chrg_uhf["num_unpaired_electrons"]}' 96 | spin_polarisation += "\n }\n" 97 | 98 | if qm_method == "dftb3": 99 | corrections = "ThirdOrderFull = Yes\n" 100 | corrections += " Filling = Fermi {\n Temperature [K] = 300\n }\n" 101 | corrections += " HubbardDerivs {" 102 | corrections += get_hubbard_derivs(atomicnums) 103 | corrections += "\n HCorrection = Damping { \n Exponent = 4.00 \n }" 104 | else: 105 | corrections = "" 106 | 107 | cmd_txt = f""" 108 | Geometry = xyzFormat {{ 109 | <<< "{xyz_path}" 110 | }} 111 | 112 | Driver = {{}} 113 | 114 | Hamiltonian = DFTB {{ 115 | Charge = {chrg_uhf["charge"]} 116 | Scc = Yes 117 | MaxSCCIterations = 1000 118 | {corrections} 119 | SlaterKosterFiles = Type2FileNames {{ 120 | Prefix = "{os.path.join(DFTBPLUS_DATA_PATH, "recipes/slakos/download/3ob/3ob-3-1/")}" 121 | Separator = "-" 122 | Suffix = ".skf" 123 | }} 124 | MaxAngularMomentum {{ 125 | {max_angular_momentum} 126 | }} 127 | {solvation} 128 | {spin_polarisation} 129 | }} 130 | 131 | Options {{ 132 | WriteDetailedXML = Yes 133 | RandomSeed = {SEED} 134 | }} 135 | 136 | Analysis {{ 137 | WriteEigenvectors = Yes 138 | }} 139 | 140 | ParserOptions {{ 141 | ParserVersion = 7 142 | }} 143 | """ 144 | 145 | dftb_in_hsd = os.path.join(basepath, "dftb_in.hsd") 146 | with open(dftb_in_hsd, "w") as f: 147 | f.write(cmd_txt) 148 | 149 | 150 | def compute_wfn_dftb(xyz_path: str, qm_method: str, implicit_solvent: Optional[str] = None) -> str: 151 | """Run DFTB+ compute to obtain wavefunction (detailed.xml) 152 | 153 | Parameters 154 | ---------- 155 | xyz_path : str 156 | path to XYZ file of protein-ligand complex 157 | qm_method : str 158 | which method to use 159 | implicit_solvent : Optional[str], optional 160 | which implicit solvent to use, by default None 161 | 162 | Returns 163 | ------- 164 | str 165 | wavefunction savepath 166 | 167 | Raises 168 | ------ 169 | ValueError 170 | if running the DFTB+ computation failed 171 | """ 172 | basepath = os.path.dirname(xyz_path) 173 | 174 | generate_dftb_input_file(xyz_path, qm_method, implicit_solvent=implicit_solvent) 175 | 176 | cmd_line_output = os.path.join(basepath, "dftb_cmd_out.log") 177 | f = open(cmd_line_output, "w+") # write a new file each time 178 | cmd = ["dftb+"] 179 | completed_process = subprocess.run( 180 | cmd, 181 | stdout=f, 182 | stderr=subprocess.STDOUT, 183 | cwd=basepath, 184 | ) 185 | return_code = completed_process.returncode 186 | 187 | if return_code != 0: 188 | raise ValueError(f"{xyz_path} failed with {return_code}") 189 | f.close() 190 | wfn_savepath = os.path.join(basepath, "detailed.xml") 191 | return wfn_savepath 192 | 193 | 194 | if __name__ == "__main__": 195 | parser = argparse.ArgumentParser() 196 | parser.add_argument("--xyz_path", type=str, required=True) 197 | parser.add_argument("--qm_method", type=str, default="dftb") 198 | parser.add_argument("--solvent", type=str, default=None) 199 | args = parser.parse_args() 200 | wfn_savepath = os.path.join(os.path.dirname(args.xyz_path), "detailed.xml") 201 | if os.path.exists(wfn_savepath): 202 | print(f"detailed.xml already exists: {wfn_savepath}", flush=True) 203 | else: 204 | wfn_savepath = compute_wfn_dftb(args.xyz_path, args.qm_method, implicit_solvent=args.solvent) 205 | -------------------------------------------------------------------------------- /bcpaff/qm/compute_wfn_psi4.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import argparse 6 | import os 7 | import pickle 8 | import tempfile 9 | from shutil import rmtree 10 | 11 | import psi4 12 | 13 | PSI4_OPTIONS = {"basis": "def2-svp"} 14 | LEVEL_OF_THEORY = "wb97x-d" 15 | 16 | 17 | def compute_wfn_psi4( 18 | psi4_input_pickle, memory=8, num_cores=1, level_of_theory=LEVEL_OF_THEORY, basis_set=PSI4_OPTIONS["basis"] 19 | ): 20 | # read input data & construct molecule 21 | with open(psi4_input_pickle, "rb") as f: 22 | psi4_input = pickle.load(f) 23 | if not all([m == 1 for m in psi4_input["fragment_multiplicities"]]): 24 | raise ValueError(f"Radical electrons: {psi4_input_pickle}") 25 | p4mol = psi4.core.Molecule.from_arrays( 26 | elez=psi4_input["elez"], 27 | fragment_separators=psi4_input["fragment_separators"], 28 | fix_com=True, 29 | fix_orientation=True, 30 | fix_symmetry="c1", 31 | fragment_charges=psi4_input["fragment_charges"], 32 | fragment_multiplicities=psi4_input["fragment_multiplicities"], 33 | molecular_charge=psi4_input["molecular_charge"], 34 | molecular_multiplicity=psi4_input["molecular_multiplicity"], 35 | geom=psi4_input["geom"], 36 | ) 37 | 38 | # set scratch directory 39 | psi4_io = psi4.core.IOManager.shared_object() 40 | local_scratch = os.environ.get("TMPDIR") # local scratch directory on Euler 41 | if local_scratch: 42 | os.environ["PSI_SCRATCH"] = local_scratch 43 | psi4_io.set_default_path(local_scratch) 44 | print(f"Using local scratch {local_scratch}.", flush=True) 45 | else: # if job didn't request a local scratch directory, use global scratch 46 | tmp_dir = tempfile.mkdtemp() 47 | os.environ["PSI_SCRATCH"] = tmp_dir 48 | psi4_io.set_default_path(tmp_dir) 49 | print(f"Using tempfile scratch {tmp_dir}.", flush=True) 50 | # define psi4 settings 51 | psi4.core.clean() 52 | psi4.set_num_threads(num_cores) 53 | psi4.set_memory(f"{int(memory)} GB") 54 | PSI4_OPTIONS["basis"] = basis_set 55 | psi4.set_options(PSI4_OPTIONS) 56 | 57 | # calculate wavefunction 58 | E, wfn = psi4.energy(level_of_theory, return_wfn=True, molecule=p4mol) 59 | psi4.core.clean() 60 | if not local_scratch: 61 | rmtree(tmp_dir) 62 | wfn_file = os.path.join(os.path.dirname(psi4_input_pickle), "wfn.npy") 63 | fchk_savepath = os.path.join(os.path.dirname(psi4_input_pickle), "wfn.fchk") 64 | wfn.to_file(wfn_file) 65 | psi4.fchk(wfn, fchk_savepath) 66 | return fchk_savepath 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument("--psi4_input_pickle", type=str) 72 | parser.add_argument("--memory", type=int, default=8) 73 | parser.add_argument("--num_cores", type=int, default=4) 74 | parser.add_argument("--level_of_theory", type=str, default=LEVEL_OF_THEORY) 75 | parser.add_argument("--basis_set", type=str, default=PSI4_OPTIONS["basis"]) 76 | args = parser.parse_args() 77 | fchk_savepath = os.path.join(os.path.dirname(args.psi4_input_pickle), "wfn.fchk") 78 | npy_savepath = os.path.join(os.path.dirname(args.psi4_input_pickle), "wfn.npy") 79 | if os.path.exists(fchk_savepath) or os.path.exists(npy_savepath): 80 | print(f"Wavefunction already exists: {fchk_savepath}", flush=True) 81 | else: 82 | fchk_savepath = compute_wfn_psi4( 83 | args.psi4_input_pickle, 84 | memory=int(args.memory), 85 | num_cores=int(args.num_cores), 86 | level_of_theory=args.level_of_theory, 87 | basis_set=args.basis_set, 88 | ) 89 | print(f"Saved wavefunction to {fchk_savepath}") 90 | -------------------------------------------------------------------------------- /bcpaff/qm/compute_wfn_xtb.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import argparse 6 | import json 7 | import os 8 | import subprocess 9 | 10 | from bcpaff.utils import ROOT_PATH 11 | 12 | XTB_INPUT_FILE = os.path.join(ROOT_PATH, "bcpaff", "qm", "xtb.inp") 13 | XTB_ENV = { 14 | "OMP_STACKSIZE": "4G", 15 | "OMP_NUM_THREADS": "1", 16 | "OMP_MAX_ACTIVE_LEVELS": "1", 17 | "MKL_NUM_THREADS": "1", 18 | } 19 | XTB_BINARY = os.path.join(os.environ.get("CONDA_PREFIX"), "bin", "xtb") 20 | 21 | 22 | def check_xtb_uhf(cmd_line_output): 23 | with open(cmd_line_output, "r") as f: 24 | for line in f: 25 | if line.startswith(" spin :"): 26 | spin = float(line.rstrip("\n").split()[-1]) 27 | break 28 | if spin != 0.0: 29 | error_file = os.path.join(os.path.dirname(cmd_line_output), "uhf_error.txt") 30 | with open(error_file, "w") as f: 31 | f.write(f"spin = {spin}") 32 | 33 | 34 | def compute_wfn_xtb(xyz_path, implicit_solvent=None): 35 | basepath = os.path.dirname(xyz_path) 36 | 37 | json_path = os.path.join(basepath, "chrg_uhf.json") 38 | with open(json_path, "r") as f: 39 | chrg_uhf = json.load(f) 40 | 41 | cmd_line_output = os.path.join(basepath, "xtb_cmd_out.log") 42 | f = open(cmd_line_output, "w+") # write a new file each time 43 | cmd = [ 44 | XTB_BINARY, 45 | xyz_path, 46 | "--input", 47 | XTB_INPUT_FILE, 48 | "--chrg", 49 | str(chrg_uhf["charge"]), 50 | "--uhf", 51 | str(chrg_uhf["num_unpaired_electrons"]), 52 | "--molden", 53 | "--iterations", 54 | "10000", 55 | ] 56 | if implicit_solvent is not None: 57 | cmd += [f"--{implicit_solvent.split('_')[0]}", implicit_solvent.split("_")[1]] 58 | # e.g. turn "alpb_water" into "--alpb water" 59 | completed_process = subprocess.run( 60 | cmd, 61 | stdout=f, 62 | stderr=subprocess.STDOUT, 63 | cwd=basepath, 64 | env=XTB_ENV, 65 | ) 66 | return_code = completed_process.returncode 67 | check_xtb_uhf(cmd_line_output) 68 | 69 | if return_code != 0: 70 | raise ValueError(f"{xyz_path} failed with {return_code}") 71 | f.close() 72 | molden_file = os.path.join(basepath, "molden.input") 73 | return molden_file 74 | 75 | 76 | if __name__ == "__main__": 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument("--xyz_path", type=str, required=True) 79 | args = parser.parse_args() 80 | molden_savepath = os.path.join(os.path.dirname(args.xyz_path), "molden.input") 81 | if os.path.exists(molden_savepath): 82 | print(f"molden.input already exists: {molden_savepath}", flush=True) 83 | else: 84 | molden_file = compute_wfn_xtb(args.xyz_path) 85 | -------------------------------------------------------------------------------- /bcpaff/qm/prepare_cluster_job.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import argparse 6 | import glob 7 | import os 8 | 9 | from bcpaff.qm.compute_wfn_psi4 import LEVEL_OF_THEORY, PSI4_OPTIONS 10 | 11 | 12 | def prepare_cluster_job_psi4(folders, cmd_file_out, args): 13 | python_file = os.path.join(os.path.dirname(__file__), "compute_wfn.py") 14 | cmd_strs = [] 15 | for folder in folders: 16 | pdb_id = os.path.basename(folder) 17 | pocket_sdf = os.path.join(folder, f"{pdb_id}_pocket_with_hydrogens.sdf") 18 | ligand_sdf = os.path.join(folder, f"{pdb_id}_ligand_with_hydrogens.sdf") 19 | cmd_str = f"python {python_file} --ligand_sdf {ligand_sdf} --pocket_sdf {pocket_sdf} --memory {args.memory} --num_cores {args.num_cores} --level_of_theory {args.level_of_theory} --basis_set {args.basis_set}\n" 20 | cmd_strs.append(cmd_str) 21 | with open(cmd_file_out, "w") as f: 22 | f.writelines(cmd_strs) 23 | print(f"Wrote output to {cmd_file_out}") 24 | memory_per_core_in_mb = int(args.memory / args.num_cores * 1024) 25 | bsub_cmd = ( 26 | """bsub -W 24:00 -R "rusage[mem=""" 27 | + str(memory_per_core_in_mb) 28 | + """,scratch=50000]" -n """ 29 | + str(args.num_cores) 30 | ) 31 | bsub_cmd += ( 32 | """ -J "qm_bcpaff[1-""" 33 | + str(len(cmd_strs)) 34 | + """]" "awk -v jindex=\$LSB_JOBINDEX 'NR==jindex' """ 35 | + cmd_file_out 36 | + """ | bash" """ 37 | ) 38 | return bsub_cmd 39 | 40 | 41 | def prepare_cluster_job_xtb(folders, cmd_file_out, args): 42 | python_file = os.path.join(os.path.dirname(__file__), "compute_wfn_xtb.py") 43 | cmd_strs = [] 44 | for folder in folders: 45 | pdb_id = os.path.basename(folder) 46 | pocket_sdf = os.path.join(folder, f"{pdb_id}_pocket_with_hydrogens.sdf") 47 | ligand_sdf = os.path.join(folder, f"{pdb_id}_ligand_with_hydrogens.sdf") 48 | cmd_str = f"python {python_file} --ligand_sdf {ligand_sdf} --pocket_sdf {pocket_sdf}\n" 49 | cmd_strs.append(cmd_str) 50 | with open(cmd_file_out, "w") as f: 51 | f.writelines(cmd_strs) 52 | print(f"Wrote output to {cmd_file_out}") 53 | bsub_cmd = """ 54 | bsub -W 4:00 -R "rusage[mem=10000]" 55 | """ 56 | bsub_cmd += ( 57 | """ -J "qm_bcpaff[1-""" 58 | + str(len(cmd_strs)) 59 | + """]" "awk -v jindex=\$LSB_JOBINDEX 'NR==jindex' """ 60 | + cmd_file_out 61 | + """ | bash" """ 62 | ) 63 | return bsub_cmd 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument("search_path", type=str) 69 | parser.add_argument("method", type=str, default="xtb") 70 | parser.add_argument("--memory", type=int, default=8) 71 | parser.add_argument("--num_cores", type=int, default=4) 72 | parser.add_argument("--level_of_theory", type=str, default=LEVEL_OF_THEORY) 73 | parser.add_argument("--basis_set", type=str, default=PSI4_OPTIONS["basis"]) 74 | args = parser.parse_args() 75 | folders = sorted(glob.glob(os.path.join(args.search_path, "*" + os.path.sep))) 76 | cmd_file_out = os.path.join(os.getcwd(), "qm_commands.txt") 77 | if os.path.exists(cmd_file_out): 78 | user_input = "" 79 | while user_input not in ["y", "n"]: 80 | user_input = input("Output file exists. Overwrite?").lower() 81 | if user_input == "n": 82 | raise ValueError(f"Commands file already exists") 83 | 84 | if args.method == "xtb": 85 | bsub_cmd = prepare_cluster_job_xtb(folders, cmd_file_out, args) 86 | elif args.method == "psi4": 87 | bsub_cmd = prepare_cluster_job_psi4(folders, cmd_file_out, args) 88 | else: 89 | raise ValueError(f"args.method must be xtb or psi4, you chose {args.method}") 90 | 91 | print("Use this to run your job:\n\n") 92 | print(f"conda activate bcpaff_{args.method}\n") 93 | print(bsub_cmd) 94 | -------------------------------------------------------------------------------- /bcpaff/qm/xtb.inp: -------------------------------------------------------------------------------- 1 | $write 2 | json=true -------------------------------------------------------------------------------- /bcpaff/qtaim/critic2_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import os 6 | import subprocess 7 | 8 | from bcpaff.qtaim.qtaim_reader import QtaimPropsCritic2 9 | from bcpaff.utils import DATA_PATH, ROOT_PATH 10 | 11 | CRITIC2_BINARY = os.path.join(ROOT_PATH, "critic2", "src", "critic2") 12 | CRITIC2_ENV = {"CRITIC_HOME": os.path.join(ROOT_PATH, "critic2")} 13 | 14 | 15 | def generate_critic2_instructions(basepath): 16 | """ 17 | Generate instructions for critic2 calculation (search for critical points in DFTB+ wavefunction) 18 | """ 19 | cmd_str = f""" 20 | molecule {os.path.join(basepath, "pl_complex.xyz")} 21 | zpsp h 1 c 4 n 5 o 6 f 7 p 5 s 6 cl 7 br 7 i 7 22 | load {os.path.join(basepath, "detailed.xml")} {os.path.join(basepath, "eigenvec.bin")} {os.path.join(DATA_PATH, "dftb+/recipes/slakos/download/3ob/3ob-3-1/wfc.3ob-3-1.hsd")} core 23 | auto 24 | cpreport {os.path.join(basepath, "cps.xyz")} 25 | """ 26 | instructions_file = os.path.join(basepath, "input.cri") 27 | with open(instructions_file, "w") as f: 28 | f.write(cmd_str) 29 | return instructions_file 30 | 31 | 32 | def run_critic2_analysis(wfn_file): 33 | """Run critic2 to find critical points (Multiwfn doesn't work for DFTB+ generated wavefunctions.) 34 | No need to specify charges, this is already read from the DFTB+ wavefunction file (see email 03.02.23) 35 | 36 | Parameters 37 | ---------- 38 | wfn_file : str 39 | detailed.xml file from DFTB+ 40 | 41 | Returns 42 | ------- 43 | str 44 | path to output.cri file with CP information 45 | 46 | Raises 47 | ------ 48 | ValueError 49 | in case subprocess returns non-zero return code 50 | """ 51 | basepath = os.path.dirname(wfn_file) 52 | output_cri = os.path.join(basepath, "output.cri") 53 | f_out = open(output_cri, "w+") # write a new file each time 54 | instructions_file = generate_critic2_instructions( 55 | basepath, 56 | ) 57 | completed_process = subprocess.run( 58 | [CRITIC2_BINARY, instructions_file], stdout=f_out, stderr=subprocess.STDOUT, cwd=basepath, env=CRITIC2_ENV 59 | ) 60 | f_out.close() 61 | return_code = completed_process.returncode 62 | if return_code != 0: 63 | raise ValueError(f"{wfn_file} failed with return_code {return_code}") 64 | return output_cri 65 | -------------------------------------------------------------------------------- /bcpaff/qtaim/generate_critical_points.txt: -------------------------------------------------------------------------------- 1 | 2 # Topology analysis 2 | 2 # Search CPs from nuclear positions 3 | 3 # Search CPs from midpoint of atom pairs 4 | 8 # Generating the paths connecting (3,-3) and (3,-1) CPs 5 | -5 # Modify or print detail or export paths, or plot property along a path 6 | 8 # Only retain bond paths (and corresponding CPs) connecting two specific molecular fragments while remove all other bond paths and BCPs 7 | 1-32 # atoms belonging to first fragment 8 | 33-20000 # atoms belonging to second fragment 9 | y # remove corresponding BCPs 10 | 4 # Save points of all paths to paths.txt in current folder 11 | 0 # Return 12 | -4 # Modify or export CPs (critical points) 13 | 4 # Save CPs to CPs.txt in current folder 14 | 0 # Return 15 | 7 # Show real space function values at specific CP or all CPs 16 | 0 # All properties (-1 to skip ESP → makes it faster) 17 | -10 # Return to main menu 18 | q # Exit program gracefully 19 | 20 | # OUTPUTS 21 | # CPProp.txt → real-space values of CPs 22 | # CP.txt → coordinates of CPs (might already be contained in CPProp.txt) 23 | # paths.txt → path connectivity -------------------------------------------------------------------------------- /bcpaff/qtaim/multiwfn_commands.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | 6 | def find_cps_and_paths(): 7 | return """ 8 | 2 # Topology analysis 9 | 2 # Search CPs from nuclear positions 10 | 3 # Search CPs from midpoint of atom pairs 11 | 8 # Generating the paths connecting (3,-3) and (3,-1) CPs 12 | -10 # Return to main menu 13 | """ 14 | 15 | 16 | def keep_only_intermolecular(atom_idxs1, atom_idxs2): 17 | return f""" 18 | 2 # Topology analysis 19 | -5 # Modify or print detail or export paths, or plot property along a path 20 | 8 # Only retain bond paths (and corresponding CPs) connecting two specific molecular fragments while remove all other bond paths and BCPs 21 | {atom_idxs1} # atoms belonging to first fragment 22 | {atom_idxs2} # atoms belonging to second fragment 23 | y # remove corresponding BCPs 24 | 0 # Return 25 | -10 # Return to main menu 26 | """ 27 | 28 | 29 | def remove_all_but_bcps(): 30 | return """ 31 | 2 # Topology analysis 32 | -4 # Modify or export CPs (critical points) 33 | 2 # Delete some CPs 34 | 3 # Delete all (3,-3) CPs 35 | 5 # Delete all (3,+1) CPs 36 | 6 # Delete all (3,+3) CPs 37 | 0 # Return 38 | 0 # Return 39 | -10 # Return to main menu 40 | """ 41 | 42 | 43 | def save_paths(): 44 | return """ 45 | 2 # Topology analysis 46 | -5 # Modify or print detail or export paths, or plot property along a path 47 | 4 # Save points of all paths to paths.txt in current folder 48 | 0 # Return 49 | -10 # Return to main menu 50 | """ 51 | 52 | 53 | def save_cps(include_esp=True): 54 | props = 0 if include_esp else -1 55 | return f""" 56 | 2 # Topology analysis 57 | -4 # Modify or export CPs (critical points) 58 | 4 # Save CPs to CPs.txt in current folder 59 | 0 # Return 60 | 7 # Show real space function values at specific CP or all CPs 61 | {props} # 0 for all properties, -1 to skip ESP --> makes it faster 62 | -10 # Return to main menu 63 | """ 64 | 65 | 66 | def save_cps_to_pdb(): 67 | return """ 68 | 2 # Topology analysis 69 | -4 # Modify or export CPs (critical points) 70 | 6 # Export CPs as CPs.pdb file in current folder 71 | 0 # Return 72 | -10 # Return to main menu 73 | """ 74 | 75 | 76 | def save_paths_to_pdb(): 77 | return """ 78 | 2 # Topology analysis 79 | -5 # Modify or print detail or export paths, or plot property along a path 80 | 6 # Export paths as paths.pdb file in current folder 81 | 0 # Return 82 | -10 # Return to main menu 83 | """ 84 | 85 | 86 | def exit_gracefully(): 87 | return "q # Exit program gracefully" 88 | -------------------------------------------------------------------------------- /bcpaff/qtaim/multiwfn_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import argparse 6 | import os 7 | import subprocess 8 | from sys import platform 9 | 10 | from rdkit import Chem 11 | 12 | from bcpaff.qtaim import multiwfn_commands as cmds 13 | from bcpaff.utils import ROOT_PATH 14 | 15 | if platform == "linux" or platform == "linux2": 16 | MULTIWFN_BINARY = os.path.join(ROOT_PATH, "multiwfn", "Multiwfn_noGUI") 17 | elif platform == "darwin": 18 | MULTIWFN_BINARY = os.path.join(ROOT_PATH, "multiwfn", "Multiwfn_noGUI") 19 | MULTIWFN_ENV = {"Multiwfnpath": os.path.dirname(MULTIWFN_BINARY), "OMP_STACKSIZE": "200M"} 20 | 21 | 22 | def generate_instructions(basepath, only_intermolecular, only_bcps, num_ligand_atoms, include_esp=True): 23 | outstring = cmds.find_cps_and_paths() 24 | if only_intermolecular: 25 | atom_idxs1 = f"1-{num_ligand_atoms}" # ligand needs to be first in wavefunction geometry 26 | atom_idxs2 = f"{num_ligand_atoms+1}-99999" # everything else 27 | outstring += cmds.keep_only_intermolecular(atom_idxs1, atom_idxs2) 28 | if only_bcps: 29 | outstring += cmds.remove_all_but_bcps() 30 | outstring += cmds.save_paths() 31 | outstring += cmds.save_cps(include_esp=include_esp) 32 | outstring += cmds.save_cps_to_pdb() 33 | outstring += cmds.save_paths_to_pdb() 34 | outstring += cmds.exit_gracefully() 35 | outfile = os.path.join(basepath, "multiwfn_instructions.txt") 36 | with open(outfile, "w") as f: 37 | f.write(outstring) 38 | return outfile 39 | 40 | 41 | def run_multiwfn_analysis(wfn_file, only_intermolecular=False, only_bcps=False, num_ligand_atoms=-1, include_esp=True): 42 | basepath = os.path.dirname(wfn_file) 43 | multiwfn_cmd_out = os.path.join(basepath, "multiwfn_cmd_out.log") 44 | f_out = open(multiwfn_cmd_out, "w+") # write a new file each time 45 | instructions = generate_instructions( 46 | basepath, 47 | only_intermolecular=only_intermolecular, 48 | only_bcps=only_bcps, 49 | num_ligand_atoms=num_ligand_atoms, 50 | include_esp=include_esp, 51 | ) 52 | f_in = open(instructions, "r") 53 | completed_process = subprocess.run( 54 | [MULTIWFN_BINARY, wfn_file], stdin=f_in, stdout=f_out, stderr=subprocess.STDOUT, cwd=basepath, env=MULTIWFN_ENV 55 | ) 56 | f_in.close() 57 | f_out.close() 58 | return_code = completed_process.returncode 59 | if return_code != 0: 60 | raise ValueError(f"{wfn_file} failed with return_code {return_code}") 61 | cp_file = os.path.join(basepath, "CPs.txt") 62 | paths_file = os.path.join(basepath, "paths.txt") 63 | cpprop_file = os.path.join(basepath, "CPprop.txt") 64 | return cp_file, cpprop_file, paths_file 65 | 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument("--wfn_file", type=str, required=True) 70 | parser.add_argument("--ligand_sdf", type=str, required=True) 71 | parser.add_argument("--only_intermolecular", action="store_true", default=False) 72 | parser.add_argument("--only_bcps", action="store_true", default=False) 73 | parser.add_argument("--no_esp", action="store_false", default=True, dest="include_esp") 74 | args = parser.parse_args() 75 | cp_file, cpprop_file, paths_file = run_multiwfn_analysis( 76 | args.wfn_file, 77 | only_intermolecular=args.only_intermolecular, 78 | only_bcps=args.only_bcps, 79 | num_ligand_atoms=next(Chem.SDMolSupplier(args.ligand_sdf, removeHs=False, sanitize=False)).GetNumAtoms(), 80 | include_esp=args.include_esp, 81 | ) 82 | -------------------------------------------------------------------------------- /bcpaff/qtaim/qtaim_reader.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import os 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from rdkit import Chem 10 | from scipy.spatial.distance import cdist 11 | 12 | from bcpaff.utils import ANGSTROM2BOHR, BOHR2ANGSTROM 13 | 14 | SEP = "----------------" 15 | 16 | ESP_OFFSET = -2 17 | SIGMA_TO_NAME = { 18 | -3: "nucleus_critical_point", 19 | -1: "bond_critical_point", 20 | 1: "ring_critical_point", 21 | 3: "cage_critical_point", 22 | } 23 | QTAIM_UNITS = { 24 | "density": "electrons/Angstrom^3", 25 | "laplacian": "electrons/Angstrom^5", 26 | "elf": "-", 27 | "lol": "-", 28 | "rdg": "-", 29 | "sign_lap_rho": "electrons/Angstrom^3", 30 | "gradient_norm": "electrons/Angstrom^4", 31 | "hessian_eigenvalues": "-", 32 | "ellipticity": "-", 33 | "eta_index": "-", 34 | "esp": "J/C = eV/e", 35 | "position": "Angstrom", 36 | } 37 | 38 | ABBREVIATIONS = { 39 | "nucleus_critical_point": "NCP", 40 | "bond_critical_point": "BCP", 41 | "ring_critical_point": "RCP", 42 | "cage_critical_point": "CCP", 43 | } 44 | 45 | 46 | class QtaimProps(object): 47 | def __init__( 48 | self, 49 | basepath=None, 50 | cp_file=None, 51 | cpprop_file=None, 52 | paths_file=None, 53 | ligand_sdf=None, 54 | pl_complex_xyz=None, 55 | identifier=None, 56 | ): 57 | self.basepath = basepath 58 | self.identifier = identifier 59 | self.cp_file = cp_file 60 | self.paths_file = paths_file 61 | self.cpprop_file = cpprop_file 62 | self.pl_complex_xyz = pl_complex_xyz 63 | self.ligand_sdf = ligand_sdf 64 | if self.cp_file is None: 65 | self.identifier = os.path.basename(self.basepath) 66 | self.cp_file = os.path.join(self.basepath, "CPs.txt") 67 | self.paths_file = os.path.join(self.basepath, "paths.txt") 68 | self.cpprop_file = os.path.join(self.basepath, "CPprop.txt") 69 | self.pl_complex_xyz = os.path.join(self.basepath, "pl_complex.xyz") 70 | self.ligand_sdf = os.path.join(self.basepath, f"{self.identifier}_ligand_with_hydrogens.sdf") 71 | self.atom_ids, self.pl_complex_coords = self._get_pl_complex_info(self.pl_complex_xyz) 72 | self.ligand = next(Chem.SDMolSupplier(self.ligand_sdf, removeHs=False, sanitize=False)) 73 | self.natoms_ligand = self.ligand.GetNumAtoms() 74 | self.critical_points = [] 75 | self._read_critical_points() 76 | self.cp_positions = np.vstack([cp.position for cp in self.critical_points]) 77 | self._read_paths() 78 | self.num_cps = len(self.critical_points) 79 | self._get_df() 80 | 81 | def _get_df(self): 82 | df = pd.DataFrame([cp.idx for cp in self.critical_points], columns=["idx"]) 83 | df = pd.concat( 84 | [df, pd.DataFrame([ABBREVIATIONS[cp.name] for cp in self.critical_points], columns=["point_name"])], axis=1 85 | ) 86 | df = pd.concat( 87 | [df, pd.DataFrame([cp.intermolecular for cp in self.critical_points], columns=["intermolecular"])], axis=1 88 | ) 89 | df = pd.concat([df, pd.DataFrame([cp.props for cp in self.critical_points])], axis=1) 90 | df = pd.concat([df, pd.DataFrame(self.cp_positions, columns=["x", "y", "z"])], axis=1) 91 | atom_neighbors = [] 92 | for cp in self.critical_points: 93 | if len(cp.atom_neighbors) == 0: 94 | atom_neighbors.append([float("NaN"), float("NaN")]) 95 | elif len(cp.atom_neighbors) == 1: 96 | atom_neighbors.append(cp.atom_neighbors + [float("NaN")]) 97 | elif len(cp.atom_neighbors) == 2: 98 | atom_neighbors.append(sorted(cp.atom_neighbors)) 99 | else: 100 | raise ValueError("Incorrect number of atom neighbors") 101 | df = pd.concat([df, pd.DataFrame(atom_neighbors, columns=["atom_neighbor_1", "atom_neighbor_2"])], axis=1) 102 | self.df = df 103 | 104 | def _get_pl_complex_info(self, pl_complex_xyz): 105 | with open(pl_complex_xyz, "r") as f: 106 | lines = [l.rstrip("\n").split() for l in f.readlines()[2:]] 107 | atom_ids, coords = zip(*[(l[0], [float(x) for x in l[1:]]) for l in lines]) 108 | return np.asarray(atom_ids), np.asarray(coords) 109 | 110 | def _read_critical_points(self): 111 | with open(self.cpprop_file, "r") as f: 112 | lines = [line.lstrip(" ").rstrip("\n") for line in f.readlines()] 113 | self.include_esp = True if any([line.startswith("Total ESP") for line in lines]) else False 114 | split_idx = [i for i, line in enumerate(lines) if (line.startswith(SEP) and line.endswith(SEP))] 115 | split_idx.remove(0) 116 | blocks = [lines[i:j] for i, j in zip([0] + split_idx, split_idx + [None])] 117 | self.num_paths = len(blocks) 118 | for block in blocks: 119 | cp = CriticalPoint( 120 | block, self.include_esp, pl_complex_coords=self.pl_complex_coords, atom_ids=self.atom_ids 121 | ) 122 | if cp.name == "nucleus_critical_point" and cp.corresponding_atom_symbol == "Unknown": 123 | continue # don't append NCPs with unknown atom 124 | self.critical_points.append(cp) 125 | 126 | def _read_paths(self): 127 | with open(self.paths_file, "r") as f: 128 | lines = [line.lstrip(" ").rstrip("\n") for line in f.readlines()] 129 | split_idx = [i for i, line in enumerate(lines) if line == ""] 130 | split_idx.remove(1) 131 | blocks = [lines[i:j] for i, j in zip([1] + split_idx, split_idx + [None])] 132 | for block in blocks: 133 | path_positions = np.array([[x for x in line.split()] for line in block[3:]]).astype(float) * BOHR2ANGSTROM 134 | 135 | # path always starts at critical point (path_positions[0, None]) and ends at atom (path_positions[-1, None]) 136 | atom_id = cdist(path_positions[-1, None], self.pl_complex_coords).argmin() 137 | cp_id = cdist(path_positions[0, None], self.cp_positions).argmin() 138 | self.critical_points[cp_id]._add_path(atom_id, self.atom_ids[atom_id], path_positions, self.natoms_ligand) 139 | 140 | 141 | class CriticalPoint(object): 142 | def __init__(self, block, include_esp, pl_complex_coords, atom_ids): 143 | self.include_esp = include_esp 144 | self.esp_offset = 0 if include_esp else ESP_OFFSET 145 | self.pl_complex_coords = pl_complex_coords 146 | self.atom_ids = atom_ids 147 | self.atom_neighbors = [] 148 | self.atom_neighbors_symbol = [] 149 | self.path_positions = None 150 | self._read_block(block) 151 | self.intermolecular = False # can be overwritten by _add_paths later 152 | 153 | def _read_block(self, block): 154 | for line in block: 155 | if line.startswith(SEP): 156 | self.idx = int(line.strip("- ").split("Type")[0].strip("CP, ")) 157 | omega, sigma = line.strip("- ").split("Type")[1].strip("() ").split(",") 158 | self.omega = int(omega) 159 | self.sigma = int(sigma) 160 | self.name = SIGMA_TO_NAME[self.sigma] 161 | elif line.startswith("Corresponding nucleus:"): # only for nucleus-critical points 162 | corresponding_nucleus = line.lstrip("Corresponding nucleus: ").split("(")[0] 163 | if corresponding_nucleus == "Unknown": # no corresponding nucleus found 164 | self.corresponding_atom_symbol = "Unknown" # figured this out later from position 165 | else: 166 | self.corresponding_atom_id = int(corresponding_nucleus) - 1 # zero indexing 167 | self.corresponding_atom_symbol = line.lstrip("Corresponding nucleus: ").split("(")[-1].rstrip(" )") 168 | elif line.startswith("Position (Bohr):"): 169 | self.position = np.array([float(x) * BOHR2ANGSTROM for x in line.lstrip("Position (Bohr): ").split()]) 170 | if ( 171 | self.name == "nucleus_critical_point" and self.corresponding_atom_symbol == "Unknown" 172 | ): # need to fix 173 | distance_matrix = cdist(np.expand_dims(self.position, axis=0), self.pl_complex_coords) 174 | assert distance_matrix.min() < 0.3 # sanity check, 0.3 Angstrom = manually inspected cutoff 175 | self.corresponding_atom_id = distance_matrix.argmin() 176 | self.corresponding_atom_symbol = self.atom_ids[self.corresponding_atom_id] 177 | elif line.startswith("Density of all electrons:"): 178 | self.density = float(line.lstrip("Density of all electrons: ")) * (ANGSTROM2BOHR ** 3) 179 | elif line.startswith("Laplacian of electron density:"): 180 | self.laplacian = float(line.lstrip("Laplacian of electron density: ")) * (ANGSTROM2BOHR ** 5) 181 | elif line.startswith("Electron localization function (ELF):"): 182 | self.elf = float(line.lstrip("Electron localization function (ELF): ")) 183 | elif line.startswith("Localized orbital locator (LOL):"): 184 | self.lol = float(line.lstrip("Localized orbital locator (LOL): ")) 185 | elif line.startswith("Reduced density gradient (RDG):"): 186 | self.rdg = float(line.lstrip("Reduced density gradient (RDG): ")) 187 | elif line.startswith("Sign(lambda2)*rho:"): 188 | self.sign_lap_rho = float(line.lstrip("Sign(lambda2)*rho: ")) * (ANGSTROM2BOHR ** 3) 189 | elif line.startswith("Total ESP:") and self.include_esp: 190 | self.esp = float(line.lstrip("Total ESP: ").split("(")[0].split("a.u.")[0].strip(" ")) 191 | elif line.startswith("ESP from nuclear charges: ") and self.include_esp: 192 | self.esp_nuc = float(line.lstrip("ESP from nuclear charges: ")) 193 | elif line.startswith("ESP from electrons:") and self.include_esp: 194 | self.esp_ele = float(line.lstrip("ESP from electrons: ")) 195 | elif line.startswith("Norm of gradient is:"): 196 | self.gradient_norm = float(line.lstrip("Norm of gradient is: ")) * (ANGSTROM2BOHR ** 4) 197 | elif line.startswith("Eigenvalues of Hessian:"): 198 | matrix = np.array([float(x) for x in line.lstrip("Eigenvalues of Hessian: ").split()]) 199 | self.hessian_eigenvalues = matrix 200 | elif line.startswith("Ellipticity of electron density:"): 201 | self.ellipticity = float(line.lstrip("Ellipticity of electron density: ")) 202 | elif line.startswith("eta index:"): 203 | self.eta_index = float(line.lstrip("eta index: ")) 204 | 205 | self.props = { 206 | "density": self.density, 207 | "laplacian": self.laplacian, 208 | "elf": self.elf, 209 | "lol": self.lol, 210 | "rdg": self.rdg, 211 | "sign_lap_rho": self.sign_lap_rho, 212 | "gradient_norm": self.gradient_norm, 213 | "hessian_eigenvalues": self.hessian_eigenvalues, 214 | "ellipticity": self.ellipticity, 215 | "eta_index": self.eta_index, 216 | } 217 | if self.include_esp: 218 | self.props.update({"esp": self.esp, "esp_nuc": self.esp_nuc, "esp_ele": self.esp_ele}) 219 | 220 | def _add_path(self, atom_id, atom_symbol, path_positions, natoms_ligand): 221 | self.atom_neighbors.append(atom_id) 222 | self.atom_neighbors_symbol.append(atom_symbol) 223 | if self.path_positions is None: 224 | self.path_positions = np.flip( 225 | path_positions, axis=0 226 | ) # so that complete array goes from atom1 --> BCP --> atom2 227 | # the bond path doesn't end perfectly at the atom, but just before (distance ca. 0.05 Angstrom) 228 | else: 229 | self.path_positions = np.vstack([self.path_positions, path_positions]) 230 | if sum([a <= natoms_ligand - 1 for a in self.atom_neighbors]) == 1: 231 | # one and only one atom neighbor belongs to ligand --> intermolecular 232 | self.intermolecular = True 233 | else: 234 | # either none or both atom neighbors belong to ligand --> not intermolecular 235 | self.intermolecular = False 236 | 237 | 238 | class CriticalPointCritic2(object): 239 | def __init__(self, i, row, pl_complex_coords, atom_ids): 240 | self.idx = i 241 | assert self.idx == row.cp_idx - 1 # zero indexing 242 | self.pl_complex_coords = pl_complex_coords 243 | self.atom_ids = atom_ids 244 | self.atom_neighbors = [] 245 | self.atom_neighbors_symbol = [] 246 | self.path_positions = None 247 | # self._read_block(block) 248 | self.intermolecular = False # can be overwritten by _add_paths later 249 | self.density = row.edens * (ANGSTROM2BOHR ** 3) # convert to electrons/Angstrom^3 etc. 250 | self.gradient_norm = row.grad * (ANGSTROM2BOHR ** 4) 251 | self.laplacian = row.lap * (ANGSTROM2BOHR ** 5) 252 | omega, sigma = row.type.strip("()").split(",") 253 | self.omega, self.sigma = int(omega), int(sigma) 254 | self.name = SIGMA_TO_NAME[self.sigma] 255 | if row.name == "nucleus_critical_point": 256 | self.corresponding_atom_id = self.idx # already zero-indexed above 257 | self.corresponding_atom_symbol = row.type_name 258 | assert self.atom_ids[self.corresponding_atom_id] == self.corresponding_atom_symbol # sanity check 259 | self.position = row[["x", "y", "z"]].to_numpy().astype(float) 260 | self.props = { 261 | "density": self.density, 262 | "laplacian": self.laplacian, 263 | "gradient_norm": self.gradient_norm, 264 | } 265 | 266 | def _add_path(self, atom_ids, atom_symbols, path_positions, natoms_ligand): 267 | self.atom_neighbors = atom_ids 268 | self.atom_neighbors_symbol = atom_symbols 269 | self.path_positions = path_positions 270 | 271 | if sum([a <= natoms_ligand - 1 for a in self.atom_neighbors]) == 1: 272 | # one and only one atom neighbor belongs to ligand --> intermolecular 273 | self.intermolecular = True 274 | else: 275 | # either none or both atom neighbors belong to ligand --> not intermolecular 276 | self.intermolecular = False 277 | 278 | 279 | class QtaimPropsCritic2(object): 280 | def __init__(self, basepath=None, output_cri=None, pl_complex_xyz=None, ligand_sdf=None): 281 | self.basepath = basepath 282 | self.output_cri = output_cri if output_cri is not None else os.path.join(basepath, "output.cri") 283 | self.pl_complex_xyz = ( 284 | pl_complex_xyz if pl_complex_xyz is not None else os.path.join(basepath, "pl_complex.xyz") 285 | ) 286 | if ligand_sdf is not None: 287 | self.ligand_sdf = ligand_sdf 288 | self.ligand = next(Chem.SDMolSupplier(self.ligand_sdf, removeHs=False, sanitize=False)) 289 | else: 290 | self.ligand_sdf = self.pl_complex_xyz # everything is ligand 291 | self.ligand = Chem.rdmolfiles.MolFromXYZFile(self.ligand_sdf) 292 | self.natoms_ligand = self.ligand.GetNumAtoms() 293 | self.atom_ids, self.pl_complex_coords = self._get_pl_complex_info(self.pl_complex_xyz) 294 | 295 | self.critical_points = [] 296 | self._read_critical_points() 297 | self.cp_positions = np.vstack([cp.position for cp in self.critical_points]) 298 | self._read_paths() 299 | self.num_cps = len(self.critical_points) 300 | 301 | def _get_pl_complex_info(self, pl_complex_xyz): 302 | with open(pl_complex_xyz, "r") as f: 303 | lines = [l.rstrip("\n").split() for l in f.readlines()[2:]] 304 | atom_ids, coords = zip(*[(l[0], [float(x) for x in l[1:]]) for l in lines]) 305 | return np.asarray(atom_ids), np.asarray(coords) 306 | 307 | def _read_critical_points(self): 308 | with open(self.output_cri, "r") as f: 309 | lines = [line.lstrip(" ").rstrip("\n") for line in f.readlines()] 310 | SEP_CRITIC2 = "Poincare-Hopf sum:" 311 | start_idx = [i for i, line in enumerate(lines) if line.startswith(SEP_CRITIC2)] 312 | assert len(start_idx) == 1 313 | start_idx = start_idx[0] 314 | self.poincare_hopf_sum = int(lines[start_idx].split(": ")[-1]) 315 | self.found_all_cps = self.poincare_hopf_sum == 1 316 | if self.poincare_hopf_sum != 1: 317 | with open(os.path.join(os.path.dirname(self.output_cri), "missing_points"), "w") as f: 318 | f.write(f"Poincare-Hopf sum = {self.poincare_hopf_sum}") 319 | end_idx = lines.index("* Analysis of system bonds") 320 | df_lines = lines[start_idx + 2 : end_idx - 1] 321 | df_lines = [l.split() for l in df_lines] 322 | # sometimes splitting gets messed up because of "(3,1 )" 323 | for i, l in enumerate(df_lines): 324 | if len(l) == 11: # instead of the usual 10 325 | l = [l[0]] + [l[1] + l[2]] + l[3:] # string concat for elems 1 & 2 326 | df_lines[i] = l 327 | df = pd.DataFrame( 328 | df_lines, columns=["cp_idx", "type", "name", "x", "y", "z", "type_name", "edens", "grad", "lap"] 329 | ) 330 | dtypes = { 331 | "cp_idx": int, 332 | "type": str, 333 | "name": str, 334 | "x": float, 335 | "y": float, 336 | "z": float, 337 | "type_name": str, 338 | "edens": float, 339 | "grad": float, 340 | "lap": float, 341 | } 342 | df = df.astype(dtypes) 343 | 344 | for i, row in df.iterrows(): 345 | cp = CriticalPointCritic2(i, row, self.pl_complex_coords, self.atom_ids) 346 | self.critical_points.append(cp) 347 | 348 | def _read_paths(self): 349 | with open(self.output_cri, "r") as f: 350 | lines = [line.lstrip(" ").rstrip("\n") for line in f.readlines()] 351 | start_idx = lines.index("# ncp End-1 End-2 r1(ang_) r2(ang_) r1/r2 r1-B-r2 (degree)") 352 | end_idx = lines.index("* Analysis of system rings") 353 | df_lines = lines[start_idx + 1 : end_idx - 1] 354 | df_lines = [l.split()[:5] for l in df_lines] 355 | df_lines = [[x.strip("()") for x in l] for l in df_lines] 356 | df = pd.DataFrame( 357 | df_lines, columns=["cp_idx", "neighbor_type_1", "neighbor_id_1", "neighbor_type_2", "neighbor_id_2"] 358 | ) 359 | dtypes = { 360 | "cp_idx": int, 361 | "neighbor_type_1": str, 362 | "neighbor_id_1": int, 363 | "neighbor_type_2": str, 364 | "neighbor_id_2": int, 365 | } 366 | df = df.astype(dtypes) 367 | df.loc[:, "cp_idx"] = df.cp_idx - 1 # zero indexing 368 | df.loc[:, "neighbor_id_1"] = df.neighbor_id_1 - 1 # zero indexing 369 | df.loc[:, "neighbor_id_2"] = df.neighbor_id_2 - 1 # zero indexing 370 | 371 | for _, row in df.iterrows(): 372 | path_positions = self.cp_positions[[row.neighbor_id_1, row.cp_idx, row.neighbor_id_2]] 373 | self.critical_points[row.cp_idx]._add_path( 374 | [row.neighbor_id_1, row.neighbor_id_2], 375 | [row.neighbor_type_1, row.neighbor_type_2], 376 | path_positions, 377 | self.natoms_ligand, 378 | ) 379 | 380 | -------------------------------------------------------------------------------- /bcpaff/qtaim/qtaim_viewer.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import os 6 | 7 | import numpy as np 8 | import py3Dmol 9 | from rdkit import Chem 10 | 11 | 12 | def get_radii(x): 13 | MAX_RADIUS = 0.8 # very transparent, visually implies low importance 14 | MIN_RADIUS = 0.2 # not very transparent (quite solid color), visually implies high importance 15 | x = np.abs(x) 16 | x = x - np.min(x) 17 | x = x / np.max(x) 18 | x = x * (MAX_RADIUS - MIN_RADIUS) + MIN_RADIUS 19 | return x 20 | 21 | 22 | class QtaimViewer(object): 23 | def __init__( 24 | self, 25 | qtaim_props, 26 | only_intermolecular=True, 27 | detailed_paths=False, 28 | attributions_data=None, 29 | width=640, 30 | height=480, 31 | ): 32 | self.v = py3Dmol.view(width=width, height=height) 33 | with open(os.path.join(os.path.dirname(qtaim_props.ligand_sdf), "pl_complex.xyz"), "r") as f: 34 | xyz_str = f.read() 35 | self.v.addModel(xyz_str, "xyz") 36 | self.v.setStyle({"model": 0}, {"stick": {"colorscheme": "lightgrayCarbon", "radius": 0.1}}) 37 | self.v.addModel(Chem.MolToMolBlock(qtaim_props.ligand, kekulize=False), "mol") 38 | self.v.setStyle({"model": 1}, {"stick": {"colorscheme": "blackCarbon", "radius": 0.2}}) 39 | self.v.setBackgroundColor("white") 40 | self.v.zoomTo() 41 | for cp in qtaim_props.critical_points: 42 | if only_intermolecular: 43 | if not (cp.name == "bond_critical_point" and cp.intermolecular): 44 | continue 45 | if cp.name != "bond_critical_point": 46 | continue 47 | if detailed_paths: 48 | points = cp.path_positions 49 | else: 50 | points = cp.path_positions[np.round(np.linspace(0, len(cp.path_positions) - 1, 5)).astype(int)] 51 | # first, last, and some in between (need first & last to have proper attachment to atoms) 52 | points = [{key: val for (key, val) in zip(["x", "y", "z"], pos)} for pos in points] 53 | 54 | self.v.addCurve({"points": points, "radius": 0.05, "color": "yellow"}) 55 | self.v.addSphere( 56 | { 57 | "center": {key: val for (key, val) in zip(["x", "y", "z"], cp.position)}, 58 | "radius": 0.1, 59 | "color": "red", 60 | } 61 | ) 62 | if attributions_data is not None: 63 | # map attribution values to transparency values and colors 64 | radii = get_radii(attributions_data["attributions"]) 65 | colors = ["green" if x > 0 else "red" for x in attributions_data["attributions"]] 66 | 67 | # plot 68 | for xyz, c, r in zip(attributions_data["coords"], colors, radii): 69 | self.v.addSphere( 70 | { 71 | "center": {key: val for (key, val) in zip(["x", "y", "z"], xyz.tolist())}, 72 | "radius": float(r), 73 | "color": c, 74 | } 75 | ) 76 | 77 | def show(self): 78 | return self.v.show() 79 | -------------------------------------------------------------------------------- /bcpaff/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | © 2023, ETH Zurich 3 | """ 4 | 5 | import os 6 | 7 | import pandas as pd 8 | 9 | ROOT_PATH = os.path.dirname(os.path.dirname(__file__)) 10 | DATA_PATH = os.path.join(ROOT_PATH, "data") 11 | PROCESSED_DATA_PATH = os.path.join(ROOT_PATH, "processed_data") 12 | BASE_OUTPUT_DIR = os.path.join(PROCESSED_DATA_PATH, "model_runs") 13 | REPORT_PATH = os.path.join(PROCESSED_DATA_PATH, "reports") 14 | ANALYSIS_PATH = os.path.join(PROCESSED_DATA_PATH, "analysis") 15 | paths = [DATA_PATH, PROCESSED_DATA_PATH, BASE_OUTPUT_DIR, REPORT_PATH, ANALYSIS_PATH] 16 | for path in paths: 17 | os.makedirs(path, exist_ok=True) 18 | 19 | ELEMENTS = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I"] 20 | 21 | ATOM_NEIGHBOR_IDS = {"H": 0, "C": 1, "N": 2, "O": 3, "F": 4, "P": 5, "S": 6, "Cl": 7, "Br": 8, "I": 9, "*": 10} 22 | OTHER = len(ATOM_NEIGHBOR_IDS) # no +1 needed since 0-indexed 23 | METALS = [ 24 | "Li", 25 | "Be", 26 | "Na", 27 | "Mg", 28 | "Al", 29 | "K", 30 | "Ca", 31 | "Sc", 32 | "Ti", 33 | "V", 34 | "Cr", 35 | "Mn", 36 | "Fe", 37 | "Co", 38 | "Ni", 39 | "Cu", 40 | "Zn", 41 | "Ga", 42 | "Rb", 43 | "Sr", 44 | "Y", 45 | "Zr", 46 | "Nb", 47 | "Mo", 48 | "Tc", 49 | "Ru", 50 | "Rh", 51 | "Pd", 52 | "Ag", 53 | "Cd", 54 | "In", 55 | "Sn", 56 | "Cs", 57 | "Ba", 58 | "La", 59 | "Ce", 60 | "Pr", 61 | "Nd", 62 | "Pm", 63 | "Sm", 64 | "Eu", 65 | "Gd", 66 | "Tb", 67 | "Dy", 68 | "Ho", 69 | "Er", 70 | "Tm", 71 | "Yb", 72 | "Lu", 73 | "Hf", 74 | "Ta", 75 | "W", 76 | "Re", 77 | "Os", 78 | "Ir", 79 | "Pt", 80 | "Au", 81 | "Hg", 82 | "Tl", 83 | "Pb", 84 | "Bi", 85 | "Po", 86 | "Fr", 87 | "Ra", 88 | "Ac", 89 | "Th", 90 | "Pa", 91 | "U", 92 | "Np", 93 | "Pu", 94 | "Am", 95 | "Cm", 96 | "Bk", 97 | "Cf", 98 | "Es", 99 | "Fm", 100 | "Md", 101 | "No", 102 | "Lr", 103 | "Rf", 104 | "Db", 105 | "Sg", 106 | "Bh", 107 | "Hs", 108 | "Mt", 109 | "Ds", 110 | "Rg", 111 | "Cn", 112 | "Nh", 113 | "Fl", 114 | ] 115 | ELEMENT_NUMS = [1, 6, 7, 8, 9, 15, 16, 17, 35, 53] 116 | ATOM_DICT = { 117 | 0: 0, # BCPs 118 | 1: 1, # hydrogen 119 | 6: 2, # carbon 120 | 7: 3, # nitrogen 121 | 8: 4, # oxygen 122 | 9: 5, # fluorine 123 | 15: 6, # phosphorus 124 | 16: 7, # sulphure 125 | 17: 8, # chlorine 126 | 35: 9, # bromine 127 | 53: 10, # iodine 128 | } 129 | 130 | SEED = 1234 131 | BOHR2ANGSTROM = 0.529177249 132 | ANGSTROM2BOHR = 1 / BOHR2ANGSTROM 133 | 134 | DEFAULT_PROPS = [ 135 | "density", 136 | "laplacian", 137 | "elf", 138 | "lol", 139 | "rdg", 140 | "sign_lap_rho", 141 | "gradient_norm", 142 | "ellipticity", 143 | "eta_index", 144 | ] 145 | DEFAULT_PROPS_CRITIC2 = ["density", "laplacian", "gradient_norm"] 146 | HPARAMS = pd.read_csv(os.path.join(ROOT_PATH, "hparam_files", "hparams_bcp_props.csv")) 147 | DFTBPLUS_DATA_PATH = os.path.join(DATA_PATH, "dftb+") 148 | 149 | 150 | DATASETS_AND_SPLITS = { 151 | "pdbbind": ["random"], 152 | "pde10a": [ 153 | "random", 154 | "temporal_2011", 155 | "temporal_2012", 156 | "temporal_2013", 157 | "aminohetaryl_c1_amide", 158 | "c1_hetaryl_alkyl_c2_hetaryl", 159 | "aryl_c1_amide_c2_hetaryl", 160 | ], 161 | } 162 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: bcpaff 2 | channels: 3 | - pyg 4 | - pytorch 5 | - conda-forge 6 | - anaconda 7 | - defaults 8 | dependencies: 9 | - babel=2.9.1 10 | - black=22.1.0 11 | - click=8.0.2 12 | - cloudpickle=2.0.0 13 | - cpuonly=1.0 14 | - cudatoolkit=10.2 15 | - einops=0.4.1 16 | - htop=3.1.2 17 | - ipykernel=6.4.1 18 | - ipython=7.29.0 19 | - ipython_genutils=0.2.0 20 | - ipywidgets=7.6.5 21 | - joblib=1.1.0 22 | - jupyter=1.0.0 23 | - matplotlib=3.5.1 24 | - mkl=2021.4.0 25 | - networkx=2.6.3 26 | - notebook=6.4.12 27 | - numpy=1.20.3 28 | - numpy-base=1.20.3 29 | - openbabel=3.1.1 30 | - pandas=1.3.5 31 | - pickleshare=0.7.5 32 | - pip=21.2.2 33 | - py3dmol=1.8.0 34 | - python=3.7.11 35 | - pytorch=1.9.1 36 | - pytorch-scatter=2.0.9 37 | - pytorch-sparse=0.6.12 38 | - pytorch-spline-conv=1.2.1 39 | - rdkit=2021.09.4 40 | - requests=2.27.1 41 | - scikit-learn=1.0.2 42 | - scipy=1.6.2 43 | - tensorboard=2.9.1 44 | - tqdm=4.62.3 45 | - typed-ast=1.4.3 46 | - typing-extensions=4.3.0 47 | - werkzeug=2.0.3 48 | - xtb=6.4.1 49 | - yaml=0.2.5 50 | - pip: 51 | - torch-geometric==2.0.3 52 | -------------------------------------------------------------------------------- /env_psi4.yml: -------------------------------------------------------------------------------- 1 | name: bcpaff_psi4_new 2 | channels: 3 | - psi4 4 | - pyg 5 | - pytorch 6 | - conda-forge 7 | - anaconda 8 | - defaults 9 | dependencies: 10 | - babel=2.9.1 11 | - psi4=1.7 12 | - black=22.1.0 13 | - click=8.0.2 14 | - cloudpickle=2.0.0 15 | - cpuonly=1.0 16 | - einops=0.4.1 17 | - htop=3.1.2 18 | - ipykernel=6.4.1 19 | - ipython=7.29.0 20 | - ipython_genutils=0.2.0 21 | - ipywidgets=7.6.5 22 | - joblib=1.1.0 23 | - jupyter=1.0.0 24 | - matplotlib=3.5.1 25 | - mkl=2021.4.0 26 | - networkx=2.6.3 27 | - notebook=6.4.12 28 | - numpy=1.20.3 29 | - numpy-base=1.20.3 30 | - openbabel=3.1.1 31 | - pandas=1.3.5 32 | - pickleshare=0.7.5 33 | - pip=21.2.2 34 | - py3dmol=1.8.0 35 | - python=3.9 36 | - pytorch=1.9.1 37 | - pytorch-scatter=2.0.9 38 | - pytorch-sparse=0.6.12 39 | - pytorch-spline-conv=1.2.1 40 | - rdkit=2021.09.4 41 | - requests=2.27.1 42 | - scikit-learn=1.0.2 43 | - scipy=1.6.2 44 | - tensorboard=2.9.1 45 | - tqdm=4.62.3 46 | - typed-ast=1.4.3 47 | - typing-extensions=4.3.0 48 | - werkzeug=2.0.3 49 | - yaml=0.2.5 50 | - pip: 51 | - torch-geometric==2.0.3 52 | -------------------------------------------------------------------------------- /hparam_files/hparams_bcp_atom_ids.csv: -------------------------------------------------------------------------------- 1 | run_id,batch_size,kernel_dim,mlp_dim,cutoff,baseline_atom_ids,aggr,pool,properties,n_kernels,ncp_graph 2 | run_00000,16,16,128,6,True,mean,mean,nnnnnnnnn,5,False 3 | run_00001,16,16,256,6,True,mean,mean,nnnnnnnnn,5,False 4 | run_00002,16,16,512,6,True,mean,mean,nnnnnnnnn,5,False 5 | run_00003,16,32,128,6,True,mean,mean,nnnnnnnnn,5,False 6 | run_00004,16,32,256,6,True,mean,mean,nnnnnnnnn,5,False 7 | run_00005,16,32,512,6,True,mean,mean,nnnnnnnnn,5,False 8 | run_00006,16,64,128,6,True,mean,mean,nnnnnnnnn,5,False 9 | run_00007,16,64,256,6,True,mean,mean,nnnnnnnnn,5,False 10 | run_00008,16,64,512,6,True,mean,mean,nnnnnnnnn,5,False 11 | run_00009,16,128,128,6,True,mean,mean,nnnnnnnnn,5,False 12 | run_00010,16,128,256,6,True,mean,mean,nnnnnnnnn,5,False 13 | run_00011,16,128,512,6,True,mean,mean,nnnnnnnnn,5,False 14 | run_00012,32,16,128,6,True,mean,mean,nnnnnnnnn,5,False 15 | run_00013,32,16,256,6,True,mean,mean,nnnnnnnnn,5,False 16 | run_00014,32,16,512,6,True,mean,mean,nnnnnnnnn,5,False 17 | run_00015,32,32,128,6,True,mean,mean,nnnnnnnnn,5,False 18 | run_00016,32,32,256,6,True,mean,mean,nnnnnnnnn,5,False 19 | run_00017,32,32,512,6,True,mean,mean,nnnnnnnnn,5,False 20 | run_00018,32,64,128,6,True,mean,mean,nnnnnnnnn,5,False 21 | run_00019,32,64,256,6,True,mean,mean,nnnnnnnnn,5,False 22 | run_00020,32,64,512,6,True,mean,mean,nnnnnnnnn,5,False 23 | run_00021,32,128,128,6,True,mean,mean,nnnnnnnnn,5,False 24 | run_00022,32,128,256,6,True,mean,mean,nnnnnnnnn,5,False 25 | run_00023,32,128,512,6,True,mean,mean,nnnnnnnnn,5,False 26 | run_00024,64,16,128,6,True,mean,mean,nnnnnnnnn,5,False 27 | run_00025,64,16,256,6,True,mean,mean,nnnnnnnnn,5,False 28 | run_00026,64,16,512,6,True,mean,mean,nnnnnnnnn,5,False 29 | run_00027,64,32,128,6,True,mean,mean,nnnnnnnnn,5,False 30 | run_00028,64,32,256,6,True,mean,mean,nnnnnnnnn,5,False 31 | run_00029,64,32,512,6,True,mean,mean,nnnnnnnnn,5,False 32 | run_00030,64,64,128,6,True,mean,mean,nnnnnnnnn,5,False 33 | run_00031,64,64,256,6,True,mean,mean,nnnnnnnnn,5,False 34 | run_00032,64,64,512,6,True,mean,mean,nnnnnnnnn,5,False 35 | run_00033,64,128,128,6,True,mean,mean,nnnnnnnnn,5,False 36 | run_00034,64,128,256,6,True,mean,mean,nnnnnnnnn,5,False 37 | run_00035,64,128,512,6,True,mean,mean,nnnnnnnnn,5,False 38 | run_00036,128,16,128,6,True,mean,mean,nnnnnnnnn,5,False 39 | run_00037,128,16,256,6,True,mean,mean,nnnnnnnnn,5,False 40 | run_00038,128,16,512,6,True,mean,mean,nnnnnnnnn,5,False 41 | run_00039,128,32,128,6,True,mean,mean,nnnnnnnnn,5,False 42 | run_00040,128,32,256,6,True,mean,mean,nnnnnnnnn,5,False 43 | run_00041,128,32,512,6,True,mean,mean,nnnnnnnnn,5,False 44 | run_00042,128,64,128,6,True,mean,mean,nnnnnnnnn,5,False 45 | run_00043,128,64,256,6,True,mean,mean,nnnnnnnnn,5,False 46 | run_00044,128,64,512,6,True,mean,mean,nnnnnnnnn,5,False 47 | run_00045,128,128,128,6,True,mean,mean,nnnnnnnnn,5,False 48 | run_00046,128,128,256,6,True,mean,mean,nnnnnnnnn,5,False 49 | run_00047,128,128,512,6,True,mean,mean,nnnnnnnnn,5,False 50 | -------------------------------------------------------------------------------- /hparam_files/hparams_bcp_atom_ids_and_props.csv: -------------------------------------------------------------------------------- 1 | run_id,batch_size,kernel_dim,mlp_dim,cutoff,baseline_atom_ids,aggr,pool,properties,n_kernels,ncp_graph 2 | run_00000,16,16,128,6,True,mean,mean,yyyyyyyyy,5,False 3 | run_00001,16,16,256,6,True,mean,mean,yyyyyyyyy,5,False 4 | run_00002,16,16,512,6,True,mean,mean,yyyyyyyyy,5,False 5 | run_00003,16,32,128,6,True,mean,mean,yyyyyyyyy,5,False 6 | run_00004,16,32,256,6,True,mean,mean,yyyyyyyyy,5,False 7 | run_00005,16,32,512,6,True,mean,mean,yyyyyyyyy,5,False 8 | run_00006,16,64,128,6,True,mean,mean,yyyyyyyyy,5,False 9 | run_00007,16,64,256,6,True,mean,mean,yyyyyyyyy,5,False 10 | run_00008,16,64,512,6,True,mean,mean,yyyyyyyyy,5,False 11 | run_00009,16,128,128,6,True,mean,mean,yyyyyyyyy,5,False 12 | run_00010,16,128,256,6,True,mean,mean,yyyyyyyyy,5,False 13 | run_00011,16,128,512,6,True,mean,mean,yyyyyyyyy,5,False 14 | run_00012,32,16,128,6,True,mean,mean,yyyyyyyyy,5,False 15 | run_00013,32,16,256,6,True,mean,mean,yyyyyyyyy,5,False 16 | run_00014,32,16,512,6,True,mean,mean,yyyyyyyyy,5,False 17 | run_00015,32,32,128,6,True,mean,mean,yyyyyyyyy,5,False 18 | run_00016,32,32,256,6,True,mean,mean,yyyyyyyyy,5,False 19 | run_00017,32,32,512,6,True,mean,mean,yyyyyyyyy,5,False 20 | run_00018,32,64,128,6,True,mean,mean,yyyyyyyyy,5,False 21 | run_00019,32,64,256,6,True,mean,mean,yyyyyyyyy,5,False 22 | run_00020,32,64,512,6,True,mean,mean,yyyyyyyyy,5,False 23 | run_00021,32,128,128,6,True,mean,mean,yyyyyyyyy,5,False 24 | run_00022,32,128,256,6,True,mean,mean,yyyyyyyyy,5,False 25 | run_00023,32,128,512,6,True,mean,mean,yyyyyyyyy,5,False 26 | run_00024,64,16,128,6,True,mean,mean,yyyyyyyyy,5,False 27 | run_00025,64,16,256,6,True,mean,mean,yyyyyyyyy,5,False 28 | run_00026,64,16,512,6,True,mean,mean,yyyyyyyyy,5,False 29 | run_00027,64,32,128,6,True,mean,mean,yyyyyyyyy,5,False 30 | run_00028,64,32,256,6,True,mean,mean,yyyyyyyyy,5,False 31 | run_00029,64,32,512,6,True,mean,mean,yyyyyyyyy,5,False 32 | run_00030,64,64,128,6,True,mean,mean,yyyyyyyyy,5,False 33 | run_00031,64,64,256,6,True,mean,mean,yyyyyyyyy,5,False 34 | run_00032,64,64,512,6,True,mean,mean,yyyyyyyyy,5,False 35 | run_00033,64,128,128,6,True,mean,mean,yyyyyyyyy,5,False 36 | run_00034,64,128,256,6,True,mean,mean,yyyyyyyyy,5,False 37 | run_00035,64,128,512,6,True,mean,mean,yyyyyyyyy,5,False 38 | run_00036,128,16,128,6,True,mean,mean,yyyyyyyyy,5,False 39 | run_00037,128,16,256,6,True,mean,mean,yyyyyyyyy,5,False 40 | run_00038,128,16,512,6,True,mean,mean,yyyyyyyyy,5,False 41 | run_00039,128,32,128,6,True,mean,mean,yyyyyyyyy,5,False 42 | run_00040,128,32,256,6,True,mean,mean,yyyyyyyyy,5,False 43 | run_00041,128,32,512,6,True,mean,mean,yyyyyyyyy,5,False 44 | run_00042,128,64,128,6,True,mean,mean,yyyyyyyyy,5,False 45 | run_00043,128,64,256,6,True,mean,mean,yyyyyyyyy,5,False 46 | run_00044,128,64,512,6,True,mean,mean,yyyyyyyyy,5,False 47 | run_00045,128,128,128,6,True,mean,mean,yyyyyyyyy,5,False 48 | run_00046,128,128,256,6,True,mean,mean,yyyyyyyyy,5,False 49 | run_00047,128,128,512,6,True,mean,mean,yyyyyyyyy,5,False 50 | -------------------------------------------------------------------------------- /hparam_files/hparams_bcp_feature_ablation.csv: -------------------------------------------------------------------------------- 1 | run_id,batch_size,kernel_dim,mlp_dim,cutoff,baseline_atom_ids,aggr,pool,properties,n_kernels,ncp_graph 2 | run_00146,128,128,256,6,False,mean,mean,yyyyyyyyy,5,False 3 | run_00246,128,128,256,6,False,mean,mean,ynnnnnnnn,5,False 4 | run_00346,128,128,256,6,False,mean,mean,nynnnnnnn,5,False 5 | run_00446,128,128,256,6,False,mean,mean,nnynnnnnn,5,False 6 | run_00546,128,128,256,6,False,mean,mean,nnnynnnnn,5,False 7 | run_00646,128,128,256,6,False,mean,mean,nnnnynnnn,5,False 8 | run_00746,128,128,256,6,False,mean,mean,nnnnnynnn,5,False 9 | run_00846,128,128,256,6,False,mean,mean,nnnnnnynn,5,False 10 | run_00946,128,128,256,6,False,mean,mean,nnnnnnnyn,5,False 11 | run_01046,128,128,256,6,False,mean,mean,nnnnnnnny,5,False 12 | run_01046,128,128,256,6,False,mean,mean,yynnnnnnn,5,False 13 | run_01046,128,128,256,6,False,mean,mean,yynnynnnn,5,False 14 | -------------------------------------------------------------------------------- /hparam_files/hparams_bcp_props.csv: -------------------------------------------------------------------------------- 1 | run_id,batch_size,kernel_dim,mlp_dim,cutoff,baseline_atom_ids,aggr,pool,properties,n_kernels,ncp_graph 2 | run_00000,16,16,128,6,False,mean,mean,yyyyyyyyy,5,False 3 | run_00001,16,16,256,6,False,mean,mean,yyyyyyyyy,5,False 4 | run_00002,16,16,512,6,False,mean,mean,yyyyyyyyy,5,False 5 | run_00003,16,32,128,6,False,mean,mean,yyyyyyyyy,5,False 6 | run_00004,16,32,256,6,False,mean,mean,yyyyyyyyy,5,False 7 | run_00005,16,32,512,6,False,mean,mean,yyyyyyyyy,5,False 8 | run_00006,16,64,128,6,False,mean,mean,yyyyyyyyy,5,False 9 | run_00007,16,64,256,6,False,mean,mean,yyyyyyyyy,5,False 10 | run_00008,16,64,512,6,False,mean,mean,yyyyyyyyy,5,False 11 | run_00009,16,128,128,6,False,mean,mean,yyyyyyyyy,5,False 12 | run_00010,16,128,256,6,False,mean,mean,yyyyyyyyy,5,False 13 | run_00011,16,128,512,6,False,mean,mean,yyyyyyyyy,5,False 14 | run_00012,32,16,128,6,False,mean,mean,yyyyyyyyy,5,False 15 | run_00013,32,16,256,6,False,mean,mean,yyyyyyyyy,5,False 16 | run_00014,32,16,512,6,False,mean,mean,yyyyyyyyy,5,False 17 | run_00015,32,32,128,6,False,mean,mean,yyyyyyyyy,5,False 18 | run_00016,32,32,256,6,False,mean,mean,yyyyyyyyy,5,False 19 | run_00017,32,32,512,6,False,mean,mean,yyyyyyyyy,5,False 20 | run_00018,32,64,128,6,False,mean,mean,yyyyyyyyy,5,False 21 | run_00019,32,64,256,6,False,mean,mean,yyyyyyyyy,5,False 22 | run_00020,32,64,512,6,False,mean,mean,yyyyyyyyy,5,False 23 | run_00021,32,128,128,6,False,mean,mean,yyyyyyyyy,5,False 24 | run_00022,32,128,256,6,False,mean,mean,yyyyyyyyy,5,False 25 | run_00023,32,128,512,6,False,mean,mean,yyyyyyyyy,5,False 26 | run_00024,64,16,128,6,False,mean,mean,yyyyyyyyy,5,False 27 | run_00025,64,16,256,6,False,mean,mean,yyyyyyyyy,5,False 28 | run_00026,64,16,512,6,False,mean,mean,yyyyyyyyy,5,False 29 | run_00027,64,32,128,6,False,mean,mean,yyyyyyyyy,5,False 30 | run_00028,64,32,256,6,False,mean,mean,yyyyyyyyy,5,False 31 | run_00029,64,32,512,6,False,mean,mean,yyyyyyyyy,5,False 32 | run_00030,64,64,128,6,False,mean,mean,yyyyyyyyy,5,False 33 | run_00031,64,64,256,6,False,mean,mean,yyyyyyyyy,5,False 34 | run_00032,64,64,512,6,False,mean,mean,yyyyyyyyy,5,False 35 | run_00033,64,128,128,6,False,mean,mean,yyyyyyyyy,5,False 36 | run_00034,64,128,256,6,False,mean,mean,yyyyyyyyy,5,False 37 | run_00035,64,128,512,6,False,mean,mean,yyyyyyyyy,5,False 38 | run_00036,128,16,128,6,False,mean,mean,yyyyyyyyy,5,False 39 | run_00037,128,16,256,6,False,mean,mean,yyyyyyyyy,5,False 40 | run_00038,128,16,512,6,False,mean,mean,yyyyyyyyy,5,False 41 | run_00039,128,32,128,6,False,mean,mean,yyyyyyyyy,5,False 42 | run_00040,128,32,256,6,False,mean,mean,yyyyyyyyy,5,False 43 | run_00041,128,32,512,6,False,mean,mean,yyyyyyyyy,5,False 44 | run_00042,128,64,128,6,False,mean,mean,yyyyyyyyy,5,False 45 | run_00043,128,64,256,6,False,mean,mean,yyyyyyyyy,5,False 46 | run_00044,128,64,512,6,False,mean,mean,yyyyyyyyy,5,False 47 | run_00045,128,128,128,6,False,mean,mean,yyyyyyyyy,5,False 48 | run_00046,128,128,256,6,False,mean,mean,yyyyyyyyy,5,False 49 | run_00047,128,128,512,6,False,mean,mean,yyyyyyyyy,5,False 50 | -------------------------------------------------------------------------------- /hparam_files/hparams_bcp_props_mini.csv: -------------------------------------------------------------------------------- 1 | run_id,batch_size,kernel_dim,mlp_dim,cutoff,baseline_atom_ids,aggr,pool,properties,n_kernels,ncp_graph 2 | run_00000,16,16,128,6,False,mean,mean,yyyyyyyyy,5,False 3 | run_00001,16,16,256,6,False,mean,mean,yyyyyyyyy,5,False 4 | -------------------------------------------------------------------------------- /hparam_files/hparams_ncp_atom_ids.csv: -------------------------------------------------------------------------------- 1 | run_id,batch_size,kernel_dim,mlp_dim,cutoff,baseline_atom_ids,aggr,pool,properties,n_kernels,ncp_graph 2 | run_00000,16,16,128,6,True,mean,mean,nnnnnnnnn,5,True 3 | run_00001,16,16,256,6,True,mean,mean,nnnnnnnnn,5,True 4 | run_00002,16,16,512,6,True,mean,mean,nnnnnnnnn,5,True 5 | run_00003,16,32,128,6,True,mean,mean,nnnnnnnnn,5,True 6 | run_00004,16,32,256,6,True,mean,mean,nnnnnnnnn,5,True 7 | run_00005,16,32,512,6,True,mean,mean,nnnnnnnnn,5,True 8 | run_00006,16,64,128,6,True,mean,mean,nnnnnnnnn,5,True 9 | run_00007,16,64,256,6,True,mean,mean,nnnnnnnnn,5,True 10 | run_00008,16,64,512,6,True,mean,mean,nnnnnnnnn,5,True 11 | run_00009,16,128,128,6,True,mean,mean,nnnnnnnnn,5,True 12 | run_00010,16,128,256,6,True,mean,mean,nnnnnnnnn,5,True 13 | run_00011,16,128,512,6,True,mean,mean,nnnnnnnnn,5,True 14 | run_00012,32,16,128,6,True,mean,mean,nnnnnnnnn,5,True 15 | run_00013,32,16,256,6,True,mean,mean,nnnnnnnnn,5,True 16 | run_00014,32,16,512,6,True,mean,mean,nnnnnnnnn,5,True 17 | run_00015,32,32,128,6,True,mean,mean,nnnnnnnnn,5,True 18 | run_00016,32,32,256,6,True,mean,mean,nnnnnnnnn,5,True 19 | run_00017,32,32,512,6,True,mean,mean,nnnnnnnnn,5,True 20 | run_00018,32,64,128,6,True,mean,mean,nnnnnnnnn,5,True 21 | run_00019,32,64,256,6,True,mean,mean,nnnnnnnnn,5,True 22 | run_00020,32,64,512,6,True,mean,mean,nnnnnnnnn,5,True 23 | run_00021,32,128,128,6,True,mean,mean,nnnnnnnnn,5,True 24 | run_00022,32,128,256,6,True,mean,mean,nnnnnnnnn,5,True 25 | run_00023,32,128,512,6,True,mean,mean,nnnnnnnnn,5,True 26 | run_00024,64,16,128,6,True,mean,mean,nnnnnnnnn,5,True 27 | run_00025,64,16,256,6,True,mean,mean,nnnnnnnnn,5,True 28 | run_00026,64,16,512,6,True,mean,mean,nnnnnnnnn,5,True 29 | run_00027,64,32,128,6,True,mean,mean,nnnnnnnnn,5,True 30 | run_00028,64,32,256,6,True,mean,mean,nnnnnnnnn,5,True 31 | run_00029,64,32,512,6,True,mean,mean,nnnnnnnnn,5,True 32 | run_00030,64,64,128,6,True,mean,mean,nnnnnnnnn,5,True 33 | run_00031,64,64,256,6,True,mean,mean,nnnnnnnnn,5,True 34 | run_00032,64,64,512,6,True,mean,mean,nnnnnnnnn,5,True 35 | run_00033,64,128,128,6,True,mean,mean,nnnnnnnnn,5,True 36 | run_00034,64,128,256,6,True,mean,mean,nnnnnnnnn,5,True 37 | run_00035,64,128,512,6,True,mean,mean,nnnnnnnnn,5,True 38 | run_00036,128,16,128,6,True,mean,mean,nnnnnnnnn,5,True 39 | run_00037,128,16,256,6,True,mean,mean,nnnnnnnnn,5,True 40 | run_00038,128,16,512,6,True,mean,mean,nnnnnnnnn,5,True 41 | run_00039,128,32,128,6,True,mean,mean,nnnnnnnnn,5,True 42 | run_00040,128,32,256,6,True,mean,mean,nnnnnnnnn,5,True 43 | run_00041,128,32,512,6,True,mean,mean,nnnnnnnnn,5,True 44 | run_00042,128,64,128,6,True,mean,mean,nnnnnnnnn,5,True 45 | run_00043,128,64,256,6,True,mean,mean,nnnnnnnnn,5,True 46 | run_00044,128,64,512,6,True,mean,mean,nnnnnnnnn,5,True 47 | run_00045,128,128,128,6,True,mean,mean,nnnnnnnnn,5,True 48 | run_00046,128,128,256,6,True,mean,mean,nnnnnnnnn,5,True 49 | run_00047,128,128,512,6,True,mean,mean,nnnnnnnnn,5,True 50 | -------------------------------------------------------------------------------- /hparam_files/hparams_ncp_atom_ids_and_props.csv: -------------------------------------------------------------------------------- 1 | run_id,batch_size,kernel_dim,mlp_dim,cutoff,baseline_atom_ids,aggr,pool,properties,n_kernels,ncp_graph 2 | run_00000,16,16,128,6,True,mean,mean,yyyyyyyyy,5,True 3 | run_00001,16,16,256,6,True,mean,mean,yyyyyyyyy,5,True 4 | run_00002,16,16,512,6,True,mean,mean,yyyyyyyyy,5,True 5 | run_00003,16,32,128,6,True,mean,mean,yyyyyyyyy,5,True 6 | run_00004,16,32,256,6,True,mean,mean,yyyyyyyyy,5,True 7 | run_00005,16,32,512,6,True,mean,mean,yyyyyyyyy,5,True 8 | run_00006,16,64,128,6,True,mean,mean,yyyyyyyyy,5,True 9 | run_00007,16,64,256,6,True,mean,mean,yyyyyyyyy,5,True 10 | run_00008,16,64,512,6,True,mean,mean,yyyyyyyyy,5,True 11 | run_00009,16,128,128,6,True,mean,mean,yyyyyyyyy,5,True 12 | run_00010,16,128,256,6,True,mean,mean,yyyyyyyyy,5,True 13 | run_00011,16,128,512,6,True,mean,mean,yyyyyyyyy,5,True 14 | run_00012,32,16,128,6,True,mean,mean,yyyyyyyyy,5,True 15 | run_00013,32,16,256,6,True,mean,mean,yyyyyyyyy,5,True 16 | run_00014,32,16,512,6,True,mean,mean,yyyyyyyyy,5,True 17 | run_00015,32,32,128,6,True,mean,mean,yyyyyyyyy,5,True 18 | run_00016,32,32,256,6,True,mean,mean,yyyyyyyyy,5,True 19 | run_00017,32,32,512,6,True,mean,mean,yyyyyyyyy,5,True 20 | run_00018,32,64,128,6,True,mean,mean,yyyyyyyyy,5,True 21 | run_00019,32,64,256,6,True,mean,mean,yyyyyyyyy,5,True 22 | run_00020,32,64,512,6,True,mean,mean,yyyyyyyyy,5,True 23 | run_00021,32,128,128,6,True,mean,mean,yyyyyyyyy,5,True 24 | run_00022,32,128,256,6,True,mean,mean,yyyyyyyyy,5,True 25 | run_00023,32,128,512,6,True,mean,mean,yyyyyyyyy,5,True 26 | run_00024,64,16,128,6,True,mean,mean,yyyyyyyyy,5,True 27 | run_00025,64,16,256,6,True,mean,mean,yyyyyyyyy,5,True 28 | run_00026,64,16,512,6,True,mean,mean,yyyyyyyyy,5,True 29 | run_00027,64,32,128,6,True,mean,mean,yyyyyyyyy,5,True 30 | run_00028,64,32,256,6,True,mean,mean,yyyyyyyyy,5,True 31 | run_00029,64,32,512,6,True,mean,mean,yyyyyyyyy,5,True 32 | run_00030,64,64,128,6,True,mean,mean,yyyyyyyyy,5,True 33 | run_00031,64,64,256,6,True,mean,mean,yyyyyyyyy,5,True 34 | run_00032,64,64,512,6,True,mean,mean,yyyyyyyyy,5,True 35 | run_00033,64,128,128,6,True,mean,mean,yyyyyyyyy,5,True 36 | run_00034,64,128,256,6,True,mean,mean,yyyyyyyyy,5,True 37 | run_00035,64,128,512,6,True,mean,mean,yyyyyyyyy,5,True 38 | run_00036,128,16,128,6,True,mean,mean,yyyyyyyyy,5,True 39 | run_00037,128,16,256,6,True,mean,mean,yyyyyyyyy,5,True 40 | run_00038,128,16,512,6,True,mean,mean,yyyyyyyyy,5,True 41 | run_00039,128,32,128,6,True,mean,mean,yyyyyyyyy,5,True 42 | run_00040,128,32,256,6,True,mean,mean,yyyyyyyyy,5,True 43 | run_00041,128,32,512,6,True,mean,mean,yyyyyyyyy,5,True 44 | run_00042,128,64,128,6,True,mean,mean,yyyyyyyyy,5,True 45 | run_00043,128,64,256,6,True,mean,mean,yyyyyyyyy,5,True 46 | run_00044,128,64,512,6,True,mean,mean,yyyyyyyyy,5,True 47 | run_00045,128,128,128,6,True,mean,mean,yyyyyyyyy,5,True 48 | run_00046,128,128,256,6,True,mean,mean,yyyyyyyyy,5,True 49 | run_00047,128,128,512,6,True,mean,mean,yyyyyyyyy,5,True 50 | -------------------------------------------------------------------------------- /hparam_files/hparams_ncp_props.csv: -------------------------------------------------------------------------------- 1 | run_id,batch_size,kernel_dim,mlp_dim,cutoff,baseline_atom_ids,aggr,pool,properties,n_kernels,ncp_graph 2 | run_00000,16,16,128,6,False,mean,mean,yyyyyyyyy,5,True 3 | run_00001,16,16,256,6,False,mean,mean,yyyyyyyyy,5,True 4 | run_00002,16,16,512,6,False,mean,mean,yyyyyyyyy,5,True 5 | run_00003,16,32,128,6,False,mean,mean,yyyyyyyyy,5,True 6 | run_00004,16,32,256,6,False,mean,mean,yyyyyyyyy,5,True 7 | run_00005,16,32,512,6,False,mean,mean,yyyyyyyyy,5,True 8 | run_00006,16,64,128,6,False,mean,mean,yyyyyyyyy,5,True 9 | run_00007,16,64,256,6,False,mean,mean,yyyyyyyyy,5,True 10 | run_00008,16,64,512,6,False,mean,mean,yyyyyyyyy,5,True 11 | run_00009,16,128,128,6,False,mean,mean,yyyyyyyyy,5,True 12 | run_00010,16,128,256,6,False,mean,mean,yyyyyyyyy,5,True 13 | run_00011,16,128,512,6,False,mean,mean,yyyyyyyyy,5,True 14 | run_00012,32,16,128,6,False,mean,mean,yyyyyyyyy,5,True 15 | run_00013,32,16,256,6,False,mean,mean,yyyyyyyyy,5,True 16 | run_00014,32,16,512,6,False,mean,mean,yyyyyyyyy,5,True 17 | run_00015,32,32,128,6,False,mean,mean,yyyyyyyyy,5,True 18 | run_00016,32,32,256,6,False,mean,mean,yyyyyyyyy,5,True 19 | run_00017,32,32,512,6,False,mean,mean,yyyyyyyyy,5,True 20 | run_00018,32,64,128,6,False,mean,mean,yyyyyyyyy,5,True 21 | run_00019,32,64,256,6,False,mean,mean,yyyyyyyyy,5,True 22 | run_00020,32,64,512,6,False,mean,mean,yyyyyyyyy,5,True 23 | run_00021,32,128,128,6,False,mean,mean,yyyyyyyyy,5,True 24 | run_00022,32,128,256,6,False,mean,mean,yyyyyyyyy,5,True 25 | run_00023,32,128,512,6,False,mean,mean,yyyyyyyyy,5,True 26 | run_00024,64,16,128,6,False,mean,mean,yyyyyyyyy,5,True 27 | run_00025,64,16,256,6,False,mean,mean,yyyyyyyyy,5,True 28 | run_00026,64,16,512,6,False,mean,mean,yyyyyyyyy,5,True 29 | run_00027,64,32,128,6,False,mean,mean,yyyyyyyyy,5,True 30 | run_00028,64,32,256,6,False,mean,mean,yyyyyyyyy,5,True 31 | run_00029,64,32,512,6,False,mean,mean,yyyyyyyyy,5,True 32 | run_00030,64,64,128,6,False,mean,mean,yyyyyyyyy,5,True 33 | run_00031,64,64,256,6,False,mean,mean,yyyyyyyyy,5,True 34 | run_00032,64,64,512,6,False,mean,mean,yyyyyyyyy,5,True 35 | run_00033,64,128,128,6,False,mean,mean,yyyyyyyyy,5,True 36 | run_00034,64,128,256,6,False,mean,mean,yyyyyyyyy,5,True 37 | run_00035,64,128,512,6,False,mean,mean,yyyyyyyyy,5,True 38 | run_00036,128,16,128,6,False,mean,mean,yyyyyyyyy,5,True 39 | run_00037,128,16,256,6,False,mean,mean,yyyyyyyyy,5,True 40 | run_00038,128,16,512,6,False,mean,mean,yyyyyyyyy,5,True 41 | run_00039,128,32,128,6,False,mean,mean,yyyyyyyyy,5,True 42 | run_00040,128,32,256,6,False,mean,mean,yyyyyyyyy,5,True 43 | run_00041,128,32,512,6,False,mean,mean,yyyyyyyyy,5,True 44 | run_00042,128,64,128,6,False,mean,mean,yyyyyyyyy,5,True 45 | run_00043,128,64,256,6,False,mean,mean,yyyyyyyyy,5,True 46 | run_00044,128,64,512,6,False,mean,mean,yyyyyyyyy,5,True 47 | run_00045,128,128,128,6,False,mean,mean,yyyyyyyyy,5,True 48 | run_00046,128,128,256,6,False,mean,mean,yyyyyyyyy,5,True 49 | run_00047,128,128,512,6,False,mean,mean,yyyyyyyyy,5,True 50 | -------------------------------------------------------------------------------- /img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cisert/bcpaff/bd32154d356ce01931ee0b6d8d8cb3c0c59849ec/img.png --------------------------------------------------------------------------------