├── .DS_Store ├── GraphINVENT_Protac ├── graphinvent │ ├── Analyzer.py │ ├── BlockDatasetLoader.py │ ├── DataProcesser.py │ ├── GraphGenerator.py │ ├── GraphGeneratorRL.py │ ├── MolecularGraph.py │ ├── ProtacScoringFunction.py │ ├── ScoringFunction.py │ ├── Workflow.py │ ├── __init__.py │ ├── data │ │ ├── fine-tuning │ │ │ └── README.md │ │ ├── pre-training │ │ │ └── MOSES │ │ │ │ └── README.md │ │ └── protac │ │ │ └── README.md │ ├── gnn │ │ ├── aggregation_mpnn.py │ │ ├── edge_mpnn.py │ │ ├── modules.py │ │ ├── mpnn.py │ │ └── summation_mpnn.py │ ├── main.py │ ├── parameters │ │ ├── args.py │ │ ├── constants.py │ │ ├── defaults.py │ │ └── load.py │ ├── training_stats.py │ └── util.py └── tools │ ├── analyze_final_epoch.py │ ├── atom_types.py │ ├── combine_HDFs.py │ ├── combine_generation_batches.py │ ├── formal_charges.py │ ├── max_n_nodes.py │ ├── split_filter_protac.py │ ├── submit-split-preprocessing-supercloud.py │ ├── tdc-create-dataset.py │ ├── utils.py │ └── visualization.py ├── LICENSE.md ├── README.md ├── binary_label_metrics.py ├── features.pkl ├── molecule_metrics.ipynb ├── surrogate_model.ipynb └── surrogate_model.pkl /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divnori/Protac-Design/1456f55400b5b24396f4e17ecf458948e51f315b/.DS_Store -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/BlockDatasetLoader.py: -------------------------------------------------------------------------------- 1 | """ 2 | The `BlockDatasetLoader` defines custom `DataLoader`s and `Dataset`s used to 3 | efficiently load data from HDF files in this work 4 | """ 5 | # load general packages and functions 6 | from typing import Tuple 7 | import torch 8 | import h5py 9 | 10 | 11 | class BlockDataLoader(torch.utils.data.DataLoader): 12 | """ 13 | Main `DataLoader` class which has been modified so as to read training data 14 | from disk in blocks, as opposed to a single line at a time (as is done in 15 | the original `DataLoader` class). 16 | """ 17 | def __init__(self, dataset : torch.utils.data.Dataset, batch_size : int=100, 18 | block_size : int=10000, shuffle : bool=True, n_workers : int=0, 19 | pin_memory : bool=True) -> None: 20 | 21 | # define variables to be used throughout dataloading 22 | self.dataset = dataset # `HDFDataset` object 23 | self.batch_size = batch_size # `int` 24 | self.block_size = block_size # `int` 25 | self.shuffle = shuffle # `bool` 26 | self.n_workers = n_workers # `int` 27 | self.pin_memory = pin_memory # `bool` 28 | self.block_dataset = BlockDataset(self.dataset, 29 | batch_size=self.batch_size, 30 | block_size=self.block_size) 31 | 32 | def __iter__(self) -> torch.Tensor: 33 | 34 | # define a regular `DataLoader` using the `BlockDataset` 35 | block_loader = torch.utils.data.DataLoader(self.block_dataset, 36 | shuffle=self.shuffle, 37 | num_workers=self.n_workers) 38 | 39 | # define a condition for determining whether to drop the last block this 40 | # is done if the remainder block is very small (less than a tenth the 41 | # size of a normal block) 42 | condition = bool( 43 | int(self.block_dataset.__len__()/self.block_size) > 1 & 44 | self.block_dataset.__len__()%self.block_size < self.block_size/10 45 | ) 46 | 47 | # loop through and load BLOCKS of data every iteration 48 | for block in block_loader: 49 | block = [torch.squeeze(b) for b in block] 50 | 51 | # wrap each block in a `ShuffleBlock` so that data can be shuffled 52 | # within blocks 53 | batch_loader = torch.utils.data.DataLoader( 54 | dataset=ShuffleBlockWrapper(block), 55 | shuffle=self.shuffle, 56 | batch_size=self.batch_size, 57 | num_workers=self.n_workers, 58 | pin_memory=self.pin_memory, 59 | drop_last=condition 60 | ) 61 | 62 | for batch in batch_loader: 63 | yield batch 64 | 65 | def __len__(self) -> int: 66 | # returns the number of graphs in the DataLoader 67 | n_blocks = len(self.dataset) // self.block_size 68 | n_rem = len(self.dataset) % self.block_size 69 | n_batch_per_block = self.__ceil__(self.block_size, self.batch_size) 70 | n_last = self.__ceil__(n_rem, self.batch_size) 71 | return n_batch_per_block * n_blocks + n_last 72 | 73 | def __ceil__(self, i : int, j : int) -> int: 74 | return (i + j - 1) // j 75 | 76 | 77 | class BlockDataset(torch.utils.data.Dataset): 78 | """ 79 | Modified `Dataset` class which returns BLOCKS of data when `__getitem__()` 80 | is called. 81 | """ 82 | def __init__(self, dataset : torch.utils.data.Dataset, batch_size : int=100, 83 | block_size : int=10000) -> None: 84 | 85 | assert block_size >= batch_size, "Block size should be > batch size." 86 | 87 | self.block_size = block_size # `int` 88 | self.batch_size = batch_size # `int` 89 | self.dataset = dataset # `HDFDataset` 90 | 91 | def __getitem__(self, idx : int) -> torch.Tensor: 92 | # returns a block of data from the dataset 93 | start = idx * self.block_size 94 | end = min((idx + 1) * self.block_size, len(self.dataset)) 95 | return self.dataset[start:end] 96 | 97 | def __len__(self) -> int: 98 | # returns the number of blocks in the dataset 99 | return (len(self.dataset) + self.block_size - 1) // self.block_size 100 | 101 | 102 | class ShuffleBlockWrapper: 103 | """ 104 | Extra class used to wrap a block of data, enabling data to get shuffled 105 | *within* a block. 106 | """ 107 | def __init__(self, data : torch.Tensor) -> None: 108 | self.data = data 109 | 110 | def __getitem__(self, idx : int) -> torch.Tensor: 111 | return [d[idx] for d in self.data] 112 | 113 | def __len__(self) -> int: 114 | return len(self.data[0]) 115 | 116 | 117 | class HDFDataset(torch.utils.data.Dataset): 118 | """ 119 | Reads and collects data from an HDF file with three datasets: "nodes", 120 | "edges", and "APDs". 121 | """ 122 | def __init__(self, path : str) -> None: 123 | 124 | self.path = path 125 | hdf_file = h5py.File(self.path, "r+", swmr=True) 126 | 127 | # load each HDF dataset 128 | self.nodes = hdf_file.get("nodes") 129 | self.edges = hdf_file.get("edges") 130 | self.apds = hdf_file.get("APDs") 131 | 132 | # get the number of elements in the dataset 133 | self.n_subgraphs = self.nodes.shape[0] 134 | 135 | def __getitem__(self, idx : int) -> \ 136 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 137 | 138 | # returns specific graph elements 139 | nodes_i = torch.from_numpy(self.nodes[idx]).type(torch.float32) 140 | edges_i = torch.from_numpy(self.edges[idx]).type(torch.float32) 141 | apd_i = torch.from_numpy(self.apds[idx]).type(torch.float32) 142 | 143 | return (nodes_i, edges_i, apd_i) 144 | 145 | def __len__(self) -> int: 146 | # returns the number of graphs in the dataset 147 | return self.n_subgraphs 148 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/DataProcesser.py: -------------------------------------------------------------------------------- 1 | """ 2 | The `DataProcesser` class contains functions for pre-processing training data. 3 | """ 4 | # load general packages and functions 5 | import os 6 | import numpy as np 7 | import rdkit 8 | import h5py 9 | from tqdm import tqdm 10 | 11 | # load GraphINVENT-specific functions 12 | from Analyzer import Analyzer 13 | from parameters.constants import constants 14 | import parameters.load as load 15 | from MolecularGraph import PreprocessingGraph 16 | import util 17 | 18 | 19 | class DataProcesser: 20 | """ 21 | A class for preprocessing molecular sets and writing them to HDF files. 22 | """ 23 | def __init__(self, path : str, is_training_set : bool=False, molset = []) -> None: 24 | """ 25 | Args: 26 | ---- 27 | path (string) : Full path/filename to SMILES file containing 28 | molecules. 29 | is_training_set (bool) : Indicates if this is the training set, as we 30 | calculate a few additional things for the training 31 | set. 32 | """ 33 | # define some variables for later use 34 | self.path = path 35 | print("path defined", flush = True) 36 | self.is_training_set = is_training_set 37 | print("training set defined", flush = True) 38 | self.dataset_names = ["nodes", "edges", "APDs"] 39 | print("dataset names defined", flush = True) 40 | self.get_dataset_dims() # creates `self.dims` 41 | print("get dims done", flush = True) 42 | 43 | # load the molecules 44 | if constants.compute_train_csv: 45 | self.molecule_set = molset 46 | else: 47 | self.molecule_set = load.molecules(self.path) 48 | print("molecules loaded", flush = True) 49 | print(f"len of loaded molecules {len(self.molecule_set)}", flush=True) 50 | #self.molecule_set = molset 51 | 52 | # placeholders 53 | self.molecule_subset = None 54 | self.dataset = None 55 | self.skip_collection = None 56 | self.resume_idx = None 57 | self.ts_properties = None 58 | self.restart_index_file = None 59 | self.hdf_file = None 60 | self.dataset_size = None 61 | 62 | print("placeholders set", flush = True) 63 | 64 | # get total number of molecules, and total number of subgraphs in their 65 | # decoding routes 66 | self.n_molecules = len(self.molecule_set) 67 | print("got len", flush = True) 68 | self.total_n_subgraphs = self.get_n_subgraphs() 69 | print("get subgraphs done", flush = True) 70 | print(f"-- {self.n_molecules} molecules in set.", flush=True) 71 | print(f"-- {self.total_n_subgraphs} total subgraphs in set.", 72 | flush=True) 73 | 74 | def preprocess(self) -> None: 75 | """ 76 | Prepares an HDF file to save three different datasets to it (`nodes`, 77 | `edges`, `APDs`), and slowly fills it in by looping over all the 78 | molecules in the data in groups (or "mini-batches"). 79 | """ 80 | with h5py.File(f"{self.path[:-3]}h5.chunked", "a") as self.hdf_file: 81 | 82 | self.restart_index_file = constants.dataset_dir + "index.restart" 83 | 84 | if constants.restart and os.path.exists(self.restart_index_file): 85 | self.restart_preprocessing_job() 86 | else: 87 | self.start_new_preprocessing_job() 88 | 89 | # keep track of the dataset size (to resize later) 90 | self.dataset_size = 0 91 | 92 | self.ts_properties = None 93 | 94 | # this is where we fill the datasets with actual data by looping 95 | # over subgraphs in blocks of size `constants.batch_size` 96 | for idx in tqdm(range(0, self.total_n_subgraphs, constants.batch_size)): 97 | 98 | if not self.skip_collection: 99 | 100 | # add `constants.batch_size` subgraphs from 101 | # `self.molecule_subset` to the dataset (and if training 102 | # set, calculate their properties and add these to 103 | # `self.ts_properties`) 104 | self.get_subgraphs(init_idx=idx) 105 | 106 | util.write_last_molecule_idx( 107 | last_molecule_idx=self.resume_idx, 108 | dataset_size=self.dataset_size, 109 | restart_file_path=constants.dataset_dir 110 | ) 111 | 112 | 113 | if self.resume_idx == self.n_molecules: 114 | # all molecules have been processed 115 | 116 | self.resize_datasets() # remove padding from initialization 117 | print("Datasets resized.", flush=True) 118 | 119 | print(f"mol set length is {len(self.molecule_set)}",flush=True) 120 | 121 | if self.is_training_set and not constants.restart: 122 | print("Writing training set properties.", flush=True) 123 | util.write_ts_properties( 124 | training_set_properties=self.ts_properties 125 | ) 126 | 127 | break 128 | 129 | print("* Resaving datasets in unchunked format.") 130 | self.resave_datasets_unchunked() 131 | 132 | def restart_preprocessing_job(self) -> None: 133 | """ 134 | Restarts a preprocessing job. Uses an index specified in the dataset 135 | directory to know where to resume preprocessing. 136 | """ 137 | try: 138 | self.resume_idx, self.dataset_size = util.read_last_molecule_idx( 139 | restart_file_path=constants.dataset_dir 140 | ) 141 | except: 142 | self.resume_idx, self.dataset_size = 0, 0 143 | self.skip_collection = bool( 144 | self.resume_idx == self.n_molecules and self.is_training_set 145 | ) 146 | 147 | # load dictionary of previously created datasets (`self.dataset`) 148 | self.load_datasets(hdf_file=self.hdf_file) 149 | 150 | def start_new_preprocessing_job(self) -> None: 151 | """ 152 | Starts a fresh preprocessing job. 153 | """ 154 | self.resume_idx = 0 155 | self.skip_collection = False 156 | 157 | # create a dictionary of empty HDF datasets (`self.dataset`) 158 | self.create_datasets(hdf_file=self.hdf_file) 159 | 160 | def resave_datasets_unchunked(self) -> None: 161 | """ 162 | Resaves the HDF datasets in an unchunked format to remove initial 163 | padding. 164 | """ 165 | with h5py.File(f"{self.path[:-3]}h5.chunked", "r", swmr=True) as chunked_file: 166 | keys = list(chunked_file.keys()) 167 | data = [chunked_file.get(key)[:] for key in keys] 168 | data_zipped = tuple(zip(data, keys)) 169 | 170 | with h5py.File(f"{self.path[:-3]}h5", "w") as unchunked_file: 171 | for d, k in tqdm(data_zipped): 172 | unchunked_file.create_dataset( 173 | k, chunks=None, data=d, dtype=np.dtype("int8") 174 | ) 175 | 176 | # remove the restart file and chunked file (don't need them anymore) 177 | os.remove(self.restart_index_file) 178 | os.remove(f"{self.path[:-3]}h5.chunked") 179 | 180 | def get_subgraphs(self, init_idx = 0) -> None: 181 | """ 182 | Adds `constants.batch_size` subgraphs from `self.molecule_subset` to the 183 | HDF dataset (and if currently processing the training set, also 184 | calculates the full graphs' properties and adds these to 185 | `self.ts_properties`). 186 | 187 | Args: 188 | ---- 189 | init_idx (int) : As analysis is done in blocks/slices, `init_idx` is 190 | the start index for the next block/slice to be taken 191 | from `self.molecule_subset`. 192 | """ 193 | data_subgraphs, data_apds, molecular_graph_list = [], [], [] # initialize 194 | 195 | # convert all molecules in `self.molecules_subset` to `PreprocessingGraphs` 196 | molecular_graph_generator = map(self.get_graph, self.molecule_set) 197 | 198 | molecules_processed = 0 # keep track of the number of molecules processed 199 | 200 | # # loop over all the `PreprocessingGraph`s 201 | for graph in molecular_graph_generator: 202 | molecules_processed += 1 203 | 204 | # store `PreprocessingGraph` object 205 | molecular_graph_list.append(graph) 206 | 207 | # get the number of decoding graphs 208 | n_subgraphs = graph.get_decoding_route_length() 209 | 210 | for new_subgraph_idx in range(n_subgraphs): 211 | 212 | # `get_decoding_route_state() returns a list of [`subgraph`, `apd`], 213 | subgraph, apd = graph.get_decoding_route_state( 214 | subgraph_idx=new_subgraph_idx 215 | ) 216 | 217 | # "collect" all APDs corresponding to pre-existing subgraphs, 218 | # otherwise append both new subgraph and new APD 219 | count = 0 220 | for idx, existing_subgraph in enumerate(data_subgraphs): 221 | 222 | count += 1 223 | # check if subgraph `subgraph` is "already" in 224 | # `data_subgraphs` as `existing_subgraph`, and if so, add 225 | # the "new" APD to the "old" 226 | try: # first compare the node feature matrices 227 | nodes_equal = (subgraph[0] == existing_subgraph[0]).all() 228 | except AttributeError: 229 | nodes_equal = False 230 | try: # then compare the edge feature tensors 231 | edges_equal = (subgraph[1] == existing_subgraph[1]).all() 232 | except AttributeError: 233 | edges_equal = False 234 | 235 | # if both matrices have a match, then subgraphs are the same 236 | if nodes_equal and edges_equal: 237 | existing_apd = data_apds[idx] 238 | existing_apd += apd 239 | break 240 | 241 | # if subgraph is not already in `data_subgraphs`, append it 242 | if count == len(data_subgraphs) or count == 0: 243 | data_subgraphs.append(subgraph) 244 | data_apds.append(apd) 245 | 246 | # if `constants.batch_size` unique subgraphs have been 247 | # processed, save group to the HDF dataset 248 | len_data_subgraphs = len(data_subgraphs) 249 | if len_data_subgraphs == constants.batch_size: 250 | self.save_group(data_subgraphs=data_subgraphs, 251 | data_apds=data_apds, 252 | group_size=len_data_subgraphs, 253 | init_idx=init_idx) 254 | 255 | # get molecular properties for group iff it's the training set 256 | self.get_ts_properties(molecular_graphs=molecular_graph_list, 257 | group_size=constants.batch_size) 258 | 259 | # keep track of the last molecule to be processed in 260 | # `self.resume_idx` 261 | # number of molecules processed: 262 | self.resume_idx += molecules_processed 263 | # subgraphs processed: 264 | self.dataset_size += constants.batch_size 265 | 266 | return None 267 | 268 | n_processed_subgraphs = len(data_subgraphs) 269 | 270 | # save group with < `constants.batch_size` subgraphs (e.g. last block) 271 | self.save_group(data_subgraphs=data_subgraphs, 272 | data_apds=data_apds, 273 | group_size=n_processed_subgraphs, 274 | init_idx=init_idx) 275 | 276 | # get molecular properties for this group iff it's the training set 277 | self.get_ts_properties(molecular_graphs=molecular_graph_list, 278 | group_size=constants.batch_size) 279 | 280 | # # keep track of the last molecule to be processed in `self.resume_idx` 281 | # self.resume_idx += molecules_processed # number of molecules processed 282 | # self.dataset_size += molecules_processed # subgraphs processed 283 | 284 | return None 285 | 286 | def create_datasets(self, hdf_file : h5py._hl.files.File) -> None: 287 | """ 288 | Creates a dictionary of HDF5 datasets (`self.dataset`). 289 | 290 | Args: 291 | ---- 292 | hdf_file (h5py._hl.files.File) : HDF5 file which will contain the datasets. 293 | """ 294 | self.dataset = {} # initialize 295 | 296 | for ds_name in self.dataset_names: 297 | self.dataset[ds_name] = hdf_file.create_dataset( 298 | ds_name, 299 | (self.total_n_subgraphs, *self.dims[ds_name]), 300 | chunks=True, # must be True for resizing later 301 | dtype=np.dtype("int8") 302 | ) 303 | 304 | def resize_datasets(self) -> None: 305 | """ 306 | Resizes the HDF datasets, since much longer datasets are initialized 307 | when first creating the HDF datasets (it it is impossible to predict 308 | how many graphs will be equivalent beforehand). 309 | """ 310 | for dataset_name in self.dataset_names: 311 | try: 312 | self.dataset[dataset_name].resize( 313 | (self.dataset_size, *self.dims[dataset_name])) 314 | except KeyError: # `f_term` has no extra dims 315 | self.dataset[dataset_name].resize((self.dataset_size,)) 316 | 317 | def get_dataset_dims(self) -> None: 318 | """ 319 | Calculates the dimensions of the node features, edge features, and APDs, 320 | and stores them as lists in a dict (`self.dims`), where keys are the 321 | dataset name. 322 | 323 | Shapes: 324 | ------ 325 | dims["nodes"] : [max N nodes, N atom types + N formal charges] 326 | dims["edges"] : [max N nodes, max N nodes, N bond types] 327 | dims["APDs"] : [APD length = f_add length + f_conn length + f_term length] 328 | """ 329 | self.dims = {} 330 | self.dims["nodes"] = constants.dim_nodes 331 | self.dims["edges"] = constants.dim_edges 332 | self.dims["APDs"] = constants.dim_apd 333 | 334 | def get_graph(self, mol : rdkit.Chem.Mol) -> PreprocessingGraph: 335 | """ 336 | Converts an `rdkit.Chem.Mol` object to `PreprocessingGraph`. 337 | 338 | Args: 339 | ---- 340 | mol (rdkit.Chem.Mol) : Molecule to convert. 341 | 342 | Returns: 343 | ------- 344 | molecular_graph (PreprocessingGraph) : Molecule, now as a graph. 345 | """ 346 | if mol is not None: 347 | print(rdkit.Chem.MolToSmiles(mol),flush=True) 348 | if not constants.use_aromatic_bonds: 349 | rdkit.Chem.Kekulize(mol, clearAromaticFlags=True) 350 | molecular_graph = PreprocessingGraph(molecule=mol, 351 | constants=constants) 352 | return molecular_graph 353 | 354 | def get_molecule_subset(self) -> None: 355 | """ 356 | Slices `self.molecule_set` into a subset of molecules of size 357 | `constants.batch_size`, starting from `self.resume_idx`. 358 | `self.n_molecules` is the number of molecules in the full 359 | `self.molecule_set`. 360 | """ 361 | init_idx = self.resume_idx 362 | subset_size = constants.batch_size 363 | self.molecule_subset = [] 364 | max_idx = min(init_idx + subset_size, self.n_molecules) 365 | 366 | count = -1 367 | for mol in self.molecule_set: 368 | if mol is not None: 369 | count += 1 370 | if count < init_idx: 371 | continue 372 | elif count >= max_idx: 373 | return self.molecule_subset 374 | else: 375 | self.molecule_subset.append(mol) 376 | 377 | def get_n_subgraphs(self) -> int: 378 | """ 379 | Calculates the total number of subgraphs in the decoding route of all 380 | molecules in `self.molecule_set`. Loads training, testing, or validation 381 | set. First, the `PreprocessingGraph` for each molecule is obtained, and 382 | then the length of the decoding route is trivially calculated for each. 383 | 384 | Returns: 385 | ------- 386 | n_subgraphs (int) : Sum of number of subgraphs in decoding routes for 387 | all molecules in `self.molecule_set`. 388 | """ 389 | n_subgraphs = 0 # start the count 390 | print("initialized n_subgraphs",flush=True) 391 | 392 | # convert molecules in `self.molecule_set` to `PreprocessingGraph`s 393 | molecular_graph_generator = map(self.get_graph, self.molecule_set) 394 | print("converted mol set to graph",flush=True) 395 | 396 | # loop over all the `PreprocessingGraph`s 397 | for molecular_graph in molecular_graph_generator: 398 | # get the number of decoding graphs (i.e. the decoding route length) 399 | # and add them to the running count 400 | n_subgraphs += molecular_graph.get_decoding_route_length() 401 | print("got route length",flush=True) 402 | 403 | print("done w graph loop",flush=True) 404 | 405 | return int(n_subgraphs) 406 | 407 | def get_ts_properties(self, molecular_graphs : list, group_size : int) -> \ 408 | None: 409 | """ 410 | Gets molecular properties for group of molecular graphs, only for the 411 | training set. 412 | 413 | Args: 414 | ---- 415 | molecular_graphs (list) : Contains `PreprocessingGraph`s. 416 | group_size (int) : Size of "group" (i.e. slice of graphs). 417 | """ 418 | if self.is_training_set: 419 | print("is training set so calculating training set props", flush=True) 420 | 421 | analyzer = Analyzer() 422 | ts_properties = analyzer.evaluate_training_set( 423 | preprocessing_graphs=molecular_graphs 424 | ) 425 | print("initialized ts analyzer",flush=True) 426 | 427 | # merge properties of current group with the previous group analyzed 428 | if self.ts_properties: # `self.ts_properties` is a dictionary 429 | print("in self.ts_properties conditional", flush=True) 430 | if not constants.compute_train_csv: 431 | self.ts_properties = analyzer.combine_ts_properties( 432 | prev_properties=self.ts_properties, 433 | next_properties=ts_properties, 434 | weight_next=group_size 435 | ) 436 | else: 437 | print("setting self.ts_properties", flush=True) 438 | self.ts_properties = ts_properties 439 | else: # `self.ts_properties` is None (has not been calculated yet) 440 | self.ts_properties = ts_properties 441 | else: 442 | self.ts_properties = None 443 | print(self.ts_properties) 444 | 445 | 446 | def load_datasets(self, hdf_file : h5py._hl.files.File) -> None: 447 | """ 448 | Creates a dictionary of HDF datasets (`self.dataset`) which have been 449 | previously created (for restart jobs only). 450 | 451 | Args: 452 | ---- 453 | hdf_file (h5py._hl.files.File) : HDF file containing all the datasets. 454 | """ 455 | self.dataset = {} # initialize dictionary of datasets 456 | 457 | # use the names of the datasets as the keys in `self.dataset` 458 | for ds_name in self.dataset_names: 459 | self.dataset[ds_name] = hdf_file.get(ds_name) 460 | 461 | def save_group(self, data_subgraphs : list, data_apds : list, 462 | group_size : int, init_idx : int) -> None: 463 | """ 464 | Saves a group of padded subgraphs and their corresponding APDs to the HDF 465 | datasets as `numpy.ndarray`s. 466 | 467 | Args: 468 | ---- 469 | data_subgraphs (list) : Contains molecular subgraphs. 470 | data_apds (list) : Contains APDs. 471 | group_size (int) : Size of HDF "slice". 472 | init_idx (int) : Index to begin slicing. 473 | """ 474 | # convert to `np.ndarray`s 475 | nodes = np.array([graph_tuple[0] for graph_tuple in data_subgraphs]) 476 | edges = np.array([graph_tuple[1] for graph_tuple in data_subgraphs]) 477 | apds = np.array(data_apds) 478 | 479 | end_idx = init_idx + group_size # idx to end slicing 480 | 481 | # once data is padded, save it to dataset slice 482 | # Broadcasting errors happen in the final group, 483 | # this skips them so processing can proceed 484 | try: 485 | self.dataset["nodes"][init_idx:end_idx] = nodes 486 | self.dataset["edges"][init_idx:end_idx] = edges 487 | self.dataset["APDs"][init_idx:end_idx] = apds 488 | except: 489 | print("\nBroadcasting error, skipping.\n") -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/ProtacScoringFunction.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains function to score a given SMILES string as a high quality, average quality, or low quality 3 | PROTAC based on protac scoring boosted tree model 4 | """ 5 | import rdkit 6 | from collections import namedtuple 7 | from rdkit import Chem 8 | from rdkit.Chem import AllChem, DataStructs 9 | from rdkit.Chem import rdchem 10 | import numpy as np 11 | import pickle 12 | from parameters import constants 13 | import requests as r 14 | 15 | def predictProteinDegradation(mol : rdkit.Chem.Mol, constants : namedtuple, cellType='SRD15', receptor='Q9Y616', e3Ligase='CRBN'): 16 | ''' 17 | Returns a 0 or 1 based on a protac's molecule protein degradation potential 18 | 1 -> degrades 19 | 0 -> does not degrade 20 | ''' 21 | try: 22 | scoring_model = constants.qsar_models["protac_qsar_model"] 23 | features = constants.activity_model_features 24 | Chem.SanitizeMol(mol) 25 | 26 | ngrams_array = np.zeros((1,7841), dtype=np.int8) 27 | # baseUrl="http://www.uniprot.org/uniprot/" 28 | # currentUrl=baseUrl+receptor+".fasta" 29 | # response = r.post(currentUrl) 30 | # cData=''.join(response.text) 31 | # i = cData.index('\n')+1 32 | # seq = cData[i:].strip().lower() 33 | seq = "MAGNCGARGALSAHTLLFDLPPALLGELCAVLDSCDGALGWRGLAERLSSSWLDVRHIEKYVDQGKSGTRELLWSWAQKNKTIGDLLQVLQEMGHRRAIHLITNYGAVLSPSEKSYQEGGFPNILFKETANVTVDNVLIPEHNEKGILLKSSISFQNIIEGTRNFHKDFLIGEGEIFEVYRVEIQNLTYAVKLFKQEKKMQCKKHWKRFLSELEVLLLFHHPNILELAAYFTETEKFCLIYPYMRNGTLFDRLQCVGDTAPLPWHIRIGILIGISKAIHYLHNVQPCSVICGSISSANILLDDQFQPKLTDFAMAHFRSHLEHQSCTINMTSSSSKHLWYMPEEYIRQGKLSIKTDVYSFGIVIMEVLTGCRVVLDDPKHIQLRDLLRELMEKRGLDSCLSFLDKKVPPCPRNFSAKLFCLAGRCAATRAKLRPSMDEVLNTLESTQASLYFAEDPPTSLKSFRCPSPLFLENVPSIPVEDDESQNNNLLPSDEGLRIDRMTQKTPFECSQSEVMFLSLDKKPESKRNEEACNMPSSSCEESWFPKYIVPSQDLRPYKVNIDPSSEAPGHSCRSRPVESSCSSKFSWDEYEQYKKE".lower() 34 | ngrams = features[1237:] 35 | for i in range(len(ngrams)): 36 | n = seq.count(ngrams[i]) 37 | ngrams_array[0][i] = n 38 | 39 | fingerprint = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024) 40 | fp_array = np.zeros((0,), dtype=np.int8) 41 | DataStructs.ConvertToNumpyArray(fingerprint, fp_array) 42 | 43 | ct_ind = features.index("ct_"+cellType) 44 | e3_ind = features.index("e3_"+e3Ligase) 45 | 46 | input = list(0 for i in range(5)) + list(fp_array) + list(0 for i in range(207))+ list(ngrams_array[0]) 47 | input[ct_ind] = 1 48 | input[e3_ind] = 1 49 | 50 | output = scoring_model.predict([input]) 51 | smi = Chem.MolToSmiles(mol) 52 | letters = smi.replace("[=()-]","") 53 | if output[0][0]-output[0][1] < 0 and len(letters)>=30: 54 | return 1 55 | else: 56 | return 0 57 | except Exception as e: 58 | print('EXCEPTION IN PREDICT PROTEIN DEGRADATION',flush=True) 59 | print(e, flush=True) 60 | return 0 -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/ScoringFunction.py: -------------------------------------------------------------------------------- 1 | """ 2 | This class is used for defining the scoring function(s) which can be used during 3 | fine-tuning. 4 | """ 5 | # load general packages and functions 6 | from collections import namedtuple 7 | import torch 8 | from rdkit import DataStructs 9 | from rdkit import Chem 10 | from rdkit.Chem import QED, AllChem 11 | import numpy as np 12 | import sklearn 13 | from sklearn import svm 14 | from ProtacScoringFunction import * 15 | 16 | class ScoringFunction: 17 | """ 18 | A class for defining the scoring function components. 19 | """ 20 | def __init__(self, constants : namedtuple) -> None: 21 | """ 22 | Args: 23 | ---- 24 | constants (namedtuple) : Contains job parameters as well as global 25 | constants. 26 | """ 27 | self.score_components = constants.score_components # list 28 | self.score_type = constants.score_type # list 29 | self.qsar_models = constants.qsar_models # dict 30 | self.device = constants.device 31 | self.max_n_nodes = constants.max_n_nodes 32 | self.score_thresholds = constants.score_thresholds 33 | 34 | self.n_graphs = None # placeholder 35 | self.constants = constants 36 | 37 | assert len(self.score_components) == len(self.score_thresholds), \ 38 | "`score_components` and `score_thresholds` do not match." 39 | 40 | def compute_score(self, graphs : list, termination : torch.Tensor, 41 | validity : torch.Tensor, uniqueness : torch.Tensor) -> \ 42 | torch.Tensor: 43 | """ 44 | Computes the overall score for the input molecular graphs. 45 | 46 | Args: 47 | ---- 48 | graphs (list) : Contains molecular graphs to evaluate. 49 | termination (torch.Tensor) : Termination status of input molecular 50 | graphs. 51 | validity (torch.Tensor) : Validity of input molecular graphs. 52 | uniqueness (torch.Tensor) : Uniqueness of input molecular graphs. 53 | 54 | Returns: 55 | ------- 56 | final_score (torch.Tensor) : The final scores for each input graph. 57 | """ 58 | self.n_graphs = len(graphs) 59 | contributions_to_score = self.get_contributions_to_score(graphs=graphs) 60 | 61 | print(f"contributions_to_score {contributions_to_score}") 62 | 63 | if len(self.score_components) == 1: 64 | print('len is 1') 65 | final_score = contributions_to_score[0] 66 | 67 | elif self.score_type == "continuous": 68 | print('in continuous') 69 | final_score = contributions_to_score[0] 70 | print(f"structure score tensor {contributions_to_score[1]}") 71 | for component in contributions_to_score[1:]: 72 | final_score *= component 73 | 74 | elif self.score_type == "binary": 75 | print('in binary') 76 | component_masks = [] 77 | for idx, score_component in enumerate(contributions_to_score): 78 | component_mask = torch.where( 79 | score_component > self.score_thresholds[idx], 80 | torch.ones(self.n_graphs, device=self.device, dtype=torch.uint8), 81 | torch.zeros(self.n_graphs, device=self.device, dtype=torch.uint8) 82 | ) 83 | component_masks.append(component_mask) 84 | 85 | final_score = component_masks[0] 86 | for mask in component_masks[1:]: 87 | final_score *= mask 88 | final_score = final_score.float() 89 | 90 | else: 91 | raise NotImplementedError 92 | 93 | print(f"final score before {final_score}") 94 | 95 | # remove contribution of duplicate molecules to the score 96 | final_score *= uniqueness 97 | print(f"uniqueness score {uniqueness}") 98 | 99 | # remove contribution of invalid molecules to the score 100 | final_score *= validity 101 | print(f"validity score {validity}") 102 | 103 | # remove contribution of improperly-terminated molecules to the score 104 | final_score *= termination 105 | print(f"termination score {termination}") 106 | 107 | print(f"final score after {final_score}") 108 | 109 | return final_score 110 | 111 | def get_contributions_to_score(self, graphs : list) -> list: 112 | """ 113 | Returns the different elements of the score. 114 | 115 | Args: 116 | ---- 117 | graphs (list) : Contains molecular graphs to evaluate. 118 | 119 | Returns: 120 | ------- 121 | contributions_to_score (list) : Contains elements of the score due to 122 | each scoring function component. 123 | """ 124 | contributions_to_score = [] 125 | 126 | for score_component in self.score_components: 127 | if "target_size" in score_component: 128 | 129 | target_size = int(score_component[12:]) 130 | 131 | assert target_size <= self.max_n_nodes, \ 132 | "Target size > largest possible size (`max_n_nodes`)." 133 | assert 0 < target_size, "Target size must be greater than 0." 134 | 135 | target_size *= torch.ones(self.n_graphs, device=self.device) 136 | n_nodes = torch.tensor([graph.n_nodes for graph in graphs], 137 | device=self.device) 138 | max_nodes = self.max_n_nodes 139 | score = ( 140 | torch.ones(self.n_graphs, device=self.device) 141 | - torch.abs(n_nodes - target_size) 142 | / (max_nodes - target_size) 143 | ) 144 | 145 | contributions_to_score.append(score) 146 | 147 | elif score_component == "QED": 148 | mols = [graph.molecule for graph in graphs] 149 | 150 | # compute the QED score for each molecule (if possible) 151 | qed = [] 152 | for mol in mols: 153 | try: 154 | qed.append(QED.qed(mol)) 155 | except: 156 | qed.append(0.0) 157 | score = torch.tensor(qed, device=self.device) 158 | 159 | contributions_to_score.append(score) 160 | 161 | 162 | elif "protac_activity" in score_component: 163 | mols = [graph.molecule for graph in graphs] 164 | 165 | # `score_component` has to be the key to the QSAR model in the 166 | # `self.qsar_models` dict 167 | #qsar_model = self.qsar_models[score_component] 168 | #score = self.compute_activity(mols, qsar_model) 169 | score = self.compute_protac_activity(mols) 170 | 171 | print(f"activity scores {score}") 172 | 173 | contributions_to_score.append(score) 174 | 175 | elif "structure" in score_component: 176 | mols = [graph.molecule for graph in graphs] 177 | score = self.compute_structure(mols) 178 | print(f"structure score {score}") 179 | contributions_to_score.append(score) 180 | 181 | # elif "solubility" in score_component: 182 | # mols = [graph.molecule for graph in graphs] 183 | # score = self.compute_solubility(mols) 184 | # print(f"structure score {score}") 185 | # contributions_to_score.append(score) 186 | 187 | elif "activity" in score_component: 188 | mols = [graph.molecule for graph in graphs] 189 | 190 | # `score_component` has to be the key to the QSAR model in the 191 | # `self.qsar_models` dict 192 | qsar_model = self.qsar_models[score_component] 193 | score = self.compute_activity(mols, qsar_model) 194 | 195 | print(f"activity scores {score}") 196 | 197 | contributions_to_score.append(score) 198 | 199 | else: 200 | raise NotImplementedError("The score component is not defined. " 201 | "You can define it in " 202 | "`ScoringFunction.py`.") 203 | 204 | return contributions_to_score 205 | 206 | def compute_protac_activity(self, mols : list) -> list: 207 | """ 208 | Args: 209 | ---- 210 | mols (list) : Contains `rdkit.Mol` objects corresponding to molecular 211 | graphs sampled. 212 | 213 | Returns: 214 | ------- 215 | activity (list) : Contains predicted protac activity for input molecules. 216 | """ 217 | 218 | n_mols = len(mols) 219 | activity = torch.zeros(n_mols, device=self.device) 220 | 221 | for idx, mol in enumerate(mols): 222 | score = predictProteinDegradation(mol, self.constants) 223 | activity[idx] = score 224 | 225 | print(f"shape of activity tensor is {activity.shape}") 226 | print(f"sum of activity tensor is {sum(activity.tolist())}") 227 | 228 | return activity 229 | 230 | def compute_structure(self, mols : list) -> list: 231 | """ 232 | Args: 233 | ---- 234 | mols (list) : Contains `rdkit.Mol` objects corresponding to molecular 235 | graphs sampled. 236 | 237 | Returns: 238 | ------- 239 | activity (list) : Contains predicted structures for input molecules (0 for non-protac, 1 for protac). 240 | """ 241 | 242 | n_mols = len(mols) 243 | activity = torch.zeros(n_mols, device=self.device) 244 | 245 | for idx, mol in enumerate(mols): 246 | score = predictStructure(mol, self.constants) 247 | activity[idx] = score 248 | 249 | return activity 250 | 251 | 252 | 253 | def compute_activity(self, mols : list, 254 | activity_model : sklearn.svm.SVC) -> list: 255 | """ 256 | Note: this function may have to be tuned/replicated depending on how 257 | the activity model is saved. 258 | 259 | Args: 260 | ---- 261 | mols (list) : Contains `rdkit.Mol` objects corresponding to molecular 262 | graphs sampled. 263 | activity_model (sklearn.svm.classes.SVC) : Pre-trained QSAR model. 264 | 265 | Returns: 266 | ------- 267 | activity (list) : Contains predicted activities for input molecules. 268 | """ 269 | n_mols = len(mols) 270 | activity = torch.zeros(n_mols, device=self.device) 271 | 272 | for idx, mol in enumerate(mols): 273 | try: 274 | fingerprint = AllChem.GetMorganFingerprintAsBitVect(mol, 275 | 2, 276 | nBits=2048) 277 | ecfp4 = np.zeros((2048,)) 278 | DataStructs.ConvertToNumpyArray(fingerprint, ecfp4) 279 | activity[idx] = activity_model.predict_proba([ecfp4])[0][1] 280 | except: 281 | pass # activity[idx] will remain 0.0 282 | 283 | return activity 284 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divnori/Protac-Design/1456f55400b5b24396f4e17ecf458948e51f315b/GraphINVENT_Protac/graphinvent/__init__.py -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/data/fine-tuning/README.md: -------------------------------------------------------------------------------- 1 | 2 | Place the features.pkl and surrogate_model.pkl files here. 3 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/data/pre-training/MOSES/README.md: -------------------------------------------------------------------------------- 1 | 2 | Put the following files in this folder: 3 | 4 | - train.smi 5 | - valid.smi 6 | - test.smi 7 | 8 | Once you've generated train.csv using training_stats.py, put that here as well. 9 | 10 | After main.py has been run with (defaults.py -> job_type = 'preprocess') there should be the following files in this directory: 11 | 12 | - train.h5 13 | - valid.h5 14 | - test.h5 15 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/data/protac/README.md: -------------------------------------------------------------------------------- 1 | 2 | Download protac.csv from http://cadd.zju.edu.cn/protacdb/statics/binaryDownload/csv/protac/protac.csv and save here under the name data.csv 3 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/gnn/aggregation_mpnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines the `AggregationMPNN` class. 3 | """ 4 | # load general packages and functions 5 | from collections import namedtuple 6 | import torch 7 | 8 | 9 | class AggregationMPNN(torch.nn.Module): 10 | """ 11 | Abstract `AggregationMPNN` class. Specific models using this class are 12 | defined in `mpnn.py`; these are the attention networks AttS2V and AttGGNN. 13 | """ 14 | def __init__(self, constants : namedtuple) -> None: 15 | super().__init__() 16 | 17 | self.hidden_node_features = constants.hidden_node_features 18 | self.edge_features = constants.n_edge_features 19 | self.message_size = constants.message_size 20 | self.message_passes = constants.message_passes 21 | self.constants = constants 22 | 23 | def aggregate_message(self, nodes : torch.Tensor, node_neighbours : torch.Tensor, 24 | edges : torch.Tensor, mask : torch.Tensor) -> None: 25 | """ 26 | Message aggregation function, to be implemented in all `AggregationMPNN` subclasses. 27 | 28 | Args: 29 | ---- 30 | nodes (torch.Tensor) : Batch of node feature vectors. 31 | node_neighbours (torch.Tensor) : Batch of node feature vectors for neighbors. 32 | edges (torch.Tensor) : Batch of edge feature vectors. 33 | mask (torch.Tensor) : Mask for non-existing neighbors, where 34 | elements are 1 if corresponding element 35 | exists and 0 otherwise. 36 | 37 | Shapes: 38 | ------ 39 | nodes : (total N nodes in batch, N node features) 40 | node_neighbours : (total N nodes in batch, max node degree, N node features) 41 | edges : (total N nodes in batch, max node degree, N edge features) 42 | mask : (total N nodes in batch, max node degree) 43 | """ 44 | raise NotImplementedError 45 | 46 | def update(self, nodes : torch.Tensor, messages : torch.Tensor) -> None: 47 | """ 48 | Message update function, to be implemented in all `AggregationMPNN` subclasses. 49 | 50 | Args: 51 | ---- 52 | nodes (torch.Tensor) : Batch of node feature vectors. 53 | messages (torch.Tensor) : Batch of incoming messages. 54 | 55 | Shapes: 56 | ------ 57 | nodes : (total N nodes in batch, N node features) 58 | messages : (total N nodes in batch, N node features) 59 | """ 60 | raise NotImplementedError 61 | 62 | def readout(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor, 63 | node_mask : torch.Tensor) -> None: 64 | """ 65 | Local readout function, to be implemented in all `AggregationMPNN` subclasses. 66 | 67 | Args: 68 | ---- 69 | hidden_nodes (torch.Tensor) : Batch of node feature vectors. 70 | input_nodes (torch.Tensor) : Batch of node feature vectors. 71 | node_mask (torch.Tensor) : Mask for non-existing neighbors, where 72 | elements are 1 if corresponding element 73 | exists and 0 otherwise. 74 | 75 | Shapes: 76 | ------ 77 | hidden_nodes : (total N nodes in batch, N node features) 78 | input_nodes : (total N nodes in batch, N node features) 79 | node_mask : (total N nodes in batch, N features) 80 | """ 81 | raise NotImplementedError 82 | 83 | def forward(self, nodes : torch.Tensor, edges : torch.Tensor) -> torch.Tensor: 84 | """ 85 | Defines forward pass. 86 | 87 | Args: 88 | ---- 89 | nodes (torch.Tensor) : Batch of node feature matrices. 90 | edges (torch.Tensor) : Batch of edge feature tensors. 91 | 92 | Shapes: 93 | ------ 94 | nodes : (batch size, N nodes, N node features) 95 | edges : (batch size, N nodes, N nodes, N edge features) 96 | 97 | Returns: 98 | ------- 99 | output (torch.Tensor) : This would normally be the learned graph 100 | representation, but in all MPNN readout functions 101 | in this work, the last layer is used to predict 102 | the action probability distribution for a batch 103 | of graphs from the learned graph representation. 104 | """ 105 | adjacency = torch.sum(edges, dim=3) 106 | 107 | # **note: "idc" == "indices", "nghb{s}" == "neighbour(s)" 108 | edge_batch_batch_idc, edge_batch_node_idc, edge_batch_nghb_idc = \ 109 | adjacency.nonzero(as_tuple=True) 110 | 111 | node_batch_batch_idc, node_batch_node_idc = adjacency.sum(-1).nonzero(as_tuple=True) 112 | node_batch_adj = adjacency[node_batch_batch_idc, node_batch_node_idc, :] 113 | node_batch_size = node_batch_batch_idc.shape[0] 114 | node_degrees = node_batch_adj.sum(-1).long() 115 | max_node_degree = node_degrees.max() 116 | 117 | node_batch_node_nghbs = torch.zeros(node_batch_size, 118 | max_node_degree, 119 | self.hidden_node_features, 120 | device=self.constants.device) 121 | node_batch_edges = torch.zeros(node_batch_size, 122 | max_node_degree, 123 | self.edge_features, 124 | device=self.constants.device) 125 | 126 | node_batch_nghb_nghb_idc = torch.cat( 127 | [torch.arange(i) for i in node_degrees] 128 | ).long() 129 | 130 | edge_batch_node_batch_idc = torch.cat( 131 | [i * torch.ones(degree) for i, degree in enumerate(node_degrees)] 132 | ).long() 133 | 134 | node_batch_node_nghb_mask = torch.zeros(node_batch_size, 135 | max_node_degree, 136 | device=self.constants.device) 137 | 138 | node_batch_node_nghb_mask[edge_batch_node_batch_idc, node_batch_nghb_nghb_idc] = 1 139 | 140 | node_batch_edges[edge_batch_node_batch_idc, node_batch_nghb_nghb_idc, :] = \ 141 | edges[edge_batch_batch_idc, edge_batch_node_idc, edge_batch_nghb_idc, :] 142 | 143 | # pad up the hidden nodes 144 | hidden_nodes = torch.zeros(nodes.shape[0], 145 | nodes.shape[1], 146 | self.hidden_node_features, 147 | device=self.constants.device) 148 | hidden_nodes[:nodes.shape[0], :nodes.shape[1], :nodes.shape[2]] = nodes.clone() 149 | 150 | for _ in range(self.message_passes): 151 | 152 | node_batch_nodes = hidden_nodes[node_batch_batch_idc, node_batch_node_idc, :] 153 | node_batch_node_nghbs[edge_batch_node_batch_idc, node_batch_nghb_nghb_idc, :] = \ 154 | hidden_nodes[edge_batch_batch_idc, edge_batch_nghb_idc, :] 155 | 156 | messages = self.aggregate_message(nodes=node_batch_nodes, 157 | node_neighbours=node_batch_node_nghbs.clone(), 158 | edges=node_batch_edges, 159 | mask=node_batch_node_nghb_mask) 160 | 161 | hidden_nodes[node_batch_batch_idc, node_batch_node_idc, :] = \ 162 | self.update(node_batch_nodes.clone(), messages) 163 | 164 | node_mask = (adjacency.sum(-1) != 0) 165 | 166 | output = self.readout(hidden_nodes, nodes, node_mask) 167 | 168 | return output 169 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/gnn/edge_mpnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines the `EdgeMPNN` class. 3 | """# load general packages and functions 4 | from collections import namedtuple 5 | import torch 6 | 7 | 8 | class EdgeMPNN(torch.nn.Module): 9 | """ 10 | Abstract `EdgeMPNN` class. A specific model using this class is defined 11 | in `mpnn.py`; this is the EMN. 12 | """ 13 | def __init__(self, constants : namedtuple) -> None: 14 | super().__init__() 15 | 16 | self.edge_features = constants.edge_features 17 | self.edge_embedding_size = constants.edge_embedding_size 18 | self.message_passes = constants.message_passes 19 | self.n_nodes_largest_graph = constants.max_n_nodes 20 | self.constants = constants 21 | 22 | def preprocess_edges(self, nodes : torch.Tensor, node_neighbours : torch.Tensor, 23 | edges : torch.Tensor) -> None: 24 | """ 25 | Edge preprocessing step, to be implemented in all `EdgeMPNN` subclasses. 26 | 27 | Args: 28 | ---- 29 | nodes (torch.Tensor) : Batch of node feature vectors. 30 | node_neighbours (torch.Tensor) : Batch of node feature vectors for neighbors. 31 | edges (torch.Tensor) : Batch of edge feature vectors. 32 | 33 | Shapes: 34 | ------ 35 | nodes : (total N nodes in batch, N node features) 36 | node_neighbours : (total N nodes in batch, max node degree, N node features) 37 | edges : (total N nodes in batch, max node degree, N edge features) 38 | """ 39 | raise NotImplementedError 40 | 41 | def propagate_edges(self, edges : torch.Tensor, ingoing_edge_memories : torch.Tensor, 42 | ingoing_edges_mask : torch.Tensor) -> None: 43 | """ 44 | Edge propagation rule, to be implemented in all `EdgeMPNN` subclasses. 45 | 46 | Args: 47 | ---- 48 | edges (torch.Tensor) : Batch of edge feature tensors. 49 | ingoing_edge_memories (torch.Tensor) : Batch of memories for all 50 | ingoing edges. 51 | ingoing_edges_mask (torch.Tensor) : Mask for ingoing edges. 52 | 53 | Shapes: 54 | ------ 55 | edges : (batch size, N nodes, N nodes, total N edge features) 56 | ingoing_edge_memories : (total N edges in batch, total N edge features) 57 | ingoing_edges_mask : (total N edges in batch, max node degree, total N edge features) 58 | """ 59 | raise NotImplementedError 60 | 61 | def readout(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor, 62 | node_mask : torch.Tensor) -> None: 63 | """ 64 | Local readout function, to be implemented in all `EdgeMPNN` subclasses. 65 | 66 | Args: 67 | ---- 68 | hidden_nodes (torch.Tensor) : Batch of node feature vectors. 69 | input_nodes (torch.Tensor) : Batch of node feature vectors. 70 | node_mask (torch.Tensor) : Mask for non-existing neighbors, where 71 | elements are 1 if corresponding element 72 | exists and 0 otherwise. 73 | 74 | Shapes: 75 | ------ 76 | hidden_nodes : (total N nodes in batch, N node features) 77 | input_nodes : (total N nodes in batch, N node features) 78 | node_mask : (total N nodes in batch, N features) 79 | """ 80 | raise NotImplementedError 81 | 82 | def forward(self, nodes : torch.Tensor, edges : torch.Tensor) -> torch.Tensor: 83 | """ 84 | Defines forward pass. 85 | 86 | Args: 87 | ---- 88 | nodes (torch.Tensor) : Batch of node feature matrices. 89 | edges (torch.Tensor) : Batch of edge feature tensors. 90 | 91 | Shapes: 92 | ------ 93 | nodes : (batch size, N nodes, N node features) 94 | edges : (batch size, N nodes, N nodes, N edge features) 95 | 96 | Returns: 97 | ------- 98 | output (torch.Tensor) : This would normally be the learned graph representation, 99 | but in all MPNN readout functions in this work, 100 | the last layer is used to predict the action 101 | probability distribution for a batch of graphs from 102 | the learned graph representation. 103 | """ 104 | adjacency = torch.sum(edges, dim=3) 105 | 106 | # indices for finding edges in batch; `edges_b_idx` is batch index, 107 | # `edges_n_idx` is the node index, and `edges_nghb_idx` is the index 108 | # that each node in `edges_n_idx` is bound to 109 | edges_b_idx, edges_n_idx, edges_nghb_idx = adjacency.nonzero(as_tuple=True) 110 | 111 | n_edges = edges_n_idx.shape[0] 112 | adj_of_edge_batch_idc = adjacency.clone().long() 113 | 114 | # +1 to distinguish idx 0 from empty elements, subtracted few lines down 115 | r = torch.arange(1, n_edges + 1, device=self.constants.device) 116 | 117 | adj_of_edge_batch_idc[edges_b_idx, edges_n_idx, edges_nghb_idx] = r 118 | 119 | ingoing_edges_eb_idx = ( 120 | torch.cat([row[row.nonzero()] for row in 121 | adj_of_edge_batch_idc[edges_b_idx, edges_nghb_idx, :]]) - 1 122 | ).squeeze() 123 | 124 | edge_degrees = adjacency[edges_b_idx, edges_nghb_idx, :].sum(-1).long() 125 | ingoing_edges_igeb_idx = torch.cat( 126 | [i * torch.ones(d) for i, d in enumerate(edge_degrees)] 127 | ).long() 128 | ingoing_edges_ige_idx = torch.cat([torch.arange(i) for i in edge_degrees]).long() 129 | 130 | 131 | batch_size = adjacency.shape[0] 132 | n_nodes = adjacency.shape[1] 133 | max_node_degree = adjacency.sum(-1).max().int() 134 | edge_memories = torch.zeros(n_edges, 135 | self.edge_embedding_size, 136 | device=self.constants.device) 137 | 138 | ingoing_edge_memories = torch.zeros(n_edges, max_node_degree, 139 | self.edge_embedding_size, 140 | device=self.constants.device) 141 | ingoing_edges_mask = torch.zeros(n_edges, 142 | max_node_degree, 143 | device=self.constants.device) 144 | 145 | edge_batch_nodes = nodes[edges_b_idx, edges_n_idx, :] 146 | # **note: "nghb{s}" == "neighbour(s)" 147 | edge_batch_nghbs = nodes[edges_b_idx, edges_nghb_idx, :] 148 | edge_batch_edges = edges[edges_b_idx, edges_n_idx, edges_nghb_idx, :] 149 | edge_batch_edges = self.preprocess_edges(nodes=edge_batch_nodes, 150 | node_neighbours=edge_batch_nghbs, 151 | edges=edge_batch_edges) 152 | 153 | # remove h_ji:s influence on h_ij 154 | ingoing_edges_nghb_idx = edges_nghb_idx[ingoing_edges_eb_idx] 155 | ingoing_edges_receiving_edge_n_idx = edges_n_idx[ingoing_edges_igeb_idx] 156 | diff_idx = (ingoing_edges_receiving_edge_n_idx != ingoing_edges_nghb_idx).nonzero() 157 | 158 | try: 159 | ingoing_edges_eb_idx = ingoing_edges_eb_idx[diff_idx].squeeze() 160 | ingoing_edges_ige_idx = ingoing_edges_ige_idx[diff_idx].squeeze() 161 | ingoing_edges_igeb_idx = ingoing_edges_igeb_idx[diff_idx].squeeze() 162 | except: 163 | pass 164 | 165 | ingoing_edges_mask[ingoing_edges_igeb_idx, ingoing_edges_ige_idx] = 1 166 | 167 | for _ in range(self.message_passes): 168 | ingoing_edge_memories[ingoing_edges_igeb_idx, ingoing_edges_ige_idx, :] = \ 169 | edge_memories[ingoing_edges_eb_idx, :] 170 | edge_memories = self.propagate_edges( 171 | edges=edge_batch_edges, 172 | ingoing_edge_memories=ingoing_edge_memories.clone(), 173 | ingoing_edges_mask=ingoing_edges_mask 174 | ) 175 | 176 | node_mask = (adjacency.sum(-1) != 0) 177 | 178 | node_sets = torch.zeros(batch_size, 179 | n_nodes, 180 | max_node_degree, 181 | self.edge_embedding_size, 182 | device=self.constants.device) 183 | 184 | edge_batch_edge_memory_idc = torch.cat( 185 | [torch.arange(row.sum()) for row in adjacency.view(-1, n_nodes)] 186 | ).long() 187 | 188 | node_sets[edges_b_idx, edges_n_idx, edge_batch_edge_memory_idc, :] = edge_memories 189 | graph_sets = node_sets.sum(2) 190 | 191 | output = self.readout(graph_sets, graph_sets, node_mask) 192 | return output 193 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/gnn/modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines MPNN modules and readout functions, and APD readout functions. 3 | """ 4 | # load general packages and functions 5 | from collections import namedtuple 6 | import torch 7 | 8 | # load GraphINVENT-specific functions 9 | # (none) 10 | 11 | 12 | class GraphGather(torch.nn.Module): 13 | """ 14 | GGNN readout function. 15 | """ 16 | def __init__(self, node_features : int, hidden_node_features : int, 17 | out_features : int, att_depth : int, att_hidden_dim : int, 18 | att_dropout_p : float, emb_depth : int, emb_hidden_dim : int, 19 | emb_dropout_p : float, big_positive : float) -> None: 20 | 21 | super().__init__() 22 | 23 | self.big_positive = big_positive 24 | 25 | self.att_nn = MLP( 26 | in_features=node_features + hidden_node_features, 27 | hidden_layer_sizes=[att_hidden_dim] * att_depth, 28 | out_features=out_features, 29 | dropout_p=att_dropout_p 30 | ) 31 | 32 | self.emb_nn = MLP( 33 | in_features=hidden_node_features, 34 | hidden_layer_sizes=[emb_hidden_dim] * emb_depth, 35 | out_features=out_features, 36 | dropout_p=emb_dropout_p 37 | ) 38 | 39 | def forward(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor, 40 | node_mask : torch.Tensor) -> torch.Tensor: 41 | """ 42 | Defines forward pass. 43 | """ 44 | Softmax = torch.nn.Softmax(dim=1) 45 | 46 | cat = torch.cat((hidden_nodes, input_nodes), dim=2) 47 | energy_mask = (node_mask == 0).float() * self.big_positive 48 | energies = self.att_nn(cat) - energy_mask.unsqueeze(-1) 49 | attention = Softmax(energies) 50 | embedding = self.emb_nn(hidden_nodes) 51 | 52 | return torch.sum(attention * embedding, dim=1) 53 | 54 | 55 | class Set2Vec(torch.nn.Module): 56 | """ 57 | S2V readout function. 58 | """ 59 | def __init__(self, node_features : int, hidden_node_features : int, 60 | lstm_computations : int, memory_size : int, 61 | constants : namedtuple) -> None: 62 | 63 | super().__init__() 64 | 65 | self.constants = constants 66 | self.lstm_computations = lstm_computations 67 | self.memory_size = memory_size 68 | 69 | self.embedding_matrix = torch.nn.Linear( 70 | in_features=node_features + hidden_node_features, 71 | out_features=self.memory_size, 72 | bias=True 73 | ) 74 | 75 | self.lstm = torch.nn.LSTMCell( 76 | input_size=self.memory_size, 77 | hidden_size=self.memory_size, 78 | bias=True 79 | ) 80 | 81 | def forward(self, hidden_output_nodes : torch.Tensor, input_nodes : torch.Tensor, 82 | node_mask : torch.Tensor) -> torch.Tensor: 83 | """ 84 | Defines forward pass. 85 | """ 86 | Softmax = torch.nn.Softmax(dim=1) 87 | 88 | batch_size = input_nodes.shape[0] 89 | energy_mask = torch.bitwise_not(node_mask).float() * self.C.big_negative 90 | lstm_input = torch.zeros(batch_size, self.memory_size, device=self.constants.device) 91 | cat = torch.cat((hidden_output_nodes, input_nodes), dim=2) 92 | memory = self.embedding_matrix(cat) 93 | hidden_state = torch.zeros(batch_size, self.memory_size, device=self.constants.device) 94 | cell_state = torch.zeros(batch_size, self.memory_size, device=self.constants.device) 95 | 96 | for _ in range(self.lstm_computations): 97 | query, cell_state = self.lstm(lstm_input, (hidden_state, cell_state)) 98 | 99 | # dot product query x memory 100 | energies = (query.view(batch_size, 1, self.memory_size) * memory).sum(dim=-1) 101 | attention = Softmax(energies + energy_mask) 102 | read = (attention.unsqueeze(-1) * memory).sum(dim=1) 103 | 104 | hidden_state = query 105 | lstm_input = read 106 | 107 | cat = torch.cat((query, read), dim=1) 108 | return cat 109 | 110 | 111 | class MLP(torch.nn.Module): 112 | """ 113 | Multi-layer perceptron. Applies SELU after every linear layer. 114 | 115 | Args: 116 | ---- 117 | in_features (int) : Size of each input sample. 118 | hidden_layer_sizes (list) : Hidden layer sizes. 119 | out_features (int) : Size of each output sample. 120 | dropout_p (float) : Probability of dropping a weight. 121 | """ 122 | 123 | def __init__(self, in_features : int, hidden_layer_sizes : list, out_features : int, 124 | dropout_p : float) -> None: 125 | super().__init__() 126 | 127 | activation_function = torch.nn.SELU 128 | 129 | # create list of all layer feature sizes 130 | fs = [in_features, *hidden_layer_sizes, out_features] 131 | 132 | # create list of linear_blocks 133 | layers = [self._linear_block(in_f, out_f, 134 | activation_function, 135 | dropout_p) 136 | for in_f, out_f in zip(fs, fs[1:])] 137 | 138 | # concatenate modules in all sequentials in layers list 139 | layers = [module for sq in layers for module in sq.children()] 140 | 141 | # add modules to sequential container 142 | self.seq = torch.nn.Sequential(*layers) 143 | 144 | def _linear_block(self, in_f : int, out_f : int, activation : torch.nn.Module, 145 | dropout_p : float) -> torch.nn.Sequential: 146 | """ 147 | Returns a linear block consisting of a linear layer, an activation function 148 | (SELU), and dropout (optional) stack. 149 | 150 | Args: 151 | ---- 152 | in_f (int) : Size of each input sample. 153 | out_f (int) : Size of each output sample. 154 | activation (torch.nn.Module) : Activation function. 155 | dropout_p (float) : Probability of dropping a weight. 156 | 157 | Returns: 158 | ------- 159 | torch.nn.Sequential : The linear block. 160 | """ 161 | # bias must be used in most MLPs in our models to learn from empty graphs 162 | linear = torch.nn.Linear(in_f, out_f, bias=True) 163 | torch.nn.init.xavier_uniform_(linear.weight) 164 | return torch.nn.Sequential(linear, activation(), torch.nn.AlphaDropout(dropout_p)) 165 | 166 | def forward(self, layers_input : torch.nn.Sequential) -> torch.nn.Sequential: 167 | """ 168 | Defines forward pass. 169 | """ 170 | return self.seq(layers_input) 171 | 172 | 173 | class GlobalReadout(torch.nn.Module): 174 | """ 175 | Global readout function class. Used to predict the action probability distributions 176 | (APDs) for molecular graphs. 177 | 178 | The first tier of two `MLP`s take as input, for each graph in the batch, the 179 | final transformed node feature vectors. These feed-forward networks correspond 180 | to the preliminary "f_add" and "f_conn" distributions. 181 | 182 | The second tier of three `MLP`s takes as input the output of the first tier 183 | of `MLP`s (the "preliminary" APDs) as well as the graph embeddings for all 184 | graphs in the batch. Output are the final APD components, which are then flattened 185 | and concatenated. No activation function is applied after the final layer, so 186 | that this can be done outside (e.g. in the loss function, and before sampling). 187 | """ 188 | def __init__(self, f_add_elems : int, f_conn_elems : int, f_term_elems : int, 189 | mlp1_depth : int, mlp1_dropout_p : float, mlp1_hidden_dim : int, 190 | mlp2_depth : int, mlp2_dropout_p : float, mlp2_hidden_dim : int, 191 | graph_emb_size : int, max_n_nodes : int, node_emb_size : int, 192 | device : str) -> None: 193 | super().__init__() 194 | 195 | self.device = device 196 | 197 | # preliminary f_add 198 | self.fAddNet1 = MLP( 199 | in_features=node_emb_size, 200 | hidden_layer_sizes=[mlp1_hidden_dim] * mlp1_depth, 201 | out_features=f_add_elems, 202 | dropout_p=mlp1_dropout_p 203 | ) 204 | 205 | # preliminary f_conn 206 | self.fConnNet1 = MLP( 207 | in_features=node_emb_size, 208 | hidden_layer_sizes=[mlp1_hidden_dim] * mlp1_depth, 209 | out_features=f_conn_elems, 210 | dropout_p=mlp1_dropout_p 211 | ) 212 | 213 | # final f_add 214 | self.fAddNet2 = MLP( 215 | in_features=(max_n_nodes * f_add_elems + graph_emb_size), 216 | hidden_layer_sizes=[mlp2_hidden_dim] * mlp2_depth, 217 | out_features=f_add_elems * max_n_nodes, 218 | dropout_p=mlp2_dropout_p 219 | ) 220 | 221 | # final f_conn 222 | self.fConnNet2 = MLP( 223 | in_features=(max_n_nodes * f_conn_elems + graph_emb_size), 224 | hidden_layer_sizes=[mlp2_hidden_dim] * mlp2_depth, 225 | out_features=f_conn_elems * max_n_nodes, 226 | dropout_p=mlp2_dropout_p 227 | ) 228 | 229 | # final f_term (only takes as input graph embeddings) 230 | self.fTermNet2 = MLP( 231 | in_features=graph_emb_size, 232 | hidden_layer_sizes=[mlp2_hidden_dim] * mlp2_depth, 233 | out_features=f_term_elems, 234 | dropout_p=mlp2_dropout_p 235 | ) 236 | 237 | def forward(self, node_level_output : torch.Tensor, 238 | graph_embedding_batch : torch.Tensor) -> torch.Tensor: 239 | """ 240 | Defines forward pass. 241 | """ 242 | if self.device == "cuda": 243 | self.fAddNet1 = self.fAddNet1.to("cuda", non_blocking=True) 244 | self.fConnNet1 = self.fConnNet1.to("cuda", non_blocking=True) 245 | self.fAddNet2 = self.fAddNet2.to("cuda", non_blocking=True) 246 | self.fConnNet2 = self.fConnNet2.to("cuda", non_blocking=True) 247 | self.fTermNet2 = self.fTermNet2.to("cuda", non_blocking=True) 248 | 249 | # get preliminary f_add and f_conn 250 | f_add_1 = self.fAddNet1(node_level_output) 251 | f_conn_1 = self.fConnNet1(node_level_output) 252 | 253 | if self.device == "cuda": 254 | f_add_1 = f_add_1.to("cuda", non_blocking=True) 255 | f_conn_1 = f_conn_1.to("cuda", non_blocking=True) 256 | 257 | # reshape preliminary APDs into flattenened vectors (e.g. one vector per 258 | # graph in batch) 259 | f_add_1_size = f_add_1.size() 260 | f_conn_1_size = f_conn_1.size() 261 | f_add_1 = f_add_1.view((f_add_1_size[0], f_add_1_size[1] * f_add_1_size[2])) 262 | f_conn_1 = f_conn_1.view((f_conn_1_size[0], f_conn_1_size[1] * f_conn_1_size[2])) 263 | 264 | # get final f_add, f_conn, and f_term 265 | f_add_2 = self.fAddNet2( 266 | torch.cat((f_add_1, graph_embedding_batch), dim=1).unsqueeze(dim=1) 267 | ) 268 | f_conn_2 = self.fConnNet2( 269 | torch.cat((f_conn_1, graph_embedding_batch), dim=1).unsqueeze(dim=1) 270 | ) 271 | f_term_2 = self.fTermNet2(graph_embedding_batch) 272 | 273 | if self.device == "cuda": 274 | f_add_2 = f_add_2.to("cuda", non_blocking=True) 275 | f_conn_2 = f_conn_2.to("cuda", non_blocking=True) 276 | f_term_2 = f_term_2.to("cuda", non_blocking=True) 277 | 278 | # flatten and concatenate 279 | cat = torch.cat((f_add_2.squeeze(dim=1), f_conn_2.squeeze(dim=1), f_term_2), dim=1) 280 | 281 | return cat # note: no activation function before returning 282 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/gnn/mpnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines specific MPNN implementations. 3 | """ 4 | # load general packages and functions 5 | from collections import namedtuple 6 | import math 7 | import torch 8 | 9 | # load GraphINVENT-specific functions 10 | import gnn.aggregation_mpnn 11 | import gnn.edge_mpnn 12 | import gnn.summation_mpnn 13 | import gnn.modules 14 | 15 | 16 | class MNN(gnn.summation_mpnn.SummationMPNN): 17 | """ 18 | The "message neural network" model. 19 | """ 20 | def __init__(self, constants : namedtuple) -> None: 21 | super().__init__(constants) 22 | 23 | self.constants = constants 24 | message_weights = torch.Tensor(self.constants.message_size, 25 | self.constants.hidden_node_features, 26 | self.constants.n_edge_features) 27 | if self.constants.device == "cuda": 28 | message_weights = message_weights.to("cuda", non_blocking=True) 29 | 30 | self.message_weights = torch.nn.Parameter(message_weights) 31 | 32 | self.gru = torch.nn.GRUCell( 33 | input_size=self.constants.message_size, 34 | hidden_size=self.constants.hidden_node_features, 35 | bias=True 36 | ) 37 | 38 | self.APDReadout = gnn.modules.GlobalReadout( 39 | node_emb_size=self.constants.hidden_node_features, 40 | graph_emb_size=self.constants.hidden_node_features, 41 | mlp1_hidden_dim=self.constants.mlp1_hidden_dim, 42 | mlp1_depth=self.constants.mlp1_depth, 43 | mlp1_dropout_p=self.constants.mlp1_dropout_p, 44 | mlp2_hidden_dim=self.constants.mlp2_hidden_dim, 45 | mlp2_depth=self.constants.mlp2_depth, 46 | mlp2_dropout_p=self.constants.mlp2_dropout_p, 47 | f_add_elems=self.constants.len_f_add_per_node, 48 | f_conn_elems=self.constants.len_f_conn_per_node, 49 | f_term_elems=1, 50 | max_n_nodes=self.constants.max_n_nodes, 51 | device=self.constants.device, 52 | ) 53 | 54 | self.reset_parameters() 55 | 56 | def reset_parameters(self) -> None: 57 | stdev = 1.0 / math.sqrt(self.message_weights.size(1)) 58 | self.message_weights.data.uniform_(-stdev, stdev) 59 | 60 | def message_terms(self, nodes : torch.Tensor, node_neighbours : torch.Tensor, 61 | edges : torch.Tensor) -> torch.Tensor: 62 | edges_view = edges.view(-1, 1, 1, self.constants.n_edge_features) 63 | weights_for_each_edge = (edges_view * self.message_weights.unsqueeze(0)).sum(3) 64 | return torch.matmul(weights_for_each_edge, 65 | node_neighbours.unsqueeze(-1)).squeeze() 66 | 67 | def update(self, nodes : torch.Tensor, messages : torch.Tensor) -> torch.Tensor: 68 | return self.gru(messages, nodes) 69 | 70 | def readout(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor, 71 | node_mask : torch.Tensor) -> torch.Tensor: 72 | graph_embeddings = torch.sum(hidden_nodes, dim=1) 73 | output = self.APDReadout(hidden_nodes, graph_embeddings) 74 | return output 75 | 76 | 77 | class S2V(gnn.summation_mpnn.SummationMPNN): 78 | """ 79 | The "set2vec" model. 80 | """ 81 | def __init__(self, constants : namedtuple) -> None: 82 | super().__init__(constants) 83 | 84 | self.constants = constants 85 | 86 | self.enn = gnn.modules.MLP( 87 | in_features=self.constants.n_edge_features, 88 | hidden_layer_sizes=[self.constants.enn_hidden_dim] * self.constants.enn_depth, 89 | out_features=self.constants.hidden_node_features * self.constants.message_size, 90 | dropout_p=self.constants.enn_dropout_p 91 | ) 92 | 93 | self.gru = torch.nn.GRUCell( 94 | input_size=self.constants.message_size, 95 | hidden_size=self.constants.hidden_node_features, 96 | bias=True 97 | ) 98 | 99 | self.s2v = gnn.modules.Set2Vec( 100 | node_features=self.constants.n_node_features, 101 | hidden_node_features=self.constants.hidden_node_features, 102 | lstm_computations=self.constants.s2v_lstm_computations, 103 | memory_size=self.constants.s2v_memory_size 104 | ) 105 | 106 | self.APDReadout = gnn.modules.GlobalReadout( 107 | node_emb_size=self.constants.hidden_node_features, 108 | graph_emb_size=self.constants.s2v_memory_size * 2, 109 | mlp1_hidden_dim=self.constants.mlp1_hidden_dim, 110 | mlp1_depth=self.constants.mlp1_depth, 111 | mlp1_dropout_p=self.constants.mlp1_dropout_p, 112 | mlp2_hidden_dim=self.constants.mlp2_hidden_dim, 113 | mlp2_depth=self.constants.mlp2_depth, 114 | mlp2_dropout_p=self.constants.mlp2_dropout_p, 115 | f_add_elems=self.constants.len_f_add_per_node, 116 | f_conn_elems=self.constants.len_f_conn_per_node, 117 | f_term_elems=1, 118 | max_n_nodes=self.constants.max_n_nodes, 119 | device=self.constants.device, 120 | ) 121 | 122 | def message_terms(self, nodes : torch.Tensor, node_neighbours : torch.Tensor, 123 | edges : torch.Tensor) -> torch.Tensor: 124 | enn_output = self.enn(edges) 125 | matrices = enn_output.view(-1, 126 | self.constants.message_size, 127 | self.constants.hidden_node_features) 128 | msg_terms = torch.matmul(matrices, 129 | node_neighbours.unsqueeze(-1)).squeeze(-1) 130 | return msg_terms 131 | 132 | def update(self, nodes : torch.Tensor, messages : torch.Tensor) -> torch.Tensor: 133 | return self.gru(messages, nodes) 134 | 135 | def readout(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor, 136 | node_mask : torch.Tensor) -> torch.Tensor: 137 | graph_embeddings = self.s2v(hidden_nodes, input_nodes, node_mask) 138 | output = self.APDReadout(hidden_nodes, graph_embeddings) 139 | return output 140 | 141 | 142 | class AttentionS2V(gnn.aggregation_mpnn.AggregationMPNN): 143 | """ 144 | The "set2vec with attention" model. 145 | """ 146 | def __init__(self, constants : namedtuple) -> None: 147 | 148 | super().__init__(constants) 149 | 150 | self.constants = constants 151 | 152 | self.enn = gnn.modules.MLP( 153 | in_features=self.constants.n_edge_features, 154 | hidden_layer_sizes=[self.constants.enn_hidden_dim] * self.constants.enn_depth, 155 | out_features=self.constants.hidden_node_features * self.constants.message_size, 156 | dropout_p=self.constants.enn_dropout_p 157 | ) 158 | 159 | self.att_enn = gnn.modules.MLP( 160 | in_features=self.constants.hidden_node_features + self.constants.n_edge_features, 161 | hidden_layer_sizes=[self.constants.att_hidden_dim] * self.constants.att_depth, 162 | out_features=self.constants.message_size, 163 | dropout_p=self.constants.att_dropout_p 164 | ) 165 | 166 | self.gru = torch.nn.GRUCell( 167 | input_size=self.constants.message_size, 168 | hidden_size=self.constants.hidden_node_features, 169 | bias=True 170 | ) 171 | 172 | self.s2v = gnn.modules.Set2Vec( 173 | node_features=self.constants.n_node_features, 174 | hidden_node_features=self.constants.hidden_node_features, 175 | lstm_computations=self.constants.s2v_lstm_computations, 176 | memory_size=self.constants.s2v_memory_size, 177 | ) 178 | 179 | self.APDReadout = gnn.modules.GlobalReadout( 180 | node_emb_size=self.constants.hidden_node_features, 181 | graph_emb_size=self.constants.s2v_memory_size * 2, 182 | mlp1_hidden_dim=self.constants.mlp1_hidden_dim, 183 | mlp1_depth=self.constants.mlp1_depth, 184 | mlp1_dropout_p=self.constants.mlp1_dropout_p, 185 | mlp2_hidden_dim=self.constants.mlp2_hidden_dim, 186 | mlp2_depth=self.constants.mlp2_depth, 187 | mlp2_dropout_p=self.constants.mlp2_dropout_p, 188 | f_add_elems=self.constants.len_f_add_per_node, 189 | f_conn_elems=self.constants.len_f_conn_per_node, 190 | f_term_elems=1, 191 | max_n_nodes=self.constants.max_n_nodes, 192 | device=self.constants.device, 193 | ) 194 | 195 | def aggregate_message(self, nodes : torch.Tensor, 196 | node_neighbours : torch.Tensor, 197 | edges : torch.Tensor, 198 | mask : torch.Tensor) -> torch.Tensor: 199 | Softmax = torch.nn.Softmax(dim=1) 200 | max_node_degree = node_neighbours.shape[1] 201 | 202 | enn_output = self.enn(edges) 203 | matrices = enn_output.view(-1, 204 | max_node_degree, 205 | self.constants.message_size, 206 | self.constants.hidden_node_features) 207 | message_terms = torch.matmul(matrices, node_neighbours.unsqueeze(-1)).squeeze() 208 | 209 | att_enn_output = self.att_enn(torch.cat((edges, node_neighbours), dim=2)) 210 | energies = att_enn_output.view(-1, max_node_degree, self.constants.message_size) 211 | energy_mask = (1 - mask).float() * self.constants.big_negative 212 | weights = Softmax(energies + energy_mask.unsqueeze(-1)) 213 | 214 | return (weights * message_terms).sum(1) 215 | 216 | def update(self, nodes : torch.Tensor, messages : torch.Tensor) -> torch.Tensor: 217 | if self.constants.device == "cuda": 218 | messages = messages + torch.zeros(self.constants.message_size, device="cuda") 219 | return self.gru(messages, nodes) 220 | 221 | def readout(self, hidden_nodes : torch.Tensor, 222 | input_nodes : torch.Tensor, 223 | node_mask : torch.Tensor) -> torch.Tensor: 224 | graph_embeddings = self.s2v(hidden_nodes, input_nodes, node_mask) 225 | output = self.APDReadout(hidden_nodes, graph_embeddings) 226 | return output 227 | 228 | 229 | class GGNN(gnn.summation_mpnn.SummationMPNN): 230 | """ 231 | The "gated-graph neural network" model. 232 | """ 233 | def __init__(self, constants : namedtuple) -> None: 234 | super().__init__(constants) 235 | 236 | self.constants = constants 237 | 238 | self.msg_nns = torch.nn.ModuleList() 239 | for _ in range(self.constants.n_edge_features): 240 | self.msg_nns.append( 241 | gnn.modules.MLP( 242 | in_features=self.constants.hidden_node_features, 243 | hidden_layer_sizes=[self.constants.enn_hidden_dim] * self.constants.enn_depth, 244 | out_features=self.constants.message_size, 245 | dropout_p=self.constants.enn_dropout_p, 246 | ) 247 | ) 248 | 249 | self.gru = torch.nn.GRUCell( 250 | input_size=self.constants.message_size, 251 | hidden_size=self.constants.hidden_node_features, 252 | bias=True 253 | ) 254 | 255 | self.gather = gnn.modules.GraphGather( 256 | node_features=self.constants.n_node_features, 257 | hidden_node_features=self.constants.hidden_node_features, 258 | out_features=self.constants.gather_width, 259 | att_depth=self.constants.gather_att_depth, 260 | att_hidden_dim=self.constants.gather_att_hidden_dim, 261 | att_dropout_p=self.constants.gather_att_dropout_p, 262 | emb_depth=self.constants.gather_emb_depth, 263 | emb_hidden_dim=self.constants.gather_emb_hidden_dim, 264 | emb_dropout_p=self.constants.gather_emb_dropout_p, 265 | big_positive=self.constants.big_positive 266 | ) 267 | 268 | self.APDReadout = gnn.modules.GlobalReadout( 269 | node_emb_size=self.constants.hidden_node_features, 270 | graph_emb_size=self.constants.gather_width, 271 | mlp1_hidden_dim=self.constants.mlp1_hidden_dim, 272 | mlp1_depth=self.constants.mlp1_depth, 273 | mlp1_dropout_p=self.constants.mlp1_dropout_p, 274 | mlp2_hidden_dim=self.constants.mlp2_hidden_dim, 275 | mlp2_depth=self.constants.mlp2_depth, 276 | mlp2_dropout_p=self.constants.mlp2_dropout_p, 277 | f_add_elems=self.constants.len_f_add_per_node, 278 | f_conn_elems=self.constants.len_f_conn_per_node, 279 | f_term_elems=1, 280 | max_n_nodes=self.constants.max_n_nodes, 281 | device=self.constants.device, 282 | ) 283 | 284 | def message_terms(self, nodes : torch.Tensor, node_neighbours : torch.Tensor, 285 | edges : torch.Tensor) -> torch.Tensor: 286 | edges_v = edges.view(-1, self.constants.n_edge_features, 1) 287 | node_neighbours_v = edges_v * node_neighbours.view(-1, 288 | 1, 289 | self.constants.hidden_node_features) 290 | terms_masked_per_edge = [ 291 | edges_v[:, i, :] * self.msg_nns[i](node_neighbours_v[:, i, :]) 292 | for i in range(self.constants.n_edge_features) 293 | ] 294 | return sum(terms_masked_per_edge) 295 | 296 | def update(self, nodes : torch.Tensor, messages : torch.Tensor) -> torch.Tensor: 297 | return self.gru(messages, nodes) 298 | 299 | def readout(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor, 300 | node_mask : torch.Tensor) -> torch.Tensor: 301 | graph_embeddings = self.gather(hidden_nodes, input_nodes, node_mask) 302 | output = self.APDReadout(hidden_nodes, graph_embeddings) 303 | return output 304 | 305 | 306 | class AttentionGGNN(gnn.aggregation_mpnn.AggregationMPNN): 307 | """ 308 | The "GGNN with attention" model. 309 | """ 310 | def __init__(self, constants : namedtuple) -> None: 311 | super().__init__(constants) 312 | 313 | self.constants = constants 314 | self.msg_nns = torch.nn.ModuleList() 315 | self.att_nns = torch.nn.ModuleList() 316 | 317 | for _ in range(self.constants.n_edge_features): 318 | self.msg_nns.append( 319 | gnn.modules.MLP( 320 | in_features=self.constants.hidden_node_features, 321 | hidden_layer_sizes=[self.constants.msg_hidden_dim] * self.constants.msg_depth, 322 | out_features=self.constants.message_size, 323 | dropout_p=self.constants.msg_dropout_p, 324 | ) 325 | ) 326 | self.att_nns.append( 327 | gnn.modules.MLP( 328 | in_features=self.constants.hidden_node_features, 329 | hidden_layer_sizes=[self.constants.att_hidden_dim] * self.constants.att_depth, 330 | out_features=self.constants.message_size, 331 | dropout_p=self.constants.att_dropout_p, 332 | ) 333 | ) 334 | 335 | self.gru = torch.nn.GRUCell( 336 | input_size=self.constants.message_size, 337 | hidden_size=self.constants.hidden_node_features, 338 | bias=True 339 | ) 340 | 341 | self.gather = gnn.modules.GraphGather( 342 | node_features=self.constants.n_node_features, 343 | hidden_node_features=self.constants.hidden_node_features, 344 | out_features=self.constants.gather_width, 345 | att_depth=self.constants.gather_att_depth, 346 | att_hidden_dim=self.constants.gather_att_hidden_dim, 347 | att_dropout_p=self.constants.gather_att_dropout_p, 348 | emb_depth=self.constants.gather_emb_depth, 349 | emb_hidden_dim=self.constants.gather_emb_hidden_dim, 350 | emb_dropout_p=self.constants.gather_emb_dropout_p, 351 | big_positive=self.constants.big_positive 352 | ) 353 | 354 | self.APDReadout = gnn.modules.GlobalReadout( 355 | node_emb_size=self.constants.hidden_node_features, 356 | graph_emb_size=self.constants.gather_width, 357 | mlp1_hidden_dim=self.constants.mlp1_hidden_dim, 358 | mlp1_depth=self.constants.mlp1_depth, 359 | mlp1_dropout_p=self.constants.mlp1_dropout_p, 360 | mlp2_hidden_dim=self.constants.mlp2_hidden_dim, 361 | mlp2_depth=self.constants.mlp2_depth, 362 | mlp2_dropout_p=self.constants.mlp2_dropout_p, 363 | f_add_elems=self.constants.len_f_add_per_node, 364 | f_conn_elems=self.constants.len_f_conn_per_node, 365 | f_term_elems=1, 366 | max_n_nodes=self.constants.max_n_nodes, 367 | device=self.constants.device, 368 | ) 369 | 370 | def aggregate_message(self, nodes : torch.Tensor, node_neighbours : torch.Tensor, 371 | edges : torch.Tensor, mask : torch.Tensor) -> torch.Tensor: 372 | Softmax = torch.nn.Softmax(dim=1) 373 | 374 | energy_mask = (mask == 0).float() * self.constants.big_positive 375 | 376 | embeddings_masked_per_edge = [ 377 | edges[:, :, i].unsqueeze(-1) * self.msg_nns[i](node_neighbours) 378 | for i in range(self.constants.n_edge_features) 379 | ] 380 | energies_masked_per_edge = [ 381 | edges[:, :, i].unsqueeze(-1) * self.att_nns[i](node_neighbours) 382 | for i in range(self.constants.n_edge_features) 383 | ] 384 | 385 | embedding = sum(embeddings_masked_per_edge) 386 | energies = sum(energies_masked_per_edge) - energy_mask.unsqueeze(-1) 387 | attention = Softmax(energies) 388 | 389 | return torch.sum(attention * embedding, dim=1) 390 | 391 | def update(self, nodes : torch.Tensor, messages : torch.Tensor) -> torch.Tensor: 392 | return self.gru(messages, nodes) 393 | 394 | def readout(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor, 395 | node_mask : torch.Tensor) -> torch.Tensor: 396 | graph_embeddings = self.gather(hidden_nodes, input_nodes, node_mask) 397 | output = self.APDReadout(hidden_nodes, graph_embeddings) 398 | return output 399 | 400 | 401 | class EMN(gnn.edge_mpnn.EdgeMPNN): 402 | """ 403 | The "edge memory network" model. 404 | """ 405 | def __init__(self, constants : namedtuple) -> None: 406 | super().__init__(constants) 407 | 408 | self.constants = constants 409 | 410 | self.embedding_nn = gnn.modules.MLP( 411 | in_features=self.constants.n_node_features * 2 + self.constants.n_edge_features, 412 | hidden_layer_sizes=[self.constants.edge_emb_hidden_dim] *self.constants.edge_emb_depth, 413 | out_features=self.constants.edge_emb_size, 414 | dropout_p=self.constants.edge_emb_dropout_p, 415 | ) 416 | 417 | self.emb_msg_nn = gnn.modules.MLP( 418 | in_features=self.constants.edge_emb_size, 419 | hidden_layer_sizes=[self.constants.msg_hidden_dim] * self.constants.msg_depth, 420 | out_features=self.constants.edge_emb_size, 421 | dropout_p=self.constants.msg_dropout_p, 422 | ) 423 | 424 | self.att_msg_nn = gnn.modules.MLP( 425 | in_features=self.constants.edge_emb_size, 426 | hidden_layer_sizes=[self.constants.att_hidden_dim] * self.constants.att_depth, 427 | out_features=self.constants.edge_emb_size, 428 | dropout_p=self.constants.att_dropout_p, 429 | ) 430 | 431 | self.gru = torch.nn.GRUCell( 432 | input_size=self.constants.edge_emb_size, 433 | hidden_size=self.constants.edge_emb_size, 434 | bias=True 435 | ) 436 | 437 | self.gather = gnn.modules.GraphGather( 438 | node_features=self.constants.edge_emb_size, 439 | hidden_node_features=self.constants.edge_emb_size, 440 | out_features=self.constants.gather_width, 441 | att_depth=self.constants.gather_att_depth, 442 | att_hidden_dim=self.constants.gather_att_hidden_dim, 443 | att_dropout_p=self.constants.gather_att_dropout_p, 444 | emb_depth=self.constants.gather_emb_depth, 445 | emb_hidden_dim=self.constants.gather_emb_hidden_dim, 446 | emb_dropout_p=self.constants.gather_emb_dropout_p, 447 | big_positive=self.constants.big_positive 448 | ) 449 | 450 | self.APDReadout = gnn.modules.GlobalReadout( 451 | node_emb_size=self.constants.edge_emb_size, 452 | graph_emb_size=self.constants.gather_width, 453 | mlp1_hidden_dim=self.constants.mlp1_hidden_dim, 454 | mlp1_depth=self.constants.mlp1_depth, 455 | mlp1_dropout_p=self.constants.mlp1_dropout_p, 456 | mlp2_hidden_dim=self.constants.mlp2_hidden_dim, 457 | mlp2_depth=self.constants.mlp2_depth, 458 | mlp2_dropout_p=self.constants.mlp2_dropout_p, 459 | f_add_elems=self.constants.len_f_add_per_node, 460 | f_conn_elems=self.constants.len_f_conn_per_node, 461 | f_term_elems=1, 462 | max_n_nodes=self.constants.max_n_nodes, 463 | device=self.constants.device, 464 | ) 465 | 466 | def preprocess_edges(self, nodes : torch.Tensor, node_neighbours : torch.Tensor, 467 | edges : torch.Tensor) -> torch.Tensor: 468 | cat = torch.cat((nodes, node_neighbours, edges), dim=1) 469 | return torch.tanh(self.embedding_nn(cat)) 470 | 471 | def propagate_edges(self, edges : torch.Tensor, ingoing_edge_memories : torch.Tensor, 472 | ingoing_edges_mask : torch.Tensor) -> torch.Tensor: 473 | Softmax = torch.nn.Softmax(dim=1) 474 | 475 | energy_mask = ( 476 | (1 - ingoing_edges_mask).float() * self.constants.big_negative 477 | ).unsqueeze(-1) 478 | cat = torch.cat((edges.unsqueeze(1), ingoing_edge_memories), dim=1) 479 | embeddings = self.emb_msg_nn(cat) 480 | edge_energy = self.att_msg_nn(edges) 481 | ing_memory_energies = self.att_msg_nn(ingoing_edge_memories) + energy_mask 482 | energies = torch.cat((edge_energy.unsqueeze(1), ing_memory_energies), dim=1) 483 | attention = Softmax(energies) 484 | 485 | # set aggregation of set of given edge feature and ingoing edge memories 486 | message = (attention * embeddings).sum(dim=1) 487 | 488 | return self.gru(message) # return hidden state 489 | 490 | def readout(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor, 491 | node_mask : torch.Tensor) -> torch.Tensor: 492 | graph_embeddings = self.gather(hidden_nodes, input_nodes, node_mask) 493 | output = self.APDReadout(hidden_nodes, graph_embeddings) 494 | return output 495 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/gnn/summation_mpnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines the `SummationMPNN` class. 3 | """ 4 | # load general packages and functions 5 | from collections import namedtuple 6 | import torch 7 | 8 | 9 | class SummationMPNN(torch.nn.Module): 10 | """ 11 | Abstract `SummationMPNN` class. Specific models using this class are 12 | defined in `mpnn.py`; these are MNN, S2V, and GGNN. 13 | """ 14 | def __init__(self, constants : namedtuple): 15 | 16 | super().__init__() 17 | 18 | self.hidden_node_features = constants.hidden_node_features 19 | self.edge_features = constants.n_edge_features 20 | self.message_size = constants.message_size 21 | self.message_passes = constants.message_passes 22 | self.constants = constants 23 | 24 | def message_terms(self, nodes : torch.Tensor, node_neighbours : torch.Tensor, 25 | edges : torch.Tensor) -> None: 26 | """ 27 | Message passing function, to be implemented in all `SummationMPNN` subclasses. 28 | 29 | Args: 30 | ---- 31 | nodes (torch.Tensor) : Batch of node feature vectors. 32 | node_neighbours (torch.Tensor) : Batch of node feature vectors for neighbors. 33 | edges (torch.Tensor) : Batch of edge feature vectors. 34 | 35 | Shapes: 36 | ------ 37 | nodes : (total N nodes in batch, N node features) 38 | node_neighbours : (total N nodes in batch, max node degree, N node features) 39 | edges : (total N nodes in batch, max node degree, N edge features) 40 | """ 41 | raise NotImplementedError 42 | 43 | def update(self, nodes : torch.Tensor, messages : torch.Tensor) -> None: 44 | """ 45 | Message update function, to be implemented in all `SummationMPNN` subclasses. 46 | 47 | Args: 48 | ---- 49 | nodes (torch.Tensor) : Batch of node feature vectors. 50 | messages (torch.Tensor) : Batch of incoming messages. 51 | 52 | Shapes: 53 | ------ 54 | nodes : (total N nodes in batch, N node features) 55 | messages : (total N nodes in batch, N node features) 56 | """ 57 | raise NotImplementedError 58 | 59 | def readout(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor, 60 | node_mask : torch.Tensor) -> None: 61 | """ 62 | Local readout function, to be implemented in all `SummationMPNN` subclasses. 63 | 64 | Args: 65 | ---- 66 | hidden_nodes (torch.Tensor) : Batch of node feature vectors. 67 | input_nodes (torch.Tensor) : Batch of node feature vectors. 68 | node_mask (torch.Tensor) : Mask for non-existing neighbors, where elements 69 | are 1 if corresponding element exists and 0 70 | otherwise. 71 | 72 | Shapes: 73 | ------ 74 | hidden_nodes : (total N nodes in batch, N node features) 75 | input_nodes : (total N nodes in batch, N node features) 76 | node_mask : (total N nodes in batch, N features) 77 | """ 78 | raise NotImplementedError 79 | 80 | def forward(self, nodes : torch.Tensor, edges : torch.Tensor) -> None: 81 | """ 82 | Defines forward pass. 83 | 84 | Args: 85 | ---- 86 | nodes (torch.Tensor) : Batch of node feature matrices. 87 | edges (torch.Tensor) : Batch of edge feature tensors. 88 | 89 | Shapes: 90 | ------ 91 | nodes : (batch size, N nodes, N node features) 92 | edges : (batch size, N nodes, N nodes, N edge features) 93 | 94 | Returns: 95 | ------- 96 | output (torch.Tensor) : This would normally be the learned graph representation, 97 | but in all MPNN readout functions in this work, 98 | the last layer is used to predict the action 99 | probability distribution for a batch of graphs 100 | from the learned graph representation. 101 | """ 102 | adjacency = torch.sum(edges, dim=3) 103 | 104 | # **note: "idc" == "indices", "nghb{s}" == "neighbour(s)" 105 | (edge_batch_batch_idc, 106 | edge_batch_node_idc, 107 | edge_batch_nghb_idc) = adjacency.nonzero(as_tuple=True) 108 | 109 | (node_batch_batch_idc, node_batch_node_idc) = adjacency.sum(-1).nonzero(as_tuple=True) 110 | 111 | same_batch = node_batch_batch_idc.view(-1, 1) == edge_batch_batch_idc 112 | same_node = node_batch_node_idc.view(-1, 1) == edge_batch_node_idc 113 | 114 | # element ij of `message_summation_matrix` is 1 if `edge_batch_edges[j]` 115 | # is connected with `node_batch_nodes[i]`, else 0 116 | message_summation_matrix = (same_batch * same_node).float() 117 | 118 | edge_batch_edges = edges[edge_batch_batch_idc, edge_batch_node_idc, edge_batch_nghb_idc, :] 119 | 120 | # pad up the hidden nodes 121 | hidden_nodes = torch.zeros(nodes.shape[0], 122 | nodes.shape[1], 123 | self.hidden_node_features, 124 | device=self.constants.device) 125 | hidden_nodes[:nodes.shape[0], :nodes.shape[1], :nodes.shape[2]] = nodes.clone() 126 | node_batch_nodes = hidden_nodes[node_batch_batch_idc, node_batch_node_idc, :] 127 | 128 | for _ in range(self.message_passes): 129 | edge_batch_nodes = hidden_nodes[edge_batch_batch_idc, edge_batch_node_idc, :] 130 | 131 | edge_batch_nghbs = hidden_nodes[edge_batch_batch_idc, edge_batch_nghb_idc, :] 132 | 133 | message_terms = self.message_terms(edge_batch_nodes, 134 | edge_batch_nghbs, 135 | edge_batch_edges) 136 | 137 | if len(message_terms.size()) == 1: # if a single graph in batch 138 | message_terms = message_terms.unsqueeze(0) 139 | 140 | # the summation in eq. 1 of the NMPQC paper happens here 141 | messages = torch.matmul(message_summation_matrix, message_terms) 142 | 143 | node_batch_nodes = self.update(node_batch_nodes, messages) 144 | hidden_nodes[node_batch_batch_idc, node_batch_node_idc, :] = node_batch_nodes.clone() 145 | 146 | node_mask = adjacency.sum(-1) != 0 147 | output = self.readout(hidden_nodes, nodes, node_mask) 148 | 149 | return output 150 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main function for running GraphINVENT jobs. 3 | 4 | Examples: 5 | -------- 6 | * If you define an "input.csv" with desired job parameters in job_dir/: 7 | (graphinvent) ~/GraphINVENT$ python main.py --job_dir path/to/job_dir/ 8 | * If you instead want to run your job using the submission scripts: 9 | (graphinvent) ~/GraphINVENT$ python submit-fine-tuning.py 10 | """ 11 | # load general packages and functions 12 | import datetime 13 | 14 | # load GraphINVENT-specific functions 15 | import util 16 | from parameters.constants import constants 17 | from Workflow import Workflow 18 | 19 | # suppress minor warnings 20 | util.suppress_warnings() 21 | 22 | 23 | def main(): 24 | """ 25 | Defines the type of job (preprocessing, training, generation, testing, or 26 | fine-tuning), writes the job parameters (for future reference), and runs 27 | the job. 28 | """ 29 | _ = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") # fix date/time 30 | 31 | workflow = Workflow(constants=constants) 32 | 33 | job_type = constants.job_type 34 | print(f"* Run mode: '{job_type}'", flush=True) 35 | 36 | if job_type == "preprocess": 37 | # write preprocessing parameters 38 | util.write_preprocessing_parameters(params=constants) 39 | print("done writing preprocessing params to csv", flush=True) 40 | 41 | # preprocess all datasets 42 | workflow.preprocess_phase() 43 | 44 | elif job_type == "train": 45 | # write training parameters 46 | util.write_job_parameters(params=constants) 47 | 48 | # train model and generate graphs 49 | workflow.training_phase() 50 | 51 | elif job_type == "generate": 52 | # write generation parameters 53 | util.write_job_parameters(params=constants) 54 | 55 | # generate molecules only 56 | workflow.generation_phase() 57 | 58 | elif job_type == "test": 59 | # write testing parameters 60 | util.write_job_parameters(params=constants) 61 | 62 | # evaluate best model using the test set data 63 | workflow.testing_phase() 64 | 65 | elif job_type == "fine-tune": 66 | # write training parameters 67 | util.write_job_parameters(params=constants) 68 | 69 | # fine-tune the model and generate graphs 70 | workflow.learning_phase() 71 | 72 | else: 73 | raise NotImplementedError("Not a valid `job_type`.") 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/parameters/args.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines `ArgumentParser` for specifying job directory using command-line. 3 | """# load general packages and functions 4 | import argparse 5 | 6 | 7 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, 8 | add_help=False) 9 | parser.add_argument("--job-dir", 10 | type=str, 11 | default="../output/", 12 | help="Directory in which to write all output.") 13 | 14 | 15 | args = parser.parse_args() 16 | 17 | args_dict = vars(args) 18 | job_dir = args_dict["job_dir"] 19 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/parameters/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loads input parameters from `defaults.py`, and defines other global constants 3 | that depend on the input features, creating a `namedtuple` from them; 4 | additionally, if there exists an `input.csv` in the job directory, loads those 5 | arguments and overrides default values in `defaults.py`. 6 | """ 7 | # load general packages and functions 8 | from collections import namedtuple 9 | import pickle 10 | import csv 11 | import os 12 | import sys 13 | from typing import Tuple 14 | import numpy as np 15 | from rdkit.Chem.rdchem import BondType 16 | 17 | # load GraphINVENT-specific functions 18 | sys.path.insert(1, "./parameters/") # search "parameters/" directory 19 | import parameters.args as args 20 | import parameters.defaults as defaults 21 | 22 | 23 | def get_feature_dimensions(parameters : dict) -> Tuple[int, int, int, int]: 24 | """ 25 | Returns dimensions of all node features. 26 | """ 27 | n_atom_types = len(parameters["atom_types"]) 28 | n_formal_charge = len(parameters["formal_charge"]) 29 | n_numh = ( 30 | int(not parameters["use_explicit_H"] and not parameters["ignore_H"]) 31 | * len(parameters["imp_H"]) 32 | ) 33 | n_chirality = int(parameters["use_chirality"]) * len(parameters["chirality"]) 34 | 35 | return n_atom_types, n_formal_charge, n_numh, n_chirality 36 | 37 | 38 | def get_tensor_dimensions(n_atom_types : int, n_formal_charge : int, n_num_h : int, 39 | n_chirality : int, n_node_features : int, n_edge_features : int, 40 | parameters : dict) -> Tuple[list, list, list, list, int]: 41 | """ 42 | Returns dimensions for all tensors that describe molecular graphs. Tensor dimensions 43 | are `list`s, except for `dim_f_term` which is simply an `int`. Each element 44 | of the lists indicate the corresponding dimension of a particular subgraph matrix 45 | (i.e. `nodes`, `f_add`, etc). 46 | """ 47 | max_nodes = parameters["max_n_nodes"] 48 | 49 | # define the matrix dimensions as `list`s 50 | # first for the graph reps... 51 | dim_nodes = [max_nodes, n_node_features] 52 | 53 | dim_edges = [max_nodes, max_nodes, n_edge_features] 54 | 55 | # ... then for the APDs 56 | if parameters["use_chirality"]: 57 | if parameters["use_explicit_H"] or parameters["ignore_H"]: 58 | dim_f_add = [ 59 | parameters["max_n_nodes"], 60 | n_atom_types, 61 | n_formal_charge, 62 | n_chirality, 63 | n_edge_features, 64 | ] 65 | else: 66 | dim_f_add = [ 67 | parameters["max_n_nodes"], 68 | n_atom_types, 69 | n_formal_charge, 70 | n_num_h, 71 | n_chirality, 72 | n_edge_features, 73 | ] 74 | else: 75 | if parameters["use_explicit_H"] or parameters["ignore_H"]: 76 | dim_f_add = [ 77 | parameters["max_n_nodes"], 78 | n_atom_types, 79 | n_formal_charge, 80 | n_edge_features, 81 | ] 82 | else: 83 | dim_f_add = [ 84 | parameters["max_n_nodes"], 85 | n_atom_types, 86 | n_formal_charge, 87 | n_num_h, 88 | n_edge_features, 89 | ] 90 | 91 | dim_f_conn = [parameters["max_n_nodes"], n_edge_features] 92 | 93 | dim_f_term = 1 94 | 95 | return dim_nodes, dim_edges, dim_f_add, dim_f_conn, dim_f_term 96 | 97 | 98 | def load_params(input_csv_path : str) -> dict: 99 | """ 100 | Loads job parameters/hyperparameters from CSV (in `input_csv_path`). 101 | """ 102 | params_to_override = {} 103 | with open(input_csv_path, "r") as csv_file: 104 | 105 | params_reader = csv.reader(csv_file, delimiter=";") 106 | 107 | for key, value in params_reader: 108 | try: 109 | params_to_override[key] = eval(value) 110 | except NameError: # `value` is a `str` 111 | params_to_override[key] = value 112 | except SyntaxError: # to avoid "unexpected `EOF`" 113 | params_to_override[key] = value 114 | 115 | return params_to_override 116 | 117 | 118 | def override_params(all_params : dict) -> dict: 119 | """ 120 | If there exists an `input.csv` in the job directory, loads those arguments 121 | and overrides their default values from `features.py`. 122 | """ 123 | input_csv_path = all_params["job_dir"] + "input.csv" 124 | 125 | 126 | # check if there exists and `input.csv` in working directory 127 | if os.path.exists(input_csv_path): 128 | # override default values for parameters in `input.csv` 129 | params_to_override_dict = load_params(input_csv_path) 130 | for key, value in params_to_override_dict.items(): 131 | all_params[key] = value 132 | 133 | return all_params 134 | 135 | 136 | def collect_global_constants(parameters : dict, job_dir : str) -> namedtuple: 137 | """ 138 | Collects constants defined in `features.py` with those defined by the 139 | ArgParser (`args.py`), and returns the bundle as a `namedtuple`. 140 | 141 | Args: 142 | ---- 143 | parameters (dict) : Dictionary of parameters defined in `features.py`. 144 | job_dir (str) : Current job directory, defined on the command line. 145 | 146 | Returns: 147 | ------- 148 | constants (namedtuple) : Collected constants. 149 | """ 150 | #first override any arguments from `input.csv`: 151 | parameters["job_dir"] = job_dir 152 | print("job directory is now", parameters["job_dir"]) 153 | parameters = override_params(all_params=parameters) 154 | print("compute_train_csv is now", parameters["compute_train_csv"]) 155 | 156 | # then calculate any global constants below: 157 | if parameters["use_explicit_H"] and parameters["ignore_H"]: 158 | raise ValueError("Cannot use explicit Hs and ignore Hs at " 159 | "the same time. Please fix flags.") 160 | 161 | # define edge feature (rdkit `GetBondType()` result -> `int`) constants 162 | bondtype_to_int = {BondType.SINGLE: 0, BondType.DOUBLE: 1, BondType.TRIPLE: 2} 163 | 164 | if parameters["use_aromatic_bonds"]: 165 | bondtype_to_int[BondType.AROMATIC] = 3 166 | 167 | int_to_bondtype = dict(map(reversed, bondtype_to_int.items())) 168 | 169 | n_edge_features = len(bondtype_to_int) 170 | 171 | # define node feature constants 172 | n_atom_types, n_formal_charge, n_imp_H, n_chirality = get_feature_dimensions(parameters) 173 | 174 | n_node_features = n_atom_types + n_formal_charge + n_imp_H + n_chirality 175 | 176 | # define matrix dimensions 177 | (dim_nodes, dim_edges, dim_f_add, 178 | dim_f_conn, dim_f_term) = get_tensor_dimensions(n_atom_types, 179 | n_formal_charge, 180 | n_imp_H, 181 | n_chirality, 182 | n_node_features, 183 | n_edge_features, 184 | parameters) 185 | 186 | len_f_add = np.prod(dim_f_add[:]) 187 | len_f_add_per_node = np.prod(dim_f_add[1:]) 188 | len_f_conn = np.prod(dim_f_conn[:]) 189 | len_f_conn_per_node = np.prod(dim_f_conn[1:]) 190 | 191 | # create a dictionary of global constants, and add `job_dir` to it; this 192 | # will ultimately be converted to a `namedtuple` 193 | constants_dict = { 194 | "big_negative" : -1e6, 195 | "big_positive" : 1e6, 196 | "bondtype_to_int" : bondtype_to_int, 197 | "int_to_bondtype" : int_to_bondtype, 198 | "n_edge_features" : n_edge_features, 199 | "n_atom_types" : n_atom_types, 200 | "n_formal_charge" : n_formal_charge, 201 | "n_imp_H" : n_imp_H, 202 | "n_chirality" : n_chirality, 203 | "n_node_features" : n_node_features, 204 | "dim_nodes" : dim_nodes, 205 | "dim_edges" : dim_edges, 206 | "dim_f_add" : dim_f_add, 207 | "dim_f_conn" : dim_f_conn, 208 | "dim_f_term" : dim_f_term, 209 | "dim_apd" : [np.prod(dim_f_add) + np.prod(dim_f_conn) + 1], 210 | "len_f_add" : len_f_add, 211 | "len_f_add_per_node" : len_f_add_per_node, 212 | "len_f_conn" : len_f_conn, 213 | "len_f_conn_per_node": len_f_conn_per_node, 214 | } 215 | 216 | # join with `features.args_dict` 217 | constants_dict.update(parameters) 218 | 219 | # define path to dataset splits 220 | constants_dict["test_set"] = parameters["dataset_dir"] + "test.smi" 221 | constants_dict["training_set"] = parameters["dataset_dir"] + "train.smi" 222 | constants_dict["validation_set"] = parameters["dataset_dir"] + "valid.smi" 223 | 224 | # check (if a job is not a preprocessing job) that parameters match those for 225 | # the original preprocessing job 226 | if constants_dict["job_type"] != "preprocess": 227 | print( 228 | "* Running job using HDF datasets located at " 229 | + parameters["dataset_dir"], 230 | flush=True, 231 | ) 232 | print( 233 | "* Checking that the relevant parameters match " 234 | "those used in preprocessing the dataset.", 235 | flush=True, 236 | ) 237 | 238 | try: 239 | #load preprocessing parameters for comparison (if they exist already) 240 | csv_file = parameters["dataset_dir"] + "preprocessing_params.csv" 241 | params_to_check = load_params(input_csv_path=csv_file) 242 | 243 | for key, value in params_to_check.items(): 244 | if key in constants_dict.keys() and value != constants_dict[key]: 245 | raise ValueError( 246 | f"Check that training job parameters match those used in " 247 | f"preprocessing. {key} does not match." 248 | ) 249 | 250 | # if above error never raised, then all relevant parameters match! :) 251 | print("-- Job parameters match preprocessing parameters.", flush=True) 252 | except: 253 | print("-- Preprocessing pa rameters file does not exist for comparison.", flush=True) 254 | 255 | # load QSAR models (sklearn activity model) 256 | if constants_dict["job_type"] == "fine-tune": 257 | # print("-- Loading pre-trained scikit-learn activity model.", flush=True) 258 | # for qsar_model_name, qsar_model_path in constants_dict["qsar_models"].items(): 259 | # with open(qsar_model_path, 'rb') as file: 260 | # model_dict = pickle.load(file) 261 | # activity_model = model_dict["classifier_sv"] 262 | # constants_dict["qsar_models"][qsar_model_name] = activity_model 263 | print("-- Loading pre-trained gbm activity model.", flush=True) 264 | scoring_model_activity = pickle.load(open('/home/gridsan/dnori/GraphINVENT/data/protac_scoring_models/Protac_Scoring_Model_1024_100nM.pkl', 'rb')) 265 | scoring_model_structure = pickle.load(open('/home/gridsan/dnori/GraphINVENT/data/protac_scoring_models/Protac_Scoring_Model_Structure.pkl', 'rb')) 266 | with open('/home/gridsan/dnori/GraphINVENT/data/protac_scoring_models/features_1024_100nM.pkl','rb') as fp: 267 | features_activity = pickle.load(fp) 268 | with open('/home/gridsan/dnori/GraphINVENT/data/protac_scoring_models/features_structure.pkl','rb') as fp: 269 | features_structure = pickle.load(fp) 270 | constants_dict["qsar_models"]["protac_qsar_model"] = scoring_model_activity 271 | constants_dict["protac_structure_model"] = scoring_model_structure 272 | constants_dict["activity_model_features"] = features_activity 273 | constants_dict["structure_model_features"] = features_structure 274 | 275 | # convert `CONSTANTS` dictionary into a namedtuple (immutable + cleaner) 276 | Constants = namedtuple("CONSTANTS", sorted(constants_dict)) 277 | constants = Constants(**constants_dict) 278 | 279 | return constants 280 | 281 | # collect the constants using the functions defined above 282 | #change job_dir back to args.job_dir - ask Rocio how to get from arg parser 283 | constants = collect_global_constants(parameters=defaults.parameters, job_dir=args.job_dir) -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/parameters/defaults.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines default model parameters, hyperparameters, and settings. 3 | Recommended not to modify the default settings here, but rather create an input 4 | file with the modified parameters in a new job directory (see README). **Used 5 | as an alternative to using argparser, as there are many variables.** 6 | """ 7 | # load general packages and functions 8 | import sys 9 | 10 | # load GraphINVENT-specific functions 11 | sys.path.insert(1, "./parameters/") # search "parameters/" directory 12 | import parameters.args as args 13 | import parameters.load as load 14 | 15 | 16 | # default parameters defined below 17 | """ 18 | General settings for the generative model: 19 | atom_types (list) : Contains atom types (str) to encode in node features. 20 | formal_charge (list) : Contains formal charges (int) to encode in node 21 | features. 22 | imp_H (list) : Contains number of implicit hydrogens (int) to encode 23 | in node features. 24 | chirality (list) : Contains chiral states (str) to encode in node features. 25 | device (str) : Specifies type of architecture to run on. Options: 26 | "cuda" or "cpu"). 27 | generation_epoch (int) : Epoch to sample during a 'generation' job. 28 | n_samples (int) : Number of molecules to generate during each sampling 29 | epoch. Note: if `n_samples` > 100000 molecules, these 30 | will be generated in batches of 100000. 31 | n_workers (int) : Number of subprocesses to use during data loading. 32 | restart (bool) : If specified, will restart training from previous saved 33 | state. Can only be used for preprocessing or training 34 | jobs. 35 | max_n_nodes (int) : Maximum number of allowed nodes in graph. Must be 36 | greater than or equal to the number of nodes in 37 | largest graph in training set. 38 | job_type (str) : Type of job to run; options: 'preprocess', 'train', 39 | 'generate', 'test', or 'fine-tune'. 40 | sample_every (int) : Specifies when to sample the model (i.e. epochs 41 | between sampling). 42 | dataset_dir (str) : Full path to directory containing testing ("test.smi"), 43 | training ("train.smi"), and validation ("valid.smi") 44 | sets. 45 | use_aromatic_bonds (bool) : If specified, aromatic bond types will be used. 46 | use_canon (bool) : If specified, uses canonical RDKit ordering in graph 47 | representations. 48 | use_chirality (bool) : If specified, includes chirality in the atomic 49 | representations. 50 | use_explicit_H (bool) : If specified, uses explicit Hs in molecular 51 | representations (not recommended for most applications). 52 | ignore_H (bool) : If specified, ignores H's completely in graph 53 | representations (treats them neither as explicit or 54 | implicit). When generating graphs, H's are added to 55 | graphs after generation is terminated. 56 | use_tensorboard (bool) : If specified, enables the use of tensorboard during 57 | training. 58 | tensorboard_dir (str) : Path to directory in which to write tensorboard 59 | things. 60 | batch_size (int) : Number of graphs in a mini-batch. When preprocessing 61 | graphs, this is the size of the preprocessing groups 62 | (e.g. how many subgraphs preprocessed at once). 63 | epochs (int) : Number of training epochs. 64 | init_lr (float) : Initial learning rate. 65 | max_rel_lr (float) : Maximum allowed learning rate relative to the initial 66 | (used for learning rate ramp-up). 67 | model (str) : MPNN model to use ('MNN', 'S2V', 'AttS2V', 'GGNN', 68 | 'AttGGNN', or 'EMN'). 69 | decoding_route (str) : Breadth-first search ("bfs") or depth-first search 70 | ("dfs"). 71 | score_components (list) : A list of all the components to use in the RL scoring 72 | function. Can include "target_size={int}", "QED", 73 | "{name}_activity". 74 | score_thresholds (list) : Acceptable thresholds for the above score components. 75 | score_type (str) : If there are multiple components used in the scoring 76 | function, determines if the final score should be 77 | "continuous" (in which case, the above thresholds 78 | are ignored), or "binary" (in which case a generated 79 | molecule will receive a score of 1 iff all its score 80 | components are greater than the specified thresholds). 81 | qsar_models (dict) : A dictionary containing the path to each activity 82 | model specified in `score_components`. Note that 83 | the key in this dict must correspond to the name 84 | of the score component. 85 | sigma (float) : Can take any value. Tunes the contribution of the 86 | score in the augmented log-likelihood. See Atance 87 | et al (2021) https://doi.org/10.33774/chemrxiv-2021-9w3tc 88 | for suitable values. 89 | alpha (float) : Can take values between [0.0, 1.0]. Tunes the contribution 90 | from the best agent so far (BASF) in the loss. 91 | """ 92 | # general job parameters 93 | parameters = { 94 | "atom_types" : ["C", "N", "O", "F", "S", "Cl", "Br"], 95 | "formal_charge" : [-1, 0, 1], 96 | "imp_H" : [0, 1, 2, 3], 97 | "chirality" : ["None", "R", "S"], 98 | "device" : "cuda", 99 | "generation_epoch" : 30, 100 | "n_samples" : 2000, 101 | "n_workers" : 2, 102 | "restart" : True, 103 | "max_n_nodes" : 27, 104 | "job_type" : "preprocess", # "preprocess" -> "train" -> "fine-tune" -> "generate" 105 | "sample_every" : 10, 106 | "dataset_dir" : "data/pre-training/MOSES/", 107 | "use_aromatic_bonds" : False, 108 | "use_canon" : True, 109 | "use_chirality" : False, 110 | "use_explicit_H" : False, 111 | "ignore_H" : True, 112 | "compute_train_csv" : True, 113 | "tensorboard_dir" : "tensorboard/", 114 | "batch_size" : 1000, 115 | "block_size" : 100000, 116 | "epochs" : 100, 117 | "init_lr" : 1e-4, 118 | "max_rel_lr" : 1, 119 | "min_rel_lr" : 0.0001, 120 | "decoding_route" : "bfs", 121 | "activity_model_dir" : "data/fine-tuning/", 122 | "score_components" : ["QED", "drd2_activity", "target_size=13"], 123 | "score_thresholds" : [0.5, 0.5, 0.0], # 0.0 essentially means no threshold 124 | "score_type" : "binary", 125 | "qsar_models" : {"drd2_activity": "data/fine-tuning/qsar_model.pickle"}, 126 | "pretrained_model_dir": "output/", 127 | "sigma" : 20, 128 | "alpha" : 0.5, 129 | } 130 | 131 | # make sure job dir ends in "/" 132 | if args.job_dir[-1] != "/": 133 | print("* Adding '/' to end of `job_dir`.") 134 | args.job_dir += "/" 135 | 136 | # get the model before loading model-specific hyperparameters 137 | try: 138 | input_csv_path = args.job_dir + "input.csv" 139 | model = load.which_model(input_csv_path=input_csv_path) 140 | except: 141 | model = "GGNN" # default model 142 | parameters["model"] = model 143 | 144 | 145 | # model-specific hyperparameters (implementation-specific) 146 | if parameters["model"] == "MNN": 147 | """ 148 | MNN hyperparameters: 149 | mlp1_depth (int) : Num layers in first-tier MLP in `APDReadout`. 150 | mlp1_dropout_p (float) : Dropout probability in first-tier MLP in `APDReadout`. 151 | mlp1_hidden_dim (int) : Number of weights (layer width) in first-tier MLP 152 | in `APDReadout`. 153 | mlp2_depth (int) : Num layers in second-tier MLP in `APDReadout`. 154 | mlp2_dropout_p (float) : Dropout probability in second-tier MLP in `APDReadout`. 155 | mlp2_hidden_dim (int) : Number of weights (layer width) in second-tier MLP 156 | in `APDReadout`. 157 | message_passes (int) : Number of message passing steps. 158 | message_size (int) : Size of message passed ('enn' MLP output size). 159 | """ 160 | hyperparameters = { 161 | "mlp1_depth" : 4, 162 | "mlp1_dropout_p" : 0.0, 163 | "mlp1_hidden_dim" : 500, 164 | "mlp2_depth" : 4, 165 | "mlp2_dropout_p" : 0.0, 166 | "mlp2_hidden_dim" : 500, 167 | "hidden_node_features": 100, 168 | "message_passes" : 3, 169 | "message_size" : 100, 170 | } 171 | elif parameters["model"] == "S2V": 172 | """ 173 | S2V hyperparameters: 174 | enn_depth (int) : Num layers in 'enn' MLP. 175 | enn_dropout_p (float) : Dropout probability in 'enn' MLP. 176 | enn_hidden_dim (int) : Number of weights (layer width) in 'enn' MLP. 177 | mlp1_depth (int) : Num layers in first-tier MLP in `APDReadout`. 178 | mlp1_dropout_p (float) : Dropout probability in first-tier MLP in `APDReadout`. 179 | mlp1_hidden_dim (int) : Number of weights (layer width) in first-tier 180 | MLP in `APDReadout`. 181 | mlp2_depth (int) : Num layers in second-tier MLP in `APDReadout`. 182 | mlp2_dropout_p (float) : Dropout probability in second-tier MLP in `APDReadout`. 183 | mlp2_hidden_dim (int) : Number of weights (layer width) in second-tier 184 | MLP in `APDReadout`. 185 | message_passes (int) : Number of message passing steps. 186 | message_size (int) : Size of message passed (input size to `GRU`). 187 | s2v_lstm_computations (int) : Number of LSTM computations (loop) in S2V readout. 188 | s2v_memory_size (int) : Number of input features and hidden state size 189 | in LSTM cell in S2V readout. 190 | """ 191 | hyperparameters = { 192 | "enn_depth" : 4, 193 | "enn_dropout_p" : 0.0, 194 | "enn_hidden_dim" : 250, 195 | "mlp1_depth" : 4, 196 | "mlp1_dropout_p" : 0.0, 197 | "mlp1_hidden_dim" : 500, 198 | "mlp2_depth" : 4, 199 | "mlp2_dropout_p" : 0.0, 200 | "mlp2_hidden_dim" : 500, 201 | "hidden_node_features" : 100, 202 | "message_passes" : 3, 203 | "message_size" : 100, 204 | "s2v_lstm_computations": 3, 205 | "s2v_memory_size" : 100, 206 | } 207 | elif parameters["model"] == "AttS2V": 208 | """ 209 | AttS2V hyperparameters: 210 | att_depth (int) : Num layers in 'att_enn' MLP. 211 | att_dropout_p (float) : Dropout probability in 'att_enn' MLP. 212 | att_hidden_dim (int) : Number of weights (layer width) in 'att_enn' 213 | MLP. 214 | enn_depth (int) : Num layers in 'enn' MLP. 215 | enn_dropout_p (float) : Dropout probability in 'enn' MLP. 216 | enn_hidden_dim (int) : Number of weights (layer width) in 'enn' MLP. 217 | mlp1_depth (int) : Num layers in first-tier MLP in `APDReadout`. 218 | mlp1_dropout_p (float) : Dropout probability in first-tier MLP in `APDReadout`. 219 | mlp1_hidden_dim (int) : Number of weights (layer width) in first-tier 220 | MLP in `APDReadout`. 221 | mlp2_depth (int) : Num layers in second-tier MLP in `APDReadout`. 222 | mlp2_dropout_p (float) : Dropout probability in second-tier MLP in `APDReadout`. 223 | mlp2_hidden_dim (int) : Number of weights (layer width) in second-tier 224 | MLP in `APDReadout`. 225 | message_passes (int) : Number of message passing steps. 226 | message_size (int) : Size of message passed (output size of 'att_enn' 227 | MLP, input size to `GRU`). 228 | s2v_lstm_computations (int) : Number of LSTM computations (loop) in S2V readout. 229 | s2v_memory_size (int) : Number of input features and hidden state size 230 | in LSTM cell in S2V readout. 231 | """ 232 | hyperparameters = { 233 | "att_depth" : 4, 234 | "att_dropout_p" : 0.0, 235 | "att_hidden_dim" : 250, 236 | "enn_depth" : 4, 237 | "enn_dropout_p" : 0.0, 238 | "enn_hidden_dim" : 250, 239 | "mlp1_depth" : 4, 240 | "mlp1_dropout_p" : 0.0, 241 | "mlp1_hidden_dim" : 500, 242 | "mlp2_depth" : 4, 243 | "mlp2_dropout_p" : 0.0, 244 | "mlp2_hidden_dim" : 500, 245 | "hidden_node_features" : 100, 246 | "message_passes" : 3, 247 | "message_size" : 100, 248 | "s2v_lstm_computations": 3, 249 | "s2v_memory_size" : 100, 250 | } 251 | elif parameters["model"] == "GGNN": 252 | """ 253 | GGNN hyperparameters: 254 | enn_depth (int) : Num layers in 'enn' MLP. 255 | enn_dropout_p (float) : Dropout probability in 'enn' MLP. 256 | enn_hidden_dim (int) : Number of weights (layer width) in 'enn' MLP. 257 | mlp1_depth (int) : Num layers in first-tier MLP in `APDReadout`. 258 | mlp1_dropout_p (float) : Dropout probability in first-tier MLP in `APDReadout`. 259 | mlp1_hidden_dim (int) : Number of weights (layer width) in first-tier 260 | MLP in `APDReadout`. 261 | mlp2_depth (int) : Num layers in second-tier MLP in `APDReadout`. 262 | mlp2_dropout_p (float) : Dropout probability in second-tier MLP in `APDReadout`. 263 | mlp2_hidden_dim (int) : Number of weights (layer width) in second-tier 264 | MLP in `APDReadout`. 265 | gather_att_depth (int) : Num layers in 'gather_att' MLP in `GraphGather`. 266 | gather_att_dropout_p (float) : Dropout probability in 'gather_att' MLP in 267 | `GraphGather`. 268 | gather_att_hidden_dim (int) : Number of weights (layer width) in 'gather_att' 269 | MLP in `GraphGather`. 270 | gather_emb_depth (int) : Num layers in 'gather_emb' MLP in `GraphGather`. 271 | gather_emb_dropout_p (float) : Dropout probability in 'gather_emb' MLP in 272 | `GraphGather`. 273 | gather_emb_hidden_dim (int) : Number of weights (layer width) in 'gather_emb' 274 | MLP in `GraphGather`. 275 | gather_width (int) : Output size of `GraphGather` block. 276 | message_passes (int) : Number of message passing steps. 277 | message_size (int) : Size of message passed (output size of all 278 | MLPs in message aggregation step, input size 279 | to `GRU`). 280 | """ 281 | hyperparameters = { 282 | "enn_depth" : 4, 283 | "enn_dropout_p" : 0.0, 284 | "enn_hidden_dim" : 250, 285 | "mlp1_depth" : 4, 286 | "mlp1_dropout_p" : 0.0, 287 | "mlp1_hidden_dim" : 500, 288 | "mlp2_depth" : 4, 289 | "mlp2_dropout_p" : 0.0, 290 | "mlp2_hidden_dim" : 500, 291 | "gather_att_depth" : 4, 292 | "gather_att_dropout_p" : 0.0, 293 | "gather_att_hidden_dim": 250, 294 | "gather_emb_depth" : 4, 295 | "gather_emb_dropout_p" : 0.0, 296 | "gather_emb_hidden_dim": 250, 297 | "gather_width" : 100, 298 | "hidden_node_features" : 100, 299 | "message_passes" : 3, 300 | "message_size" : 100, 301 | } 302 | elif parameters["model"] == "AttGGNN": 303 | """ 304 | AttGGNN hyperparameters: 305 | att_depth (int) : Num layers in 'att_nns' MLP (message aggregation 306 | step). 307 | att_dropout_p (float) : Dropout probability in 'att_nns' MLP (message 308 | aggregation step). 309 | att_hidden_dim (int) : Number of weights (layer width) in 'att_nns' 310 | MLP (message aggregation step). 311 | mlp1_depth (int) : Num layers in first-tier MLP in `APDReadout`. 312 | mlp1_dropout_p (float) : Dropout probability in first-tier MLP in `APDReadout`. 313 | mlp1_hidden_dim (int) : Number of weights (layer width) in first-tier 314 | MLP in `APDReadout`. 315 | mlp2_depth (int) : Num layers in second-tier MLP in `APDReadout`. 316 | mlp2_dropout_p (float) : Dropout probability in second-tier MLP in `APDReadout`. 317 | mlp2_hidden_dim (int) : Number of weights (layer width) in second-tier 318 | MLP in `APDReadout`. 319 | gather_att_depth (int) : Num layers in 'gather_att' MLP in `GraphGather`. 320 | gather_att_dropout_p (float) : Dropout probability in 'gather_att' MLP in 321 | `GraphGather`. 322 | gather_att_hidden_dim (int) : Number of weights (layer width) in 'gather_att' 323 | MLP in `GraphGather`. 324 | gather_emb_depth (int) : Num layers in 'gather_emb' MLP in `GraphGather`. 325 | gather_emb_dropout_p (float) : Dropout probability in 'gather_emb' MLP in 326 | `GraphGather`. 327 | gather_emb_hidden_dim (int) : Number of weights (layer width) in 'gather_emb' 328 | MLP in `GraphGather`. 329 | gather_width (int) : Output size of `GraphGather` block. 330 | message_passes (int) : Number of message passing steps. 331 | message_size (int) : Size of message passed (output size of all 332 | MLPs in message aggregation step, input size 333 | to `GRU`). 334 | msg_depth (int) : Num layers in 'msg_nns' MLP (message aggregation 335 | step). 336 | msg_dropout_p (float) : Dropout probability in 'msg_nns' MLP (message 337 | aggregation step). 338 | msg_hidden_dim (int) : Number of weights (layer width) in 'msg_nns' 339 | MLP (message aggregation step). 340 | """ 341 | hyperparameters = { 342 | "att_depth" : 4, 343 | "att_dropout_p" : 0.0, 344 | "att_hidden_dim" : 250, 345 | "mlp1_depth" : 4, 346 | "mlp1_dropout_p" : 0.0, 347 | "mlp1_hidden_dim" : 500, 348 | "mlp2_depth" : 4, 349 | "mlp2_dropout_p" : 0.0, 350 | "mlp2_hidden_dim" : 500, 351 | "gather_att_depth" : 4, 352 | "gather_att_dropout_p" : 0.0, 353 | "gather_att_hidden_dim": 250, 354 | "gather_emb_depth" : 4, 355 | "gather_emb_dropout_p" : 0.0, 356 | "gather_emb_hidden_dim": 250, 357 | "gather_width" : 100, 358 | "hidden_node_features" : 100, 359 | "message_passes" : 3, 360 | "message_size" : 100, 361 | "msg_depth" : 4, 362 | "msg_dropout_p" : 0.0, 363 | "msg_hidden_dim" : 250, 364 | } 365 | elif parameters["model"] == "EMN": 366 | """ 367 | EMN hyperparameters: 368 | att_depth (int) : Num layers in 'att_msg_nn' MLP (edge propagation 369 | step). 370 | att_dropout_p (float) : Dropout probability in 'att_msg_nn' MLP (edge 371 | propagation step). 372 | att_hidden_dim (int) : Number of weights (layer width) in 'att_msg_nn' 373 | MLP (edge propagation step). 374 | edge_emb_depth (int) : Num layers in 'embedding_nn' MLP (edge processing 375 | step). 376 | edge_emb_dropout_p (float) : Dropout probability in 'embedding_nn' MLP (edge 377 | processing step). 378 | edge_emb_hidden_dim (int) : Number of weights (layer width) in 'embedding_nn' 379 | MLP (edge processing step). 380 | edge_emb_size (int) : Output size of all MLPs in edge propagation 381 | and processing steps (input size to `GraphGather`). 382 | mlp1_depth (int) : Num layers in first-tier MLP in `APDReadout`. 383 | mlp1_dropout_p (float) : Dropout probability in first-tier MLP in `APDReadout`. 384 | mlp1_hidden_dim (int) : Number of weights (layer width) in first-tier 385 | MLP in `APDReadout`. 386 | mlp2_depth (int) : Num layers in second-tier MLP in `APDReadout`. 387 | mlp2_dropout_p (float) : Dropout probability in second-tier MLP in `APDReadout`. 388 | mlp2_hidden_dim (int) : Number of weights (layer width) in second-tier 389 | MLP in `APDReadout`. 390 | gather_att_depth (int) : Num layers in 'gather_att' MLP in `GraphGather`. 391 | gather_att_dropout_p (float) : Dropout probability in 'gather_att' MLP in 392 | `GraphGather`. 393 | gather_att_hidden_dim (int) : Number of weights (layer width) in 'gather_att' 394 | MLP in `GraphGather`. 395 | gather_emb_depth (int) : Num layers in 'gather_emb' MLP in `GraphGather`. 396 | gather_emb_dropout_p (float) : Dropout probability in 'gather_emb' MLP in 397 | `GraphGather`. 398 | gather_emb_hidden_dim (int) : Number of weights (layer width) in 'gather_emb' 399 | MLP in `GraphGather`. 400 | gather_width (int) : Output size of `GraphGather` block. 401 | message_passes (int) : Number of message passing steps. 402 | msg_depth (int) : Num layers in 'emb_msg_nn' MLP (edge propagation 403 | step). 404 | msg_dropout_p (float) : Dropout probability in 'emb_msg_n' MLP (edge 405 | propagation step). 406 | msg_hidden_dim (int) : Number of weights (layer width) in 'emb_msg_nn' 407 | MLP (edge propagation step). 408 | """ 409 | hyperparameters = { 410 | "att_depth" : 4, 411 | "att_dropout_p" : 0.0, 412 | "att_hidden_dim" : 250, 413 | "edge_emb_depth" : 4, 414 | "edge_emb_dropout_p" : 0.0, 415 | "edge_emb_hidden_dim" : 250, 416 | "edge_emb_size" : 100, 417 | "mlp1_depth" : 4, 418 | "mlp1_dropout_p" : 0.0, 419 | "mlp1_hidden_dim" : 500, 420 | "mlp2_depth" : 4, 421 | "mlp2_dropout_p" : 0.0, 422 | "mlp2_hidden_dim" : 500, 423 | "gather_att_depth" : 4, 424 | "gather_att_dropout_p" : 0.0, 425 | "gather_att_hidden_dim": 250, 426 | "gather_emb_depth" : 4, 427 | "gather_emb_dropout_p" : 0.0, 428 | "gather_emb_hidden_dim": 250, 429 | "gather_width" : 100, 430 | "message_passes" : 3, 431 | "msg_depth" : 4, 432 | "msg_dropout_p" : 0.0, 433 | "msg_hidden_dim" : 250, 434 | } 435 | 436 | # make sure dataset dir ends in "/" 437 | if parameters["dataset_dir"][-1] != "/": 438 | print("* Adding '/' to end of `dataset_dir`.") 439 | parameters["dataset_dir"] += "/" 440 | 441 | # join dictionaries 442 | parameters.update(hyperparameters) 443 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/parameters/load.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for loading molecules from SMILES, as well as loading the model type. 3 | """ 4 | # load general packages and functions 5 | import csv 6 | import rdkit 7 | from rdkit.Chem.rdmolfiles import SmilesMolSupplier 8 | 9 | 10 | def molecules(path : str) -> rdkit.Chem.rdmolfiles.SmilesMolSupplier: 11 | """ 12 | Reads a SMILES file (full path/filename specified by `path`) and returns 13 | `rdkit.Mol` objects. 14 | """ 15 | # check first line of SMILES file to see if contains header 16 | with open(path) as smi_file: 17 | first_line = smi_file.readline() 18 | has_header = bool("SMILES" in first_line) 19 | smi_file.close() 20 | 21 | # read file 22 | molecule_set = SmilesMolSupplier(path, 23 | sanitize=True, 24 | nameColumn=-1, 25 | titleLine=has_header) 26 | return molecule_set 27 | 28 | def which_model(input_csv_path : str) -> str: 29 | """ 30 | Gets the type of model to use by reading it from CSV (in "input.csv"). 31 | 32 | Args: 33 | ---- 34 | input_csv_path (str) : The full path/filename to "input.csv" file 35 | containing parameters to overwrite from defaults. 36 | 37 | Returns: 38 | ------- 39 | value (str) : Name of model to use. 40 | """ 41 | with open(input_csv_path, "r") as csv_file: 42 | 43 | params_reader = csv.reader(csv_file, delimiter=";") 44 | 45 | for key, value in params_reader: 46 | if key == "model": 47 | return value # string describing model e.g. "GGNN" 48 | 49 | raise ValueError("Model type not specified.") 50 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/graphinvent/training_stats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Create train.csv containing summary stats on training set 3 | 4 | To use script, run: 5 | python graphinvent/training_stats.py 6 | """ 7 | import random 8 | import csv 9 | 10 | random.seed(10) 11 | 12 | # def write_input_csv(params_dict : dict, filename : str="params.csv") -> None: 13 | # """ 14 | # Writes job parameters/hyperparameters in `params_dict` to CSV using the specified 15 | # `filename`. 16 | # """ 17 | # print("in write input csv") 18 | # dict_path = params_dict["job_dir"] + filename 19 | 20 | # with open(dict_path, "w") as csv_file: 21 | 22 | # writer = csv.writer(csv_file, delimiter=";") 23 | # for key, value in params_dict.items(): 24 | # writer.writerow([key, value]) 25 | 26 | # params = {"compute_train_csv": True, "job_dir": "data/pre-training/MOSES/"} 27 | # write_input_csv(params_dict=params, filename="input.csv") 28 | 29 | from rdkit import Chem 30 | 31 | from DataProcesser import DataProcesser 32 | import util 33 | import torch 34 | from parameters.constants import constants 35 | 36 | from torch.utils.tensorboard import SummaryWriter 37 | 38 | #take a random subset of 10000 smiles 39 | orig_smi_file_path = str('data/pre-training/MOSES/train.smi') 40 | list_of_smiles = [] 41 | with open(orig_smi_file_path, 'r') as f: 42 | for line in f.readlines(): 43 | words = line.split() 44 | list_of_smiles.append(words[0]) 45 | 46 | # The first entry in list_of_smiles is 'SMILES', not a SMILES string 47 | list_of_smiles.pop(0) 48 | 49 | train_smiles = list_of_smiles[:92] 50 | test_smiles = list_of_smiles[92:109] 51 | valid_smiles = list_of_smiles[109:] 52 | 53 | with open('data/pre-training/protac_db_subset_50/train.smi', 'w+') as f: 54 | for smi in train_smiles: 55 | f.write("{}\n".format(smi)) 56 | with open('data/pre-training/protac_db_subset_50/test.smi', 'w+') as f: 57 | for smi in test_smiles: 58 | f.write("{}\n".format(smi)) 59 | with open('data/pre-training/protac_db_subset_50/valid.smi', 'w+') as f: 60 | for smi in valid_smiles: 61 | f.write("{}\n".format(smi)) 62 | 63 | 64 | #convert smiles to molecular graphs 65 | mols = [Chem.MolFromSmiles(smi) for smi in list_of_smiles] 66 | processor = DataProcesser(path = 'data/pre-training/MOSES/train.smi', is_training_set=True, molset = mols) 67 | graphs = [processor.get_graph(mol) for mol in mols] 68 | 69 | processor.get_ts_properties(molecular_graphs=graphs, group_size=10000) 70 | print(processor.ts_properties) 71 | 72 | #writecsv (from util.py) 73 | # Uncomment to write train.csv 74 | writer = SummaryWriter() 75 | util.properties_to_csv(processor.ts_properties, 'data/pre-training/MOSES/train.csv', 'Training set', writer) 76 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/tools/analyze_final_epoch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import rdkit 4 | from rdkit import Chem 5 | from rdkit.Chem.Draw import MolsToGridImage 6 | from rdkit.Chem.rdmolfiles import SmilesMolSupplier 7 | 8 | smi_file = "/home/gridsan/dnori/GraphINVENT/output_MOSES_subset/example-job-name/job_0/generation/aggregated_generation/epoch_100.smi" 9 | train_smi = "/home/gridsan/dnori/GraphINVENT/data/pre-training/MOSES_subset/train.smi" 10 | #smi_file = "/home/gridsan/dnori/GraphINVENT/output_protac_db_subset_60/example-job-name/job_0/generation/aggregated_generation/epoch_GEN80.smi" 11 | 12 | # # load molecules from file 13 | # mols = SmilesMolSupplier(smi_file, sanitize=True, nameColumn=-1,titleLine=True) 14 | 15 | # n_samples = 40 16 | # mols_list = [mol for mol in mols] 17 | # mols_sampled = random.sample(mols_list, n_samples) # sample 100 random molecules to visualize 18 | 19 | # mols_per_row = int(math.sqrt(n_samples)) # make a square grid 20 | 21 | # png_filename=smi_file[:-3] + "png" # name of PNG file to create 22 | # labels=list(range(n_samples)) # label structures with a number 23 | 24 | # # draw the molecules (creates a PIL image) 25 | # img = MolsToGridImage(mols=mols_sampled, 26 | # molsPerRow=mols_per_row, 27 | # legends=[str(i) for i in labels]) 28 | 29 | # img.save(png_filename) 30 | 31 | #calculate regeneration percentage 32 | 33 | with open(smi_file, 'r') as f1: 34 | gen_smi = f1.readlines() 35 | 36 | with open (train_smi, 'r') as f2: 37 | tr_smi = f2.readlines() 38 | 39 | #canon all smi in tr_smi: 40 | for i in range(len(tr_smi)): 41 | mol = Chem.MolFromSmiles(tr_smi[i]) 42 | canon_smi = Chem.MolToSmiles(mol) 43 | if canon_smi!=tr_smi[i]: 44 | tr_smi[i] 45 | 46 | total = len(gen_smi) 47 | num = 0 48 | for s in gen_smi: 49 | mol = Chem.MolFromSmiles(s) 50 | canon_s = Chem.MolToSmiles(mol) 51 | if canon_s not in tr_smi: 52 | num+=1 53 | 54 | print(f"percentage of new molecules: {num/total}") 55 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/tools/atom_types.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gets the atom types present in a set of molecules. 3 | 4 | To use script, run: 5 | python atom_types.py --smi path/to/file.smi 6 | """ 7 | import argparse 8 | import rdkit 9 | from utils import load_molecules 10 | 11 | 12 | # define the argument parser 13 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, 14 | add_help=False) 15 | 16 | # define two potential arguments to use when drawing SMILES from a file 17 | parser.add_argument("--smi", 18 | type=str, 19 | default="data/gdb13_1K/train.smi", 20 | help="SMILES file containing molecules to analyse.") 21 | args = parser.parse_args() 22 | 23 | 24 | def get_atom_types(smi_file : str) -> list: 25 | """ 26 | Determines the atom types present in an input SMILES file. 27 | 28 | Args: 29 | ---- 30 | smi_file (str) : Full path/filename to SMILES file. 31 | """ 32 | molecules = load_molecules(path=smi_file) 33 | 34 | # create a list of all the atom types 35 | atom_types = list() 36 | for mol in molecules: 37 | for atom in mol.GetAtoms(): 38 | atom_types.append(atom.GetAtomicNum()) 39 | 40 | # remove duplicate atom types then sort by atomic number 41 | set_of_atom_types = set(atom_types) 42 | atom_types_sorted = list(set_of_atom_types) 43 | atom_types_sorted.sort() 44 | 45 | # return the symbols, for convenience 46 | return [rdkit.Chem.Atom(atom).GetSymbol() for atom in atom_types_sorted] 47 | 48 | 49 | if __name__ == "__main__": 50 | atom_types = get_atom_types(smi_file=args.smi) 51 | print("* Atom types present in input file:", atom_types, flush=True) 52 | print("Done.", flush=True) 53 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/tools/combine_HDFs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Combines preprocessed HDF files. Useful when preprocessing large datasets, as 3 | one can split the `{split}.smi` into multiple files (and directories), preprocess 4 | them separately, and then combine using this script. 5 | 6 | To use script, modify the variables below to automatically create a list of 7 | paths **assuming** HDFs were created with the following directory structure: 8 | data/ 9 | |-- {dataset}_1/ 10 | |-- {dataset}_2/ 11 | |-- {dataset}_3/ 12 | |... 13 | |-- {dataset}_{n_dirs}/ 14 | 15 | The variables are also used in setting the dimensions of the HDF datasets later on. 16 | 17 | If directories were not named as above, then simply replace `path_list` below 18 | with a list of the paths to all the HDFs to combine. 19 | 20 | Then, run: 21 | python combine_HDFs.py 22 | """ 23 | import csv 24 | import numpy as np 25 | import h5py 26 | import torch 27 | from typing import Union 28 | 29 | 30 | def load_ts_properties_from_csv(csv_path : str) -> Union[dict, None]: 31 | """ 32 | Loads CSV file containing training set properties and returns contents as a dictionary. 33 | """ 34 | print("* Loading training set properties.", flush=True) 35 | 36 | # read dictionaries from csv 37 | try: 38 | with open(csv_path, "r") as csv_file: 39 | reader = csv.reader(csv_file, delimiter=";") 40 | csv_dict = dict(reader) 41 | except: 42 | return None 43 | 44 | # fix file types within dict in going from `csv_dict` --> `properties_dict` 45 | properties_dict = {} 46 | for key, value in csv_dict.items(): 47 | 48 | # first determine if key is a tuple 49 | key = eval(key) 50 | if len(key) > 1: 51 | tuple_key = (str(key[0]), str(key[1])) 52 | else: 53 | tuple_key = key 54 | 55 | # then convert the values to the correct data type 56 | try: 57 | properties_dict[tuple_key] = eval(value) 58 | except (SyntaxError, NameError): 59 | properties_dict[tuple_key] = value 60 | 61 | # convert any `list`s to `torch.Tensor`s (for consistency) 62 | if type(properties_dict[tuple_key]) == list: 63 | properties_dict[tuple_key] = torch.Tensor(properties_dict[tuple_key]) 64 | 65 | return properties_dict 66 | 67 | def write_ts_properties_to_csv(ts_properties_dict : dict) -> None: 68 | """ 69 | Writes the training set properties in `ts_properties_dict` to a CSV file. 70 | """ 71 | dict_path = f"data/{dataset}/{split}.csv" 72 | 73 | with open(dict_path, "w") as csv_file: 74 | 75 | csv_writer = csv.writer(csv_file, delimiter=";") 76 | for key, value in ts_properties_dict.items(): 77 | if "validity_tensor" in key: 78 | continue # skip writing the validity tensor because it is really long 79 | elif type(value) == np.ndarray: 80 | csv_writer.writerow([key, list(value)]) 81 | elif type(value) == torch.Tensor: 82 | try: 83 | csv_writer.writerow([key, float(value)]) 84 | except ValueError: 85 | csv_writer.writerow([key, [float(i) for i in value]]) 86 | else: 87 | csv_writer.writerow([key, value]) 88 | 89 | def get_dims() -> dict: 90 | """ 91 | Gets the dims corresponding to the three datasets in each preprocessed HDF 92 | file: "nodes", "edges", and "APDs". 93 | """ 94 | dims = {} 95 | dims["nodes"] = [max_n_nodes, n_atom_types + n_formal_charges] 96 | dims["edges"] = [max_n_nodes, max_n_nodes, n_bond_types] 97 | dim_f_add = [max_n_nodes, n_atom_types, n_formal_charges, n_bond_types] 98 | dim_f_conn = [max_n_nodes, n_bond_types] 99 | dims["APDs"] = [np.prod(dim_f_add) + np.prod(dim_f_conn) + 1] 100 | 101 | return dims 102 | 103 | def get_total_n_subgraphs(paths : list) -> int: 104 | """ 105 | Gets the total number of subgraphs saved in all the HDF files in the `paths`, 106 | where `paths` is a list of strings containing the path to each HDF file we want 107 | to combine. 108 | """ 109 | total_n_subgraphs = 0 110 | for path in paths: 111 | print("path:", path) 112 | hdf_file = h5py.File(path, "r") 113 | nodes = hdf_file.get("nodes") 114 | n_subgraphs = nodes.shape[0] 115 | total_n_subgraphs += n_subgraphs 116 | hdf_file.close() 117 | 118 | return total_n_subgraphs 119 | 120 | def main(paths : list, training_set : bool) -> None: 121 | """ 122 | Combine many small HDF files (their paths defined in `paths`) into one large HDF file. 123 | """ 124 | total_n_subgraphs = get_total_n_subgraphs(paths) 125 | dims = get_dims() 126 | 127 | print(f"* Creating HDF file to contain {total_n_subgraphs} subgraphs") 128 | new_hdf_file = h5py.File(f"data/{dataset}/{split}.h5", "a") 129 | new_dataset_nodes = new_hdf_file.create_dataset("nodes", 130 | (total_n_subgraphs, *dims["nodes"]), 131 | dtype=np.dtype("int8")) 132 | new_dataset_edges = new_hdf_file.create_dataset("edges", 133 | (total_n_subgraphs, *dims["edges"]), 134 | dtype=np.dtype("int8")) 135 | new_dataset_APDs = new_hdf_file.create_dataset("APDs", 136 | (total_n_subgraphs, *dims["APDs"]), 137 | dtype=np.dtype("int8")) 138 | 139 | print("* Combining data from smaller HDFs into a new larger HDF.") 140 | init_index = 0 141 | for path in paths: 142 | print("path:", path) 143 | hdf_file = h5py.File(path, "r") 144 | 145 | nodes = hdf_file.get("nodes") 146 | edges = hdf_file.get("edges") 147 | APDs = hdf_file.get("APDs") 148 | 149 | n_subgraphs = nodes.shape[0] 150 | 151 | new_dataset_nodes[init_index:(init_index + n_subgraphs)] = nodes 152 | new_dataset_edges[init_index:(init_index + n_subgraphs)] = edges 153 | new_dataset_APDs[init_index:(init_index + n_subgraphs)] = APDs 154 | 155 | init_index += n_subgraphs 156 | hdf_file.close() 157 | 158 | new_hdf_file.close() 159 | 160 | if training_set: 161 | print(f"* Combining data from respective `{split}.csv` files into one.") 162 | csv_list = [f"{path[:-2]}csv" for path in paths] 163 | 164 | ts_properties_old = None 165 | csv_files_processed = 0 166 | for path in csv_list: 167 | ts_properties = load_ts_properties_from_csv(csv_path=path) 168 | ts_properties_new = {} 169 | if ts_properties_old and ts_properties: 170 | for key, value in ts_properties_old.items(): 171 | if type(value) == float: 172 | ts_properties_new[key] = ( 173 | value * csv_files_processed + ts_properties[key] 174 | )/(csv_files_processed + 1) 175 | else: 176 | new_list = [] 177 | for i, value_i in enumerate(value): 178 | new_list.append( 179 | float( 180 | value_i * csv_files_processed + ts_properties[key][i] 181 | )/(csv_files_processed + 1) 182 | ) 183 | ts_properties_new[key] = new_list 184 | else: 185 | ts_properties_new = ts_properties 186 | ts_properties_old = ts_properties_new 187 | csv_files_processed += 1 188 | 189 | write_ts_properties_to_csv(ts_properties_dict=ts_properties_new) 190 | 191 | 192 | if __name__ == "__main__": 193 | # combine the HDFs defined in `path_list` 194 | 195 | # set variables 196 | dataset = "ChEMBL" 197 | n_atom_types = 15 # number of atom types used in preprocessing the data 198 | n_formal_charges = 3 # number of formal charges used in preprocessing the data 199 | n_bond_types = 3 # number of bond types used in preprocessing the data 200 | max_n_nodes = 40 # maximum number of nodes in the data 201 | 202 | # combine the training files 203 | n_dirs = 12 # how many times was `{split}.smi` split? 204 | split = "train" # train, test, or valid 205 | path_list = [f"data/{dataset}_{i}/{split}.h5" for i in range(0, n_dirs)] 206 | main(path_list, training_set=True) 207 | 208 | # combine the test files 209 | n_dirs = 4 # how many times was `{split}.smi` split? 210 | split = "test" # train, test, or valid 211 | path_list = [f"data/{dataset}_{i}/{split}.h5" for i in range(0, n_dirs)] 212 | main(path_list, training_set=False) 213 | 214 | # combine the validation files 215 | n_dirs = 2 # how many times was `{split}.smi` split? 216 | split = "valid" # train, test, or valid 217 | path_list = [f"data/{dataset}_{i}/{split}.h5" for i in range(0, n_dirs)] 218 | main(path_list, training_set=False) 219 | 220 | print("Done.", flush=True) 221 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/tools/combine_generation_batches.py: -------------------------------------------------------------------------------- 1 | """ 2 | Combines .likelihood, .smi, and .valid files across batches, 3 | consolidating into a set of 3 files for each epoch 4 | 5 | Run: 6 | python combine_generation_batches.py 7 | """ 8 | 9 | import os 10 | 11 | path_to_folder = "/home/gridsan/dnori/GraphINVENT/output_protac_db/example-job-name/job_0/generation/" 12 | extensions = ['likelihood','smi','valid'] 13 | 14 | 15 | for filename in os.listdir(path_to_folder): 16 | if 'GEN140' in filename: 17 | try: 18 | file_extension = filename[filename.index(".")+1:] 19 | except: 20 | file_extension = 'folder' 21 | if file_extension == 'smi': 22 | path = path_to_folder + filename 23 | with open(path, 'r') as input: 24 | first_instance = filename.index("_") 25 | filename_sub = filename[first_instance+1:] 26 | second_instance = filename_sub.index('_') 27 | epoch_number = filename[first_instance+7:second_instance+1+first_instance] 28 | 29 | print('yes') 30 | print('next file') 31 | #with open(f"{path_to_folder}aggregated_generation/epoch_{epoch_number}.{file_extension}", 'a+') as f: 32 | new_path = path_to_folder + "aggregated_generation/epoch_GEN140_AGG.smi" 33 | with open(new_path, 'a+') as f: 34 | for line in input: 35 | str_line = str(line) 36 | if "SMILES" in str_line: 37 | pass 38 | else: 39 | f.write(line) -------------------------------------------------------------------------------- /GraphINVENT_Protac/tools/formal_charges.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gets the formal charges present in a set of molecules. 3 | 4 | To use script, run: 5 | python formal_charges.py --smi path/to/file.smi 6 | """ 7 | import argparse 8 | from utils import load_molecules 9 | 10 | 11 | # define the argument parser 12 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, 13 | add_help=False) 14 | 15 | # define two potential arguments to use when drawing SMILES from a file 16 | parser.add_argument("--smi", 17 | type=str, 18 | default="data/gdb13_1K/train.smi", 19 | help="SMILES file containing molecules to analyse.") 20 | args = parser.parse_args() 21 | 22 | 23 | def get_formal_charges(smi_file : str) -> list: 24 | """ 25 | Determines the formal charges present in an input SMILES file. 26 | 27 | Args: 28 | ---- 29 | smi_file (str) : Full path/filename to SMILES file. 30 | """ 31 | molecules = load_molecules(path=smi_file) 32 | 33 | # create a list of all the formal charges 34 | formal_charges = list() 35 | for mol in molecules: 36 | for atom in mol.GetAtoms(): 37 | formal_charges.append(atom.GetFormalCharge()) 38 | 39 | # remove duplicate formal charges then sort 40 | set_of_formal_charges = set(formal_charges) 41 | formal_charges_sorted = list(set_of_formal_charges) 42 | formal_charges_sorted.sort() 43 | 44 | return formal_charges_sorted 45 | 46 | 47 | if __name__ == "__main__": 48 | formal_charges = get_formal_charges(smi_file=args.smi) 49 | print("* Formal charges present in input file:", formal_charges, flush=True) 50 | print("Done.", flush=True) 51 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/tools/max_n_nodes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gets the maximum number of nodes per molecule present in a set of molecules. 3 | 4 | To use script, run: 5 | python max_n_nodes.py --smi path/to/file.smi 6 | """ 7 | import argparse 8 | from utils import load_molecules 9 | 10 | 11 | # define the argument parser 12 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, 13 | add_help=False) 14 | 15 | # define two potential arguments to use when drawing SMILES from a file 16 | parser.add_argument("--smi", 17 | type=str, 18 | default="data/gdb13_1K/train.smi", 19 | help="SMILES file containing molecules to analyse.") 20 | args = parser.parse_args() 21 | 22 | 23 | def get_max_n_atoms(smi_file : str) -> int: 24 | """ 25 | Determines the maximum number of atoms per molecule in an input SMILES file. 26 | 27 | Args: 28 | ---- 29 | smi_file (str) : Full path/filename to SMILES file. 30 | """ 31 | molecules = load_molecules(path=smi_file) 32 | 33 | max_n_atoms = 0 34 | for mol in molecules: 35 | n_atoms = mol.GetNumAtoms() 36 | 37 | if n_atoms > max_n_atoms: 38 | max_n_atoms = n_atoms 39 | 40 | return max_n_atoms 41 | 42 | 43 | if __name__ == "__main__": 44 | max_n_atoms = get_max_n_atoms(smi_file=args.smi) 45 | print("* Max number of atoms in input file:", max_n_atoms, flush=True) 46 | print("Done.", flush=True) 47 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/tools/split_filter_protac.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Filters data by max_n_nodes. Randomly splits into 80% test, 15% test, 5% valid 3 | 4 | python split_filter_protac.py --orig_smi path/to/file.smi --new_smi path/to/file.smi --threshold num 5 | ''' 6 | import random 7 | 8 | import argparse 9 | from utils import load_molecules 10 | from rdkit import Chem 11 | 12 | # define the argument parser 13 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, 14 | add_help=False) 15 | 16 | # define two potential arguments to use when drawing SMILES from a file 17 | parser.add_argument("--orig_smi", 18 | type=str, 19 | default="data/gdb13_1K/train.smi", 20 | help="path to original SMILES file.") 21 | parser.add_argument("--new_smi", 22 | type=str, 23 | default="data/gdb13_1K/train.smi", 24 | help="path to new SMILES file.") 25 | parser.add_argument("--threshold", 26 | type=int, 27 | default=100, 28 | help="All molecules with >n atoms will be filtered out.") 29 | args = parser.parse_args() 30 | 31 | def filter(threshold, original_smi_file, filtered_smi_file): 32 | #filter from original.smi to new smi file filtered by size 33 | #threshold num must be an integer less than 139 34 | 35 | molecules = load_molecules(path=original_smi_file) 36 | 37 | with open(filtered_smi_file, 'w+') as f: 38 | for mol in molecules: 39 | n_atoms = mol.GetNumAtoms() 40 | if n_atoms <= threshold: 41 | atoms = mol.GetAtoms() 42 | atom_types = set(atoms) 43 | atom_symbols = [Chem.Atom(atom).GetSymbol() for atom in atom_types] 44 | if 'I' not in atom_symbols and 'P' not in atom_symbols: 45 | smi = Chem.MolToSmiles(mol) 46 | f.write("{}\n".format(smi)) 47 | 48 | 49 | 50 | def split(filtered_smi_file): 51 | 52 | orig_smiles = [] 53 | with open(filtered_smi_file, 'r') as f: 54 | for line in f.readlines(): 55 | words = line.split() 56 | orig_smiles.append(words[0]) 57 | 58 | test_sz = int(.15*len(orig_smiles)) 59 | valid_sz = int(.05*len(orig_smiles)) 60 | train_sz = len(orig_smiles) - test_sz - valid_sz 61 | 62 | random.shuffle(orig_smiles) 63 | train_smiles = orig_smiles[:train_sz] 64 | test_smiles = orig_smiles[train_sz:len(orig_smiles)-valid_sz] 65 | valid_smiles = orig_smiles[len(orig_smiles)-valid_sz:] 66 | 67 | with open('data/pre-training/protac_db_subset_70/train.smi', 'w+') as f: 68 | for smi in train_smiles: 69 | f.write("{}\n".format(smi)) 70 | with open('data/pre-training/protac_db_subset_70/test.smi', 'w+') as f: 71 | for smi in test_smiles: 72 | f.write("{}\n".format(smi)) 73 | with open('data/pre-training/protac_db_subset_70/valid.smi', 'w+') as f: 74 | for smi in valid_smiles: 75 | f.write("{}\n".format(smi)) 76 | 77 | if __name__ == "__main__": 78 | filter(threshold=args.threshold, original_smi_file = args.orig_smi, filtered_smi_file=args.new_smi) 79 | split(filtered_smi_file = args.new_smi) 80 | print("Done.", flush=True) -------------------------------------------------------------------------------- /GraphINVENT_Protac/tools/submit-split-preprocessing-supercloud.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example submission script for a GraphINVENT preprocessing job (before distribution- 3 | based training, not fine-tuning/optimization job), when the dataset is large and 4 | we want to split one large preprocessing job into multiple smaller preprocessing 5 | jobs, aggregating the final HDF files at the end. The HDF dataset created can be 6 | used to pre-train a model before a fine-tuning (via reinforcement learning) job. 7 | 8 | To run, you can first split the dataset as follows (do this within an interactive session): 9 | (graphinvent)$ python submit-split-preprocessing-supercloud.py --type split 10 | 11 | Then, submit the separate preprocessing jobs for the split dataset as follows: 12 | (graphinvent)$ python submit-split-preprocessing-supercloud.py --type submit 13 | 14 | When the above jobs have completed, aggregate the generated HDFs for each dataset split into the main dataset dir: 15 | (graphinvent)$ python submit-split-preprocessing-supercloud.py --type aggregate 16 | 17 | The above script also cleans up extra files. 18 | 19 | This script was modified to run on the MIT Supercloud. 20 | """ 21 | # load general packages and functions 22 | import csv 23 | import argparse 24 | import sys 25 | import os 26 | import shutil 27 | from pathlib import Path 28 | import subprocess 29 | import time 30 | from math import ceil 31 | from typing import Union 32 | import numpy as np 33 | import h5py 34 | import torch 35 | 36 | # define the argument parser 37 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, 38 | add_help=False) 39 | 40 | # define potential arguments for using this script 41 | parser.add_argument("--type", 42 | type=str, 43 | default="split", 44 | help="Acceptable values include 'split', 'submit', 'aggregate', and 'cleanup'.") 45 | args = parser.parse_args() 46 | 47 | # define what you want to do for the specified job(s) 48 | DATASET = "MOSES" # dataset name in "./data/pre-training/" 49 | JOB_TYPE = "preprocess" # "preprocess", "train", "generate", or "test" 50 | JOBDIR_START_IDX = 0 # where to start indexing job dirs 51 | N_JOBS = 1 # number of jobs to run per model 52 | RESTART = True # whether or not this is a restart job 53 | FORCE_OVERWRITE = True # overwrite job directories which already exist 54 | JOBNAME = "preprocessing" # used to create a sub directory 55 | 56 | # if running using LLsub, specify params below 57 | USE_LLSUB = True # use LLsub or not 58 | MEM_GB = 20 # required RAM in GB 59 | 60 | # for LLsub jobs, set number of CPUs per task 61 | if JOB_TYPE == "preprocess": 62 | CPUS_PER_TASK = 20 63 | DEVICE = "cpu" 64 | else: 65 | CPUS_PER_TASK = 10 66 | DEVICE = "cuda" 67 | 68 | # set paths here 69 | HOME = str(Path.home()) 70 | PYTHON_PATH = f"{HOME}/.conda/envs/graphinvent/bin/python" 71 | GRAPHINVENT_PATH = "./graphinvent/" 72 | DATA_PATH = "./data/pre-training/" 73 | 74 | # define dataset-specific parameters 75 | params = { 76 | "atom_types" : ["C", "N", "O", "F", "S", "Cl", "Br"], 77 | "formal_charge": [-1, 0, +1], 78 | "max_n_nodes" : 27, 79 | "job_type" : JOB_TYPE, 80 | "restart" : RESTART, 81 | "model" : "GGNN", 82 | "sample_every" : 2, 83 | "init_lr" : 1e-4, 84 | "epochs" : 100, 85 | "batch_size" : 50, 86 | "block_size" : 1000, 87 | "device" : DEVICE, 88 | "n_samples" : 100, 89 | # additional paramaters can be defined here, if different from the "defaults" 90 | # for instance, for "generate" jobs, don't forget to specify "generation_epoch" 91 | # and "n_samples" 92 | } 93 | 94 | 95 | def submit() -> None: 96 | """ 97 | Creates and submits submission script. Uses global variables defined at top 98 | of this file. 99 | """ 100 | check_paths() 101 | 102 | # create an output directory 103 | dataset_output_path = f"{HOME}/GraphINVENT/output_{DATASET}" 104 | tensorboard_path = os.path.join(dataset_output_path, "tensorboard") 105 | if JOBNAME != "": 106 | dataset_output_path = os.path.join(dataset_output_path, JOBNAME) 107 | tensorboard_path = os.path.join(tensorboard_path, JOBNAME) 108 | 109 | os.makedirs(dataset_output_path, exist_ok=True) 110 | os.makedirs(tensorboard_path, exist_ok=True) 111 | print(f"* Creating dataset directory {dataset_output_path}/", flush=True) 112 | 113 | # submit `N_JOBS` separate jobs 114 | jobdir_end_idx = JOBDIR_START_IDX + N_JOBS 115 | for job_idx in range(JOBDIR_START_IDX, jobdir_end_idx): 116 | 117 | # specify and create the job subdirectory if it does not exist 118 | params["job_dir"] = f"{dataset_output_path}/job_{job_idx}/" 119 | params["tensorboard_dir"] = f"{tensorboard_path}/job_{job_idx}/" 120 | 121 | # create the directory if it does not exist already, otherwise raises an 122 | # error, which is good because *might* not want to override data in our 123 | # existing directories! 124 | os.makedirs(params["tensorboard_dir"], exist_ok=True) 125 | try: 126 | job_dir_exists_already = bool( 127 | JOB_TYPE in ["generate", "test"] or FORCE_OVERWRITE 128 | ) 129 | os.makedirs(params["job_dir"], exist_ok=job_dir_exists_already) 130 | print( 131 | f"* Creating model subdirectory {dataset_output_path}/job_{job_idx}/", 132 | flush=True, 133 | ) 134 | except FileExistsError: 135 | print( 136 | f"-- Model subdirectory {dataset_output_path}/job_{job_idx}/ already exists.", 137 | flush=True, 138 | ) 139 | if not RESTART: 140 | continue 141 | 142 | # write the `input.csv` file 143 | write_input_csv(params_dict=params, filename="input.csv") 144 | 145 | # write `submit.sh` and submit 146 | if USE_LLSUB: 147 | print("* Writing submission script.", flush=True) 148 | write_submission_script(job_dir=params["job_dir"], 149 | job_idx=job_idx, 150 | job_type=params["job_type"], 151 | max_n_nodes=params["max_n_nodes"], 152 | cpu_per_task=CPUS_PER_TASK, 153 | python_bin_path=PYTHON_PATH) 154 | 155 | print("* Submitting batch job using LLsub.", flush=True) 156 | subprocess.run(["LLsub", params["job_dir"] + "submit.sh"], 157 | check=True) 158 | else: 159 | print("* Running job as a normal process.", flush=True) 160 | subprocess.run(["ls", f"{PYTHON_PATH}"], check=True) 161 | subprocess.run([f"{PYTHON_PATH}", 162 | f"{GRAPHINVENT_PATH}main.py", 163 | "--job-dir", 164 | params["job_dir"]], 165 | check=True) 166 | 167 | # sleep a few secs before submitting next job 168 | print("-- Sleeping 2 seconds.") 169 | time.sleep(2) 170 | 171 | 172 | def write_input_csv(params_dict : dict, filename : str="params.csv") -> None: 173 | """ 174 | Writes job parameters/hyperparameters in `params_dict` to CSV using the specified 175 | `filename`. 176 | """ 177 | dict_path = params_dict["job_dir"] + filename 178 | 179 | with open(dict_path, "w") as csv_file: 180 | 181 | writer = csv.writer(csv_file, delimiter=";") 182 | for key, value in params_dict.items(): 183 | writer.writerow([key, value]) 184 | 185 | 186 | def write_submission_script(job_dir : str, job_idx : int, job_type : str, max_n_nodes : int, 187 | cpu_per_task : int, python_bin_path : str) -> None: 188 | """ 189 | Writes a submission script (`submit.sh`). 190 | 191 | Args: 192 | ---- 193 | job_dir (str) : Job running directory. 194 | job_idx (int) : Job idx. 195 | job_type (str) : Type of job to run. 196 | max_n_nodes (int) : Maximum number of nodes in dataset. 197 | cpu_per_task (int) : How many CPUs to use per task. 198 | python_bin_path (str) : Path to Python binary to use. 199 | """ 200 | submit_filename = job_dir + "submit.sh" 201 | with open(submit_filename, "w") as submit_file: 202 | submit_file.write("#!/bin/bash\n") 203 | submit_file.write(f"#SBATCH --job-name={job_type}{max_n_nodes}_{job_idx}\n") 204 | submit_file.write(f"#SBATCH --output={job_type}{max_n_nodes}_{job_idx}o\n") 205 | submit_file.write(f"#SBATCH --cpus-per-task={cpu_per_task}\n") 206 | if DEVICE == "cuda": 207 | submit_file.write("#SBATCH --gres=gpu:volta:1\n") 208 | submit_file.write("hostname\n") 209 | submit_file.write("export QT_QPA_PLATFORM='offscreen'\n") 210 | submit_file.write(f"{python_bin_path} {GRAPHINVENT_PATH}main.py --job-dir {job_dir}") 211 | submit_file.write(f" > {job_dir}output.o${{LLSUB_RANK}}\n") 212 | 213 | 214 | def check_paths() -> None: 215 | """ 216 | Checks that paths to Python binary, data, and GraphINVENT are properly 217 | defined before running a job, and tells the user to define them if not. 218 | """ 219 | for path in [PYTHON_PATH, GRAPHINVENT_PATH, DATA_PATH]: 220 | if "path/to/" in path: 221 | print("!!!") 222 | print("* Update the following paths in `submit.py` before running:") 223 | print("-- `PYTHON_PATH`\n-- `GRAPHINVENT_PATH`\n-- `DATA_PATH`") 224 | sys.exit(0) 225 | 226 | def split_file(filename : str, n_lines_per_split : int=100000) -> None: 227 | """ 228 | _summary_ 229 | 230 | Args: 231 | ---- 232 | filename (str) : The filename. 233 | n_lines_per_split (int) : Number of lines per file. 234 | 235 | Returns: 236 | ------- 237 | n_splits (int) : Number of splits. 238 | """ 239 | output_base = filename[:-4] 240 | input = open(filename, "r") 241 | extension = filename[-3:] 242 | 243 | count = 0 244 | at = 0 245 | dest = None 246 | for line in input: 247 | if count % n_lines_per_split == 0: 248 | if dest: dest.close() 249 | dest = open(f"{output_base}.{at}.{extension}", "w") 250 | at += 1 251 | dest.write(line) 252 | count += 1 253 | 254 | n_splits = at 255 | return n_splits 256 | 257 | def get_n_splits(filename : str, n_lines_per_split : int=100000) -> None: 258 | """ 259 | _summary_ 260 | 261 | Args: 262 | ---- 263 | filename (str) : The filename. 264 | n_lines_per_split (int) : Number of lines per file. 265 | 266 | Returns: 267 | ------- 268 | n_splits (int) : Number of splits. 269 | """ 270 | output_base = filename[:-4] 271 | input = open(filename, "r") 272 | n_splits = ceil(len(input.readlines()) / n_lines_per_split) 273 | return n_splits 274 | 275 | def load_ts_properties_from_csv(csv_path : str) -> Union[dict, None]: 276 | """ 277 | Loads CSV file containing training set properties and returns contents as a dictionary. 278 | """ 279 | print("* Loading training set properties.", flush=True) 280 | 281 | # read dictionaries from csv 282 | try: 283 | with open(csv_path, "r") as csv_file: 284 | reader = csv.reader(csv_file, delimiter=";") 285 | csv_dict = dict(reader) 286 | except: 287 | return None 288 | 289 | # fix file types within dict in going from `csv_dict` --> `properties_dict` 290 | properties_dict = {} 291 | for key, value in csv_dict.items(): 292 | 293 | # first determine if key is a tuple 294 | key = eval(key) 295 | if len(key) > 1: 296 | tuple_key = (str(key[0]), str(key[1])) 297 | else: 298 | tuple_key = key 299 | 300 | # then convert the values to the correct data type 301 | try: 302 | properties_dict[tuple_key] = eval(value) 303 | except (SyntaxError, NameError): 304 | properties_dict[tuple_key] = value 305 | 306 | # convert any `list`s to `torch.Tensor`s (for consistency) 307 | if type(properties_dict[tuple_key]) == list: 308 | properties_dict[tuple_key] = torch.Tensor(properties_dict[tuple_key]) 309 | 310 | return properties_dict 311 | 312 | def write_ts_properties_to_csv(ts_properties_dict : dict, split : str) -> None: 313 | """ 314 | Writes the training set properties in `ts_properties_dict` to a CSV file. 315 | """ 316 | dict_path = f"data/pre-training/{dataset}/{split}.csv" 317 | 318 | with open(dict_path, "w") as csv_file: 319 | 320 | csv_writer = csv.writer(csv_file, delimiter=";") 321 | for key, value in ts_properties_dict.items(): 322 | if "validity_tensor" in key: 323 | continue # skip writing the validity tensor because it is really long 324 | elif type(value) == np.ndarray: 325 | csv_writer.writerow([key, list(value)]) 326 | elif type(value) == torch.Tensor: 327 | try: 328 | csv_writer.writerow([key, float(value)]) 329 | except ValueError: 330 | csv_writer.writerow([key, [float(i) for i in value]]) 331 | else: 332 | csv_writer.writerow([key, value]) 333 | 334 | def get_dims() -> dict: 335 | """ 336 | Gets the dims corresponding to the three datasets in each preprocessed HDF 337 | file: "nodes", "edges", and "APDs". 338 | """ 339 | dims = {} 340 | dims["nodes"] = [max_n_nodes, n_atom_types + n_formal_charges] 341 | dims["edges"] = [max_n_nodes, max_n_nodes, n_bond_types] 342 | dim_f_add = [max_n_nodes, n_atom_types, n_formal_charges, n_bond_types] 343 | dim_f_conn = [max_n_nodes, n_bond_types] 344 | dims["APDs"] = [np.prod(dim_f_add) + np.prod(dim_f_conn) + 1] 345 | 346 | return dims 347 | 348 | def get_total_n_subgraphs(paths : list) -> int: 349 | """ 350 | Gets the total number of subgraphs saved in all the HDF files in the `paths`, 351 | where `paths` is a list of strings containing the path to each HDF file we want 352 | to combine. 353 | """ 354 | total_n_subgraphs = 0 355 | for path in paths: 356 | print("path:", path) 357 | hdf_file = h5py.File(path, "r") 358 | nodes = hdf_file.get("nodes") 359 | n_subgraphs = nodes.shape[0] 360 | total_n_subgraphs += n_subgraphs 361 | hdf_file.close() 362 | 363 | return total_n_subgraphs 364 | 365 | def combine_HDFs(paths : list, training_set : bool, split : str) -> None: 366 | """ 367 | Combine many small HDF files (their paths defined in `paths`) into one large 368 | HDF file. Works assuming HDFs were created for the preprocessed dataset 369 | following the following directory structure: 370 | data/pre-training/ 371 | |-- {dataset}_1/ 372 | |-- {dataset}_2/ 373 | |-- {dataset}_3/ 374 | |... 375 | |-- {dataset}_{n_dirs}/ 376 | """ 377 | total_n_subgraphs = get_total_n_subgraphs(paths) 378 | dims = get_dims() 379 | 380 | print(f"* Creating HDF file to contain {total_n_subgraphs} subgraphs") 381 | new_hdf_file = h5py.File(f"data/pre-training/{dataset}/{split}.h5", "a") 382 | new_dataset_nodes = new_hdf_file.create_dataset("nodes", 383 | (total_n_subgraphs, *dims["nodes"]), 384 | dtype=np.dtype("int8")) 385 | new_dataset_edges = new_hdf_file.create_dataset("edges", 386 | (total_n_subgraphs, *dims["edges"]), 387 | dtype=np.dtype("int8")) 388 | new_dataset_APDs = new_hdf_file.create_dataset("APDs", 389 | (total_n_subgraphs, *dims["APDs"]), 390 | dtype=np.dtype("int8")) 391 | 392 | print("* Combining data from smaller HDFs into a new larger HDF.") 393 | init_index = 0 394 | for path in paths: 395 | print("path:", path) 396 | hdf_file = h5py.File(path, "r") 397 | 398 | nodes = hdf_file.get("nodes") 399 | edges = hdf_file.get("edges") 400 | APDs = hdf_file.get("APDs") 401 | 402 | n_subgraphs = nodes.shape[0] 403 | 404 | new_dataset_nodes[init_index:(init_index + n_subgraphs)] = nodes 405 | new_dataset_edges[init_index:(init_index + n_subgraphs)] = edges 406 | new_dataset_APDs[init_index:(init_index + n_subgraphs)] = APDs 407 | 408 | init_index += n_subgraphs 409 | hdf_file.close() 410 | 411 | new_hdf_file.close() 412 | 413 | if training_set: 414 | print(f"* Combining data from respective `{split}.csv` files into one.") 415 | csv_list = [f"{path[:-2]}csv" for path in paths] 416 | 417 | ts_properties_old = None 418 | csv_files_processed = 0 419 | for path in csv_list: 420 | ts_properties = load_ts_properties_from_csv(csv_path=path) 421 | ts_properties_new = {} 422 | if ts_properties_old and ts_properties: 423 | for key, value in ts_properties_old.items(): 424 | if type(value) == float: 425 | ts_properties_new[key] = ( 426 | value * csv_files_processed + ts_properties[key] 427 | )/(csv_files_processed + 1) 428 | else: 429 | new_list = [] 430 | for i, value_i in enumerate(value): 431 | new_list.append( 432 | float( 433 | value_i * csv_files_processed + ts_properties[key][i] 434 | )/(csv_files_processed + 1) 435 | ) 436 | ts_properties_new[key] = new_list 437 | else: 438 | ts_properties_new = ts_properties 439 | ts_properties_old = ts_properties_new 440 | csv_files_processed += 1 441 | 442 | write_ts_properties_to_csv(ts_properties_dict=ts_properties_new, split=split) 443 | 444 | if __name__ == "__main__": 445 | dataset = DATASET 446 | 447 | if args.type == "split": 448 | # --------- SPLIT THE DATASET ---------- 449 | # 1) first, split the training set 450 | n_training_splits = split_file(filename=f"{DATA_PATH}{DATASET}/train.smi", 451 | n_lines_per_split=100000) 452 | 453 | # 2) then, split the test set (if necessary) 454 | n_test_splits = split_file(filename=f"{DATA_PATH}{DATASET}/test.smi", 455 | n_lines_per_split=100000) 456 | 457 | # 3) finally, split the validation set (if necessary) 458 | n_valid_splits = split_file(filename=f"{DATA_PATH}{DATASET}/valid.smi", 459 | n_lines_per_split=100000) 460 | 461 | elif args.type == "submit": 462 | # ---------- MOVE EACH SPLIT INTO ITS OWN DIRECTORY AND SUBMIT EACH AS SEPARATE JOB ---------- 463 | # first get the number of splits for each train/test/valid split if each 464 | # file is split into files of max 100000 lines 465 | n_training_splits = get_n_splits(filename=f"{DATA_PATH}{DATASET}/train.smi", 466 | n_lines_per_split=100000) 467 | 468 | for split_idx in range(n_training_splits): 469 | if not os.path.exists(f"{DATA_PATH}{dataset}_{split_idx}/"): 470 | os.mkdir(f"{DATA_PATH}{dataset}_{split_idx}/") # make the dir 471 | 472 | # moving train split into folder for given index 473 | try: 474 | os.rename(f"{DATA_PATH}{dataset}/train.{split_idx}.smi", f"{DATA_PATH}{dataset}_{split_idx}/train.smi") # move the file to the dir and rename 475 | except: 476 | pass 477 | 478 | # moving test split into folder for given index, if test split exists 479 | try: 480 | os.rename(f"{DATA_PATH}{dataset}/test.{split_idx}.smi", f"{DATA_PATH}{dataset}_{split_idx}/test.smi") # move the file to the dir and rename 481 | except: 482 | pass 483 | 484 | # moving valid split into folder for given index, if valid split exists 485 | try: 486 | os.rename(f"{DATA_PATH}{dataset}/valid.{split_idx}.smi", f"{DATA_PATH}{dataset}_{split_idx}/valid.smi") # move the file to the dir and rename 487 | except: 488 | pass 489 | 490 | DATASET = f"{dataset}_{split_idx}/" 491 | params["dataset_dir"] = f"{DATA_PATH}{DATASET}" 492 | submit() 493 | 494 | 495 | elif args.type == "aggregate": 496 | # first get the number of splits for each train/test/valid split if each 497 | # file is split into files of max 100000 lines 498 | n_training_splits = get_n_splits(filename=f"{DATA_PATH}{DATASET}/train.smi", 499 | n_lines_per_split=100000) 500 | n_test_splits = get_n_splits(filename=f"{DATA_PATH}{DATASET}/test.smi", 501 | n_lines_per_split=100000) 502 | n_valid_splits = get_n_splits(filename=f"{DATA_PATH}{DATASET}/valid.smi", 503 | n_lines_per_split=100000) 504 | # ---------- AGGREGATE THE RESULTS ---------- 505 | # set variables 506 | n_atom_types = len(params["atom_types"]) # number of atom types used in preprocessing the data 507 | n_formal_charges = len(params["formal_charge"]) # number of formal charges used in preprocessing the data 508 | n_bond_types = 3 # number of bond types used in preprocessing the data 509 | max_n_nodes = params["max_n_nodes"] # maximum number of nodes in the data 510 | 511 | # # 1) combine the training files 512 | # path_list = [f"data/pre-training/{dataset}_{i}/train.h5" for i in range(0, n_training_splits)] 513 | # combine_HDFs(path_list, training_set=False, split="train") 514 | 515 | # 2) combine the test files 516 | path_list = [f"data/pre-training/{dataset}_{i}/test.h5" for i in range(0, n_test_splits)] 517 | combine_HDFs(path_list, training_set=False, split="test") 518 | 519 | # 3) combine the validation files 520 | path_list = [f"data/pre-training/{dataset}_{i}/valid.h5" for i in range(0, n_valid_splits)] 521 | combine_HDFs(path_list, training_set=False, split="valid") 522 | 523 | # ---------- DELETE TEMPORARY FILES ---------- 524 | for split_idx in range(max(n_training_splits, n_test_splits, n_valid_splits)): 525 | shutil.rmtree(f"{DATA_PATH}{dataset}_{split_idx}/") # remove the dir and all files in it 526 | else: 527 | raise ValueError("Not a valid job type.") 528 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/tools/tdc-create-dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Uses the Therapeutics Data Commons (TDC) to get datasets for goal-directed 3 | molecular optimization tasks and then filters the molecules based on number 4 | of heavy atoms and formal charge. 5 | 6 | See: 7 | * https://tdcommons.ai/ 8 | * https://github.com/mims-harvard/TDC 9 | 10 | To use script, run: 11 | (graphinvent)$ python tdc-create-dataset.py --dataset MOSES 12 | """ 13 | import os 14 | import argparse 15 | from pathlib import Path 16 | import shutil 17 | from tdc.generation import MolGen 18 | import rdkit 19 | from rdkit import Chem 20 | 21 | # define the argument parser 22 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, 23 | add_help=False) 24 | 25 | # define two potential arguments to use when drawing SMILES from a file 26 | parser.add_argument("--dataset", 27 | type=str, 28 | default="ChEMBL", 29 | help="Specifies the dataset to use for creating the data. Options " 30 | "are: 'ChEMBL', 'MOSES', or 'ZINC'.") 31 | args = parser.parse_args() 32 | 33 | 34 | def save_smiles(smi_file : str, smi_list : list) -> None: 35 | """Saves input list of SMILES to the specified file path.""" 36 | smi_writer = rdkit.Chem.rdmolfiles.SmilesWriter(smi_file) 37 | for smi in smi_list: 38 | try: 39 | mol = rdkit.Chem.MolFromSmiles(smi[0]) 40 | if mol.GetNumAtoms() < 81: # filter out molecules with >= 81 atoms 41 | save = True 42 | for atom in mol.GetAtoms(): 43 | if atom.GetFormalCharge() not in [-1, 0, +1]: # filter out molecules with large formal charge 44 | save = False 45 | break 46 | if save: 47 | smi_writer.write(mol) 48 | except: # likely TypeError or AttributeError e.g. "smi[0]" is "nan" 49 | continue 50 | smi_writer.close() 51 | 52 | 53 | if __name__ == "__main__": 54 | print(f"* Loading {args.dataset} dataset using the TDC.") 55 | data = MolGen(name=args.dataset) 56 | split = data.get_split() 57 | HOME = str(Path.home()) 58 | DATA_PATH = f"./data/{args.dataset}/" 59 | try: 60 | os.mkdir(DATA_PATH) 61 | print(f"-- Creating dataset at {DATA_PATH}") 62 | except FileExistsError: 63 | shutil.rmtree(DATA_PATH) 64 | os.mkdir(DATA_PATH) 65 | print(f"-- Removed old directory at {DATA_PATH}") 66 | print(f"-- Creating new dataset at {DATA_PATH}") 67 | 68 | print(f"* Re-saving {args.dataset} dataset in a format GraphINVENT can parse.") 69 | print("-- Saving training data...") 70 | save_smiles(smi_file=f"{DATA_PATH}train.smi", smi_list=split["train"].values) 71 | print("-- Saving testing data...") 72 | save_smiles(smi_file=f"{DATA_PATH}test.smi", smi_list=split["test"].values) 73 | print("-- Saving validation data...") 74 | save_smiles(smi_file=f"{DATA_PATH}valid.smi", smi_list=split["valid"].values) 75 | 76 | # # delete the raw downloaded files 77 | # dir_path = "./data/" 78 | # shutil.rmtree(dir_path) 79 | print("Done.", flush=True) 80 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/tools/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Miscellaneous functions. 3 | """ 4 | import rdkit 5 | from rdkit.Chem.rdmolfiles import SmilesMolSupplier 6 | 7 | 8 | def load_molecules(path : str) -> rdkit.Chem.rdmolfiles.SmilesMolSupplier: 9 | """ 10 | Reads a SMILES file (full path/filename specified by `path`) and returns the 11 | `rdkit.Mol` object "supplier". 12 | """ 13 | # check first line of SMILES file to see if contains header 14 | with open(path) as smi_file: 15 | first_line = smi_file.readline() 16 | has_header = bool("SMILES" in first_line) 17 | smi_file.close() 18 | 19 | # read file 20 | molecule_set = SmilesMolSupplier(path, sanitize=True, nameColumn=-1, titleLine=has_header) 21 | 22 | return molecule_set 23 | -------------------------------------------------------------------------------- /GraphINVENT_Protac/tools/visualization.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import rdkit 4 | from rdkit.Chem.Draw import MolsToGridImage 5 | from rdkit.Chem.rdmolfiles import SmilesMolSupplier 6 | 7 | smi_file = "/home/gridsan/dnori/GraphINVENT/output_MOSES_subset/example-job-name/job_0/generation/step61_agent.smi" 8 | 9 | # load molecules from file 10 | mols = SmilesMolSupplier(smi_file, sanitize=True, nameColumn=-1,titleLine=True) 11 | 12 | n_samples = 8 13 | mols_list = [mol for mol in mols] 14 | mols_sampled = random.sample(mols_list, n_samples) # sample 100 random molecules to visualize 15 | 16 | mols_per_row = int(math.sqrt(n_samples)) # make a square grid 17 | 18 | png_filename=smi_file[:-3] + "png" # name of PNG file to create 19 | #labels=list(range(n_samples)) # label structures with a number 20 | 21 | labels = [i for i in range(n_samples)] 22 | 23 | # draw the molecules (creates a PIL image) 24 | img = MolsToGridImage(mols=mols_sampled, 25 | molsPerRow=mols_per_row, 26 | legends=[str(i) for i in labels]) 27 | 28 | img.save(png_filename) -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Divya Nori 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Protac-Design 2 | ## Description 3 | This repo contains the code behind our workshop paper at the NeurIPS 2022 AI4Science Workshop, [Link to Paper](https://openreview.net/pdf?id=pGyp4o9gky0). It is organized into the following notebooks: 4 | 5 | * [surrogate_model.ipynb](./surrogate_model.ipynb): Contains the code for processing the raw PROTAC data and training the DC50 surrogate model. Note that you will need to download the public PROTAC data from [PROTAC-DB](http://cadd.zju.edu.cn/protacdb/downloads) in order to reproduce the results. 6 | * [molecule_metrics.ipynb](./molecule_metrics.ipynb): Contains code for computing metrics on a set of generated molecules. Metrics include percentage predicted active, percentage of duplicate molecules, percentage of molecules regenerated from training set, average number of atoms, chemical diversity, and drug-likeness. 7 | * [binary_label_metrics.py](./binary_label_metrics.py): Contains useful functions for analyzing performance of binary classification models. 8 | 9 | Then there are additional files in the repo: 10 | * [surrogate_model.pkl](./surrogate_model.pkl): Contains the pre-trained surrogate model for DC50 prediction. 11 | * [features.pkl](./features.pkl): Contains list of features used in surrogate model training; required to reproduce reinforcement learning jobs using protac scoring function. 12 | 13 | 14 | ## Instructions 15 | 1. Before running any of the notebooks, you will need to download the PROTAC data from the public [PROTAC-DB](http://cadd.zju.edu.cn/protacdb) database. 16 | 2. You will then need to create a conda environment containing the following main packages: `rdkit`, `pandas`, `sklearn`, `scipy`, `ipython`, and `optuna`. See the instructions in the next section for setting this up. 17 | 3. Open the notebooks on your favorite platform, and make sure to select the right kernel before executing. 18 | 19 | ## Environment 20 | To set up the environment for running the notebooks in this repo, you can follow the following set of instructions: 21 | ``` 22 | conda create -n protacs-env -c conda-forge scikit-learn optuna rdkit 23 | conda activate protacs-env 24 | conda install pandas scipy 25 | ``` 26 | 27 | 28 | ## Citation 29 | Nori, Divya et al. (2022) "De novo PROTAC design using graph-based deep generative models." NeurIPS 2022 AI4Science Workshop. 30 | 31 | ## Additional data 32 | Additional data, including saved GraphINVENT model states, generated structures, analysis scripts, and training data, are available on Zenodo [here](https://doi.org/10.5281/zenodo.7278277). 33 | 34 | ## Authors 35 | * Divya Nori 36 | * Rocío Mercado 37 | -------------------------------------------------------------------------------- /features.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divnori/Protac-Design/1456f55400b5b24396f4e17ecf458948e51f315b/features.pkl -------------------------------------------------------------------------------- /surrogate_model.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divnori/Protac-Design/1456f55400b5b24396f4e17ecf458948e51f315b/surrogate_model.pkl --------------------------------------------------------------------------------