├── .DS_Store
├── GraphINVENT_Protac
├── graphinvent
│ ├── Analyzer.py
│ ├── BlockDatasetLoader.py
│ ├── DataProcesser.py
│ ├── GraphGenerator.py
│ ├── GraphGeneratorRL.py
│ ├── MolecularGraph.py
│ ├── ProtacScoringFunction.py
│ ├── ScoringFunction.py
│ ├── Workflow.py
│ ├── __init__.py
│ ├── data
│ │ ├── fine-tuning
│ │ │ └── README.md
│ │ ├── pre-training
│ │ │ └── MOSES
│ │ │ │ └── README.md
│ │ └── protac
│ │ │ └── README.md
│ ├── gnn
│ │ ├── aggregation_mpnn.py
│ │ ├── edge_mpnn.py
│ │ ├── modules.py
│ │ ├── mpnn.py
│ │ └── summation_mpnn.py
│ ├── main.py
│ ├── parameters
│ │ ├── args.py
│ │ ├── constants.py
│ │ ├── defaults.py
│ │ └── load.py
│ ├── training_stats.py
│ └── util.py
└── tools
│ ├── analyze_final_epoch.py
│ ├── atom_types.py
│ ├── combine_HDFs.py
│ ├── combine_generation_batches.py
│ ├── formal_charges.py
│ ├── max_n_nodes.py
│ ├── split_filter_protac.py
│ ├── submit-split-preprocessing-supercloud.py
│ ├── tdc-create-dataset.py
│ ├── utils.py
│ └── visualization.py
├── LICENSE.md
├── README.md
├── binary_label_metrics.py
├── features.pkl
├── molecule_metrics.ipynb
├── surrogate_model.ipynb
└── surrogate_model.pkl
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divnori/Protac-Design/1456f55400b5b24396f4e17ecf458948e51f315b/.DS_Store
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/BlockDatasetLoader.py:
--------------------------------------------------------------------------------
1 | """
2 | The `BlockDatasetLoader` defines custom `DataLoader`s and `Dataset`s used to
3 | efficiently load data from HDF files in this work
4 | """
5 | # load general packages and functions
6 | from typing import Tuple
7 | import torch
8 | import h5py
9 |
10 |
11 | class BlockDataLoader(torch.utils.data.DataLoader):
12 | """
13 | Main `DataLoader` class which has been modified so as to read training data
14 | from disk in blocks, as opposed to a single line at a time (as is done in
15 | the original `DataLoader` class).
16 | """
17 | def __init__(self, dataset : torch.utils.data.Dataset, batch_size : int=100,
18 | block_size : int=10000, shuffle : bool=True, n_workers : int=0,
19 | pin_memory : bool=True) -> None:
20 |
21 | # define variables to be used throughout dataloading
22 | self.dataset = dataset # `HDFDataset` object
23 | self.batch_size = batch_size # `int`
24 | self.block_size = block_size # `int`
25 | self.shuffle = shuffle # `bool`
26 | self.n_workers = n_workers # `int`
27 | self.pin_memory = pin_memory # `bool`
28 | self.block_dataset = BlockDataset(self.dataset,
29 | batch_size=self.batch_size,
30 | block_size=self.block_size)
31 |
32 | def __iter__(self) -> torch.Tensor:
33 |
34 | # define a regular `DataLoader` using the `BlockDataset`
35 | block_loader = torch.utils.data.DataLoader(self.block_dataset,
36 | shuffle=self.shuffle,
37 | num_workers=self.n_workers)
38 |
39 | # define a condition for determining whether to drop the last block this
40 | # is done if the remainder block is very small (less than a tenth the
41 | # size of a normal block)
42 | condition = bool(
43 | int(self.block_dataset.__len__()/self.block_size) > 1 &
44 | self.block_dataset.__len__()%self.block_size < self.block_size/10
45 | )
46 |
47 | # loop through and load BLOCKS of data every iteration
48 | for block in block_loader:
49 | block = [torch.squeeze(b) for b in block]
50 |
51 | # wrap each block in a `ShuffleBlock` so that data can be shuffled
52 | # within blocks
53 | batch_loader = torch.utils.data.DataLoader(
54 | dataset=ShuffleBlockWrapper(block),
55 | shuffle=self.shuffle,
56 | batch_size=self.batch_size,
57 | num_workers=self.n_workers,
58 | pin_memory=self.pin_memory,
59 | drop_last=condition
60 | )
61 |
62 | for batch in batch_loader:
63 | yield batch
64 |
65 | def __len__(self) -> int:
66 | # returns the number of graphs in the DataLoader
67 | n_blocks = len(self.dataset) // self.block_size
68 | n_rem = len(self.dataset) % self.block_size
69 | n_batch_per_block = self.__ceil__(self.block_size, self.batch_size)
70 | n_last = self.__ceil__(n_rem, self.batch_size)
71 | return n_batch_per_block * n_blocks + n_last
72 |
73 | def __ceil__(self, i : int, j : int) -> int:
74 | return (i + j - 1) // j
75 |
76 |
77 | class BlockDataset(torch.utils.data.Dataset):
78 | """
79 | Modified `Dataset` class which returns BLOCKS of data when `__getitem__()`
80 | is called.
81 | """
82 | def __init__(self, dataset : torch.utils.data.Dataset, batch_size : int=100,
83 | block_size : int=10000) -> None:
84 |
85 | assert block_size >= batch_size, "Block size should be > batch size."
86 |
87 | self.block_size = block_size # `int`
88 | self.batch_size = batch_size # `int`
89 | self.dataset = dataset # `HDFDataset`
90 |
91 | def __getitem__(self, idx : int) -> torch.Tensor:
92 | # returns a block of data from the dataset
93 | start = idx * self.block_size
94 | end = min((idx + 1) * self.block_size, len(self.dataset))
95 | return self.dataset[start:end]
96 |
97 | def __len__(self) -> int:
98 | # returns the number of blocks in the dataset
99 | return (len(self.dataset) + self.block_size - 1) // self.block_size
100 |
101 |
102 | class ShuffleBlockWrapper:
103 | """
104 | Extra class used to wrap a block of data, enabling data to get shuffled
105 | *within* a block.
106 | """
107 | def __init__(self, data : torch.Tensor) -> None:
108 | self.data = data
109 |
110 | def __getitem__(self, idx : int) -> torch.Tensor:
111 | return [d[idx] for d in self.data]
112 |
113 | def __len__(self) -> int:
114 | return len(self.data[0])
115 |
116 |
117 | class HDFDataset(torch.utils.data.Dataset):
118 | """
119 | Reads and collects data from an HDF file with three datasets: "nodes",
120 | "edges", and "APDs".
121 | """
122 | def __init__(self, path : str) -> None:
123 |
124 | self.path = path
125 | hdf_file = h5py.File(self.path, "r+", swmr=True)
126 |
127 | # load each HDF dataset
128 | self.nodes = hdf_file.get("nodes")
129 | self.edges = hdf_file.get("edges")
130 | self.apds = hdf_file.get("APDs")
131 |
132 | # get the number of elements in the dataset
133 | self.n_subgraphs = self.nodes.shape[0]
134 |
135 | def __getitem__(self, idx : int) -> \
136 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
137 |
138 | # returns specific graph elements
139 | nodes_i = torch.from_numpy(self.nodes[idx]).type(torch.float32)
140 | edges_i = torch.from_numpy(self.edges[idx]).type(torch.float32)
141 | apd_i = torch.from_numpy(self.apds[idx]).type(torch.float32)
142 |
143 | return (nodes_i, edges_i, apd_i)
144 |
145 | def __len__(self) -> int:
146 | # returns the number of graphs in the dataset
147 | return self.n_subgraphs
148 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/DataProcesser.py:
--------------------------------------------------------------------------------
1 | """
2 | The `DataProcesser` class contains functions for pre-processing training data.
3 | """
4 | # load general packages and functions
5 | import os
6 | import numpy as np
7 | import rdkit
8 | import h5py
9 | from tqdm import tqdm
10 |
11 | # load GraphINVENT-specific functions
12 | from Analyzer import Analyzer
13 | from parameters.constants import constants
14 | import parameters.load as load
15 | from MolecularGraph import PreprocessingGraph
16 | import util
17 |
18 |
19 | class DataProcesser:
20 | """
21 | A class for preprocessing molecular sets and writing them to HDF files.
22 | """
23 | def __init__(self, path : str, is_training_set : bool=False, molset = []) -> None:
24 | """
25 | Args:
26 | ----
27 | path (string) : Full path/filename to SMILES file containing
28 | molecules.
29 | is_training_set (bool) : Indicates if this is the training set, as we
30 | calculate a few additional things for the training
31 | set.
32 | """
33 | # define some variables for later use
34 | self.path = path
35 | print("path defined", flush = True)
36 | self.is_training_set = is_training_set
37 | print("training set defined", flush = True)
38 | self.dataset_names = ["nodes", "edges", "APDs"]
39 | print("dataset names defined", flush = True)
40 | self.get_dataset_dims() # creates `self.dims`
41 | print("get dims done", flush = True)
42 |
43 | # load the molecules
44 | if constants.compute_train_csv:
45 | self.molecule_set = molset
46 | else:
47 | self.molecule_set = load.molecules(self.path)
48 | print("molecules loaded", flush = True)
49 | print(f"len of loaded molecules {len(self.molecule_set)}", flush=True)
50 | #self.molecule_set = molset
51 |
52 | # placeholders
53 | self.molecule_subset = None
54 | self.dataset = None
55 | self.skip_collection = None
56 | self.resume_idx = None
57 | self.ts_properties = None
58 | self.restart_index_file = None
59 | self.hdf_file = None
60 | self.dataset_size = None
61 |
62 | print("placeholders set", flush = True)
63 |
64 | # get total number of molecules, and total number of subgraphs in their
65 | # decoding routes
66 | self.n_molecules = len(self.molecule_set)
67 | print("got len", flush = True)
68 | self.total_n_subgraphs = self.get_n_subgraphs()
69 | print("get subgraphs done", flush = True)
70 | print(f"-- {self.n_molecules} molecules in set.", flush=True)
71 | print(f"-- {self.total_n_subgraphs} total subgraphs in set.",
72 | flush=True)
73 |
74 | def preprocess(self) -> None:
75 | """
76 | Prepares an HDF file to save three different datasets to it (`nodes`,
77 | `edges`, `APDs`), and slowly fills it in by looping over all the
78 | molecules in the data in groups (or "mini-batches").
79 | """
80 | with h5py.File(f"{self.path[:-3]}h5.chunked", "a") as self.hdf_file:
81 |
82 | self.restart_index_file = constants.dataset_dir + "index.restart"
83 |
84 | if constants.restart and os.path.exists(self.restart_index_file):
85 | self.restart_preprocessing_job()
86 | else:
87 | self.start_new_preprocessing_job()
88 |
89 | # keep track of the dataset size (to resize later)
90 | self.dataset_size = 0
91 |
92 | self.ts_properties = None
93 |
94 | # this is where we fill the datasets with actual data by looping
95 | # over subgraphs in blocks of size `constants.batch_size`
96 | for idx in tqdm(range(0, self.total_n_subgraphs, constants.batch_size)):
97 |
98 | if not self.skip_collection:
99 |
100 | # add `constants.batch_size` subgraphs from
101 | # `self.molecule_subset` to the dataset (and if training
102 | # set, calculate their properties and add these to
103 | # `self.ts_properties`)
104 | self.get_subgraphs(init_idx=idx)
105 |
106 | util.write_last_molecule_idx(
107 | last_molecule_idx=self.resume_idx,
108 | dataset_size=self.dataset_size,
109 | restart_file_path=constants.dataset_dir
110 | )
111 |
112 |
113 | if self.resume_idx == self.n_molecules:
114 | # all molecules have been processed
115 |
116 | self.resize_datasets() # remove padding from initialization
117 | print("Datasets resized.", flush=True)
118 |
119 | print(f"mol set length is {len(self.molecule_set)}",flush=True)
120 |
121 | if self.is_training_set and not constants.restart:
122 | print("Writing training set properties.", flush=True)
123 | util.write_ts_properties(
124 | training_set_properties=self.ts_properties
125 | )
126 |
127 | break
128 |
129 | print("* Resaving datasets in unchunked format.")
130 | self.resave_datasets_unchunked()
131 |
132 | def restart_preprocessing_job(self) -> None:
133 | """
134 | Restarts a preprocessing job. Uses an index specified in the dataset
135 | directory to know where to resume preprocessing.
136 | """
137 | try:
138 | self.resume_idx, self.dataset_size = util.read_last_molecule_idx(
139 | restart_file_path=constants.dataset_dir
140 | )
141 | except:
142 | self.resume_idx, self.dataset_size = 0, 0
143 | self.skip_collection = bool(
144 | self.resume_idx == self.n_molecules and self.is_training_set
145 | )
146 |
147 | # load dictionary of previously created datasets (`self.dataset`)
148 | self.load_datasets(hdf_file=self.hdf_file)
149 |
150 | def start_new_preprocessing_job(self) -> None:
151 | """
152 | Starts a fresh preprocessing job.
153 | """
154 | self.resume_idx = 0
155 | self.skip_collection = False
156 |
157 | # create a dictionary of empty HDF datasets (`self.dataset`)
158 | self.create_datasets(hdf_file=self.hdf_file)
159 |
160 | def resave_datasets_unchunked(self) -> None:
161 | """
162 | Resaves the HDF datasets in an unchunked format to remove initial
163 | padding.
164 | """
165 | with h5py.File(f"{self.path[:-3]}h5.chunked", "r", swmr=True) as chunked_file:
166 | keys = list(chunked_file.keys())
167 | data = [chunked_file.get(key)[:] for key in keys]
168 | data_zipped = tuple(zip(data, keys))
169 |
170 | with h5py.File(f"{self.path[:-3]}h5", "w") as unchunked_file:
171 | for d, k in tqdm(data_zipped):
172 | unchunked_file.create_dataset(
173 | k, chunks=None, data=d, dtype=np.dtype("int8")
174 | )
175 |
176 | # remove the restart file and chunked file (don't need them anymore)
177 | os.remove(self.restart_index_file)
178 | os.remove(f"{self.path[:-3]}h5.chunked")
179 |
180 | def get_subgraphs(self, init_idx = 0) -> None:
181 | """
182 | Adds `constants.batch_size` subgraphs from `self.molecule_subset` to the
183 | HDF dataset (and if currently processing the training set, also
184 | calculates the full graphs' properties and adds these to
185 | `self.ts_properties`).
186 |
187 | Args:
188 | ----
189 | init_idx (int) : As analysis is done in blocks/slices, `init_idx` is
190 | the start index for the next block/slice to be taken
191 | from `self.molecule_subset`.
192 | """
193 | data_subgraphs, data_apds, molecular_graph_list = [], [], [] # initialize
194 |
195 | # convert all molecules in `self.molecules_subset` to `PreprocessingGraphs`
196 | molecular_graph_generator = map(self.get_graph, self.molecule_set)
197 |
198 | molecules_processed = 0 # keep track of the number of molecules processed
199 |
200 | # # loop over all the `PreprocessingGraph`s
201 | for graph in molecular_graph_generator:
202 | molecules_processed += 1
203 |
204 | # store `PreprocessingGraph` object
205 | molecular_graph_list.append(graph)
206 |
207 | # get the number of decoding graphs
208 | n_subgraphs = graph.get_decoding_route_length()
209 |
210 | for new_subgraph_idx in range(n_subgraphs):
211 |
212 | # `get_decoding_route_state() returns a list of [`subgraph`, `apd`],
213 | subgraph, apd = graph.get_decoding_route_state(
214 | subgraph_idx=new_subgraph_idx
215 | )
216 |
217 | # "collect" all APDs corresponding to pre-existing subgraphs,
218 | # otherwise append both new subgraph and new APD
219 | count = 0
220 | for idx, existing_subgraph in enumerate(data_subgraphs):
221 |
222 | count += 1
223 | # check if subgraph `subgraph` is "already" in
224 | # `data_subgraphs` as `existing_subgraph`, and if so, add
225 | # the "new" APD to the "old"
226 | try: # first compare the node feature matrices
227 | nodes_equal = (subgraph[0] == existing_subgraph[0]).all()
228 | except AttributeError:
229 | nodes_equal = False
230 | try: # then compare the edge feature tensors
231 | edges_equal = (subgraph[1] == existing_subgraph[1]).all()
232 | except AttributeError:
233 | edges_equal = False
234 |
235 | # if both matrices have a match, then subgraphs are the same
236 | if nodes_equal and edges_equal:
237 | existing_apd = data_apds[idx]
238 | existing_apd += apd
239 | break
240 |
241 | # if subgraph is not already in `data_subgraphs`, append it
242 | if count == len(data_subgraphs) or count == 0:
243 | data_subgraphs.append(subgraph)
244 | data_apds.append(apd)
245 |
246 | # if `constants.batch_size` unique subgraphs have been
247 | # processed, save group to the HDF dataset
248 | len_data_subgraphs = len(data_subgraphs)
249 | if len_data_subgraphs == constants.batch_size:
250 | self.save_group(data_subgraphs=data_subgraphs,
251 | data_apds=data_apds,
252 | group_size=len_data_subgraphs,
253 | init_idx=init_idx)
254 |
255 | # get molecular properties for group iff it's the training set
256 | self.get_ts_properties(molecular_graphs=molecular_graph_list,
257 | group_size=constants.batch_size)
258 |
259 | # keep track of the last molecule to be processed in
260 | # `self.resume_idx`
261 | # number of molecules processed:
262 | self.resume_idx += molecules_processed
263 | # subgraphs processed:
264 | self.dataset_size += constants.batch_size
265 |
266 | return None
267 |
268 | n_processed_subgraphs = len(data_subgraphs)
269 |
270 | # save group with < `constants.batch_size` subgraphs (e.g. last block)
271 | self.save_group(data_subgraphs=data_subgraphs,
272 | data_apds=data_apds,
273 | group_size=n_processed_subgraphs,
274 | init_idx=init_idx)
275 |
276 | # get molecular properties for this group iff it's the training set
277 | self.get_ts_properties(molecular_graphs=molecular_graph_list,
278 | group_size=constants.batch_size)
279 |
280 | # # keep track of the last molecule to be processed in `self.resume_idx`
281 | # self.resume_idx += molecules_processed # number of molecules processed
282 | # self.dataset_size += molecules_processed # subgraphs processed
283 |
284 | return None
285 |
286 | def create_datasets(self, hdf_file : h5py._hl.files.File) -> None:
287 | """
288 | Creates a dictionary of HDF5 datasets (`self.dataset`).
289 |
290 | Args:
291 | ----
292 | hdf_file (h5py._hl.files.File) : HDF5 file which will contain the datasets.
293 | """
294 | self.dataset = {} # initialize
295 |
296 | for ds_name in self.dataset_names:
297 | self.dataset[ds_name] = hdf_file.create_dataset(
298 | ds_name,
299 | (self.total_n_subgraphs, *self.dims[ds_name]),
300 | chunks=True, # must be True for resizing later
301 | dtype=np.dtype("int8")
302 | )
303 |
304 | def resize_datasets(self) -> None:
305 | """
306 | Resizes the HDF datasets, since much longer datasets are initialized
307 | when first creating the HDF datasets (it it is impossible to predict
308 | how many graphs will be equivalent beforehand).
309 | """
310 | for dataset_name in self.dataset_names:
311 | try:
312 | self.dataset[dataset_name].resize(
313 | (self.dataset_size, *self.dims[dataset_name]))
314 | except KeyError: # `f_term` has no extra dims
315 | self.dataset[dataset_name].resize((self.dataset_size,))
316 |
317 | def get_dataset_dims(self) -> None:
318 | """
319 | Calculates the dimensions of the node features, edge features, and APDs,
320 | and stores them as lists in a dict (`self.dims`), where keys are the
321 | dataset name.
322 |
323 | Shapes:
324 | ------
325 | dims["nodes"] : [max N nodes, N atom types + N formal charges]
326 | dims["edges"] : [max N nodes, max N nodes, N bond types]
327 | dims["APDs"] : [APD length = f_add length + f_conn length + f_term length]
328 | """
329 | self.dims = {}
330 | self.dims["nodes"] = constants.dim_nodes
331 | self.dims["edges"] = constants.dim_edges
332 | self.dims["APDs"] = constants.dim_apd
333 |
334 | def get_graph(self, mol : rdkit.Chem.Mol) -> PreprocessingGraph:
335 | """
336 | Converts an `rdkit.Chem.Mol` object to `PreprocessingGraph`.
337 |
338 | Args:
339 | ----
340 | mol (rdkit.Chem.Mol) : Molecule to convert.
341 |
342 | Returns:
343 | -------
344 | molecular_graph (PreprocessingGraph) : Molecule, now as a graph.
345 | """
346 | if mol is not None:
347 | print(rdkit.Chem.MolToSmiles(mol),flush=True)
348 | if not constants.use_aromatic_bonds:
349 | rdkit.Chem.Kekulize(mol, clearAromaticFlags=True)
350 | molecular_graph = PreprocessingGraph(molecule=mol,
351 | constants=constants)
352 | return molecular_graph
353 |
354 | def get_molecule_subset(self) -> None:
355 | """
356 | Slices `self.molecule_set` into a subset of molecules of size
357 | `constants.batch_size`, starting from `self.resume_idx`.
358 | `self.n_molecules` is the number of molecules in the full
359 | `self.molecule_set`.
360 | """
361 | init_idx = self.resume_idx
362 | subset_size = constants.batch_size
363 | self.molecule_subset = []
364 | max_idx = min(init_idx + subset_size, self.n_molecules)
365 |
366 | count = -1
367 | for mol in self.molecule_set:
368 | if mol is not None:
369 | count += 1
370 | if count < init_idx:
371 | continue
372 | elif count >= max_idx:
373 | return self.molecule_subset
374 | else:
375 | self.molecule_subset.append(mol)
376 |
377 | def get_n_subgraphs(self) -> int:
378 | """
379 | Calculates the total number of subgraphs in the decoding route of all
380 | molecules in `self.molecule_set`. Loads training, testing, or validation
381 | set. First, the `PreprocessingGraph` for each molecule is obtained, and
382 | then the length of the decoding route is trivially calculated for each.
383 |
384 | Returns:
385 | -------
386 | n_subgraphs (int) : Sum of number of subgraphs in decoding routes for
387 | all molecules in `self.molecule_set`.
388 | """
389 | n_subgraphs = 0 # start the count
390 | print("initialized n_subgraphs",flush=True)
391 |
392 | # convert molecules in `self.molecule_set` to `PreprocessingGraph`s
393 | molecular_graph_generator = map(self.get_graph, self.molecule_set)
394 | print("converted mol set to graph",flush=True)
395 |
396 | # loop over all the `PreprocessingGraph`s
397 | for molecular_graph in molecular_graph_generator:
398 | # get the number of decoding graphs (i.e. the decoding route length)
399 | # and add them to the running count
400 | n_subgraphs += molecular_graph.get_decoding_route_length()
401 | print("got route length",flush=True)
402 |
403 | print("done w graph loop",flush=True)
404 |
405 | return int(n_subgraphs)
406 |
407 | def get_ts_properties(self, molecular_graphs : list, group_size : int) -> \
408 | None:
409 | """
410 | Gets molecular properties for group of molecular graphs, only for the
411 | training set.
412 |
413 | Args:
414 | ----
415 | molecular_graphs (list) : Contains `PreprocessingGraph`s.
416 | group_size (int) : Size of "group" (i.e. slice of graphs).
417 | """
418 | if self.is_training_set:
419 | print("is training set so calculating training set props", flush=True)
420 |
421 | analyzer = Analyzer()
422 | ts_properties = analyzer.evaluate_training_set(
423 | preprocessing_graphs=molecular_graphs
424 | )
425 | print("initialized ts analyzer",flush=True)
426 |
427 | # merge properties of current group with the previous group analyzed
428 | if self.ts_properties: # `self.ts_properties` is a dictionary
429 | print("in self.ts_properties conditional", flush=True)
430 | if not constants.compute_train_csv:
431 | self.ts_properties = analyzer.combine_ts_properties(
432 | prev_properties=self.ts_properties,
433 | next_properties=ts_properties,
434 | weight_next=group_size
435 | )
436 | else:
437 | print("setting self.ts_properties", flush=True)
438 | self.ts_properties = ts_properties
439 | else: # `self.ts_properties` is None (has not been calculated yet)
440 | self.ts_properties = ts_properties
441 | else:
442 | self.ts_properties = None
443 | print(self.ts_properties)
444 |
445 |
446 | def load_datasets(self, hdf_file : h5py._hl.files.File) -> None:
447 | """
448 | Creates a dictionary of HDF datasets (`self.dataset`) which have been
449 | previously created (for restart jobs only).
450 |
451 | Args:
452 | ----
453 | hdf_file (h5py._hl.files.File) : HDF file containing all the datasets.
454 | """
455 | self.dataset = {} # initialize dictionary of datasets
456 |
457 | # use the names of the datasets as the keys in `self.dataset`
458 | for ds_name in self.dataset_names:
459 | self.dataset[ds_name] = hdf_file.get(ds_name)
460 |
461 | def save_group(self, data_subgraphs : list, data_apds : list,
462 | group_size : int, init_idx : int) -> None:
463 | """
464 | Saves a group of padded subgraphs and their corresponding APDs to the HDF
465 | datasets as `numpy.ndarray`s.
466 |
467 | Args:
468 | ----
469 | data_subgraphs (list) : Contains molecular subgraphs.
470 | data_apds (list) : Contains APDs.
471 | group_size (int) : Size of HDF "slice".
472 | init_idx (int) : Index to begin slicing.
473 | """
474 | # convert to `np.ndarray`s
475 | nodes = np.array([graph_tuple[0] for graph_tuple in data_subgraphs])
476 | edges = np.array([graph_tuple[1] for graph_tuple in data_subgraphs])
477 | apds = np.array(data_apds)
478 |
479 | end_idx = init_idx + group_size # idx to end slicing
480 |
481 | # once data is padded, save it to dataset slice
482 | # Broadcasting errors happen in the final group,
483 | # this skips them so processing can proceed
484 | try:
485 | self.dataset["nodes"][init_idx:end_idx] = nodes
486 | self.dataset["edges"][init_idx:end_idx] = edges
487 | self.dataset["APDs"][init_idx:end_idx] = apds
488 | except:
489 | print("\nBroadcasting error, skipping.\n")
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/ProtacScoringFunction.py:
--------------------------------------------------------------------------------
1 | """
2 | Contains function to score a given SMILES string as a high quality, average quality, or low quality
3 | PROTAC based on protac scoring boosted tree model
4 | """
5 | import rdkit
6 | from collections import namedtuple
7 | from rdkit import Chem
8 | from rdkit.Chem import AllChem, DataStructs
9 | from rdkit.Chem import rdchem
10 | import numpy as np
11 | import pickle
12 | from parameters import constants
13 | import requests as r
14 |
15 | def predictProteinDegradation(mol : rdkit.Chem.Mol, constants : namedtuple, cellType='SRD15', receptor='Q9Y616', e3Ligase='CRBN'):
16 | '''
17 | Returns a 0 or 1 based on a protac's molecule protein degradation potential
18 | 1 -> degrades
19 | 0 -> does not degrade
20 | '''
21 | try:
22 | scoring_model = constants.qsar_models["protac_qsar_model"]
23 | features = constants.activity_model_features
24 | Chem.SanitizeMol(mol)
25 |
26 | ngrams_array = np.zeros((1,7841), dtype=np.int8)
27 | # baseUrl="http://www.uniprot.org/uniprot/"
28 | # currentUrl=baseUrl+receptor+".fasta"
29 | # response = r.post(currentUrl)
30 | # cData=''.join(response.text)
31 | # i = cData.index('\n')+1
32 | # seq = cData[i:].strip().lower()
33 | seq = "MAGNCGARGALSAHTLLFDLPPALLGELCAVLDSCDGALGWRGLAERLSSSWLDVRHIEKYVDQGKSGTRELLWSWAQKNKTIGDLLQVLQEMGHRRAIHLITNYGAVLSPSEKSYQEGGFPNILFKETANVTVDNVLIPEHNEKGILLKSSISFQNIIEGTRNFHKDFLIGEGEIFEVYRVEIQNLTYAVKLFKQEKKMQCKKHWKRFLSELEVLLLFHHPNILELAAYFTETEKFCLIYPYMRNGTLFDRLQCVGDTAPLPWHIRIGILIGISKAIHYLHNVQPCSVICGSISSANILLDDQFQPKLTDFAMAHFRSHLEHQSCTINMTSSSSKHLWYMPEEYIRQGKLSIKTDVYSFGIVIMEVLTGCRVVLDDPKHIQLRDLLRELMEKRGLDSCLSFLDKKVPPCPRNFSAKLFCLAGRCAATRAKLRPSMDEVLNTLESTQASLYFAEDPPTSLKSFRCPSPLFLENVPSIPVEDDESQNNNLLPSDEGLRIDRMTQKTPFECSQSEVMFLSLDKKPESKRNEEACNMPSSSCEESWFPKYIVPSQDLRPYKVNIDPSSEAPGHSCRSRPVESSCSSKFSWDEYEQYKKE".lower()
34 | ngrams = features[1237:]
35 | for i in range(len(ngrams)):
36 | n = seq.count(ngrams[i])
37 | ngrams_array[0][i] = n
38 |
39 | fingerprint = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024)
40 | fp_array = np.zeros((0,), dtype=np.int8)
41 | DataStructs.ConvertToNumpyArray(fingerprint, fp_array)
42 |
43 | ct_ind = features.index("ct_"+cellType)
44 | e3_ind = features.index("e3_"+e3Ligase)
45 |
46 | input = list(0 for i in range(5)) + list(fp_array) + list(0 for i in range(207))+ list(ngrams_array[0])
47 | input[ct_ind] = 1
48 | input[e3_ind] = 1
49 |
50 | output = scoring_model.predict([input])
51 | smi = Chem.MolToSmiles(mol)
52 | letters = smi.replace("[=()-]","")
53 | if output[0][0]-output[0][1] < 0 and len(letters)>=30:
54 | return 1
55 | else:
56 | return 0
57 | except Exception as e:
58 | print('EXCEPTION IN PREDICT PROTEIN DEGRADATION',flush=True)
59 | print(e, flush=True)
60 | return 0
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/ScoringFunction.py:
--------------------------------------------------------------------------------
1 | """
2 | This class is used for defining the scoring function(s) which can be used during
3 | fine-tuning.
4 | """
5 | # load general packages and functions
6 | from collections import namedtuple
7 | import torch
8 | from rdkit import DataStructs
9 | from rdkit import Chem
10 | from rdkit.Chem import QED, AllChem
11 | import numpy as np
12 | import sklearn
13 | from sklearn import svm
14 | from ProtacScoringFunction import *
15 |
16 | class ScoringFunction:
17 | """
18 | A class for defining the scoring function components.
19 | """
20 | def __init__(self, constants : namedtuple) -> None:
21 | """
22 | Args:
23 | ----
24 | constants (namedtuple) : Contains job parameters as well as global
25 | constants.
26 | """
27 | self.score_components = constants.score_components # list
28 | self.score_type = constants.score_type # list
29 | self.qsar_models = constants.qsar_models # dict
30 | self.device = constants.device
31 | self.max_n_nodes = constants.max_n_nodes
32 | self.score_thresholds = constants.score_thresholds
33 |
34 | self.n_graphs = None # placeholder
35 | self.constants = constants
36 |
37 | assert len(self.score_components) == len(self.score_thresholds), \
38 | "`score_components` and `score_thresholds` do not match."
39 |
40 | def compute_score(self, graphs : list, termination : torch.Tensor,
41 | validity : torch.Tensor, uniqueness : torch.Tensor) -> \
42 | torch.Tensor:
43 | """
44 | Computes the overall score for the input molecular graphs.
45 |
46 | Args:
47 | ----
48 | graphs (list) : Contains molecular graphs to evaluate.
49 | termination (torch.Tensor) : Termination status of input molecular
50 | graphs.
51 | validity (torch.Tensor) : Validity of input molecular graphs.
52 | uniqueness (torch.Tensor) : Uniqueness of input molecular graphs.
53 |
54 | Returns:
55 | -------
56 | final_score (torch.Tensor) : The final scores for each input graph.
57 | """
58 | self.n_graphs = len(graphs)
59 | contributions_to_score = self.get_contributions_to_score(graphs=graphs)
60 |
61 | print(f"contributions_to_score {contributions_to_score}")
62 |
63 | if len(self.score_components) == 1:
64 | print('len is 1')
65 | final_score = contributions_to_score[0]
66 |
67 | elif self.score_type == "continuous":
68 | print('in continuous')
69 | final_score = contributions_to_score[0]
70 | print(f"structure score tensor {contributions_to_score[1]}")
71 | for component in contributions_to_score[1:]:
72 | final_score *= component
73 |
74 | elif self.score_type == "binary":
75 | print('in binary')
76 | component_masks = []
77 | for idx, score_component in enumerate(contributions_to_score):
78 | component_mask = torch.where(
79 | score_component > self.score_thresholds[idx],
80 | torch.ones(self.n_graphs, device=self.device, dtype=torch.uint8),
81 | torch.zeros(self.n_graphs, device=self.device, dtype=torch.uint8)
82 | )
83 | component_masks.append(component_mask)
84 |
85 | final_score = component_masks[0]
86 | for mask in component_masks[1:]:
87 | final_score *= mask
88 | final_score = final_score.float()
89 |
90 | else:
91 | raise NotImplementedError
92 |
93 | print(f"final score before {final_score}")
94 |
95 | # remove contribution of duplicate molecules to the score
96 | final_score *= uniqueness
97 | print(f"uniqueness score {uniqueness}")
98 |
99 | # remove contribution of invalid molecules to the score
100 | final_score *= validity
101 | print(f"validity score {validity}")
102 |
103 | # remove contribution of improperly-terminated molecules to the score
104 | final_score *= termination
105 | print(f"termination score {termination}")
106 |
107 | print(f"final score after {final_score}")
108 |
109 | return final_score
110 |
111 | def get_contributions_to_score(self, graphs : list) -> list:
112 | """
113 | Returns the different elements of the score.
114 |
115 | Args:
116 | ----
117 | graphs (list) : Contains molecular graphs to evaluate.
118 |
119 | Returns:
120 | -------
121 | contributions_to_score (list) : Contains elements of the score due to
122 | each scoring function component.
123 | """
124 | contributions_to_score = []
125 |
126 | for score_component in self.score_components:
127 | if "target_size" in score_component:
128 |
129 | target_size = int(score_component[12:])
130 |
131 | assert target_size <= self.max_n_nodes, \
132 | "Target size > largest possible size (`max_n_nodes`)."
133 | assert 0 < target_size, "Target size must be greater than 0."
134 |
135 | target_size *= torch.ones(self.n_graphs, device=self.device)
136 | n_nodes = torch.tensor([graph.n_nodes for graph in graphs],
137 | device=self.device)
138 | max_nodes = self.max_n_nodes
139 | score = (
140 | torch.ones(self.n_graphs, device=self.device)
141 | - torch.abs(n_nodes - target_size)
142 | / (max_nodes - target_size)
143 | )
144 |
145 | contributions_to_score.append(score)
146 |
147 | elif score_component == "QED":
148 | mols = [graph.molecule for graph in graphs]
149 |
150 | # compute the QED score for each molecule (if possible)
151 | qed = []
152 | for mol in mols:
153 | try:
154 | qed.append(QED.qed(mol))
155 | except:
156 | qed.append(0.0)
157 | score = torch.tensor(qed, device=self.device)
158 |
159 | contributions_to_score.append(score)
160 |
161 |
162 | elif "protac_activity" in score_component:
163 | mols = [graph.molecule for graph in graphs]
164 |
165 | # `score_component` has to be the key to the QSAR model in the
166 | # `self.qsar_models` dict
167 | #qsar_model = self.qsar_models[score_component]
168 | #score = self.compute_activity(mols, qsar_model)
169 | score = self.compute_protac_activity(mols)
170 |
171 | print(f"activity scores {score}")
172 |
173 | contributions_to_score.append(score)
174 |
175 | elif "structure" in score_component:
176 | mols = [graph.molecule for graph in graphs]
177 | score = self.compute_structure(mols)
178 | print(f"structure score {score}")
179 | contributions_to_score.append(score)
180 |
181 | # elif "solubility" in score_component:
182 | # mols = [graph.molecule for graph in graphs]
183 | # score = self.compute_solubility(mols)
184 | # print(f"structure score {score}")
185 | # contributions_to_score.append(score)
186 |
187 | elif "activity" in score_component:
188 | mols = [graph.molecule for graph in graphs]
189 |
190 | # `score_component` has to be the key to the QSAR model in the
191 | # `self.qsar_models` dict
192 | qsar_model = self.qsar_models[score_component]
193 | score = self.compute_activity(mols, qsar_model)
194 |
195 | print(f"activity scores {score}")
196 |
197 | contributions_to_score.append(score)
198 |
199 | else:
200 | raise NotImplementedError("The score component is not defined. "
201 | "You can define it in "
202 | "`ScoringFunction.py`.")
203 |
204 | return contributions_to_score
205 |
206 | def compute_protac_activity(self, mols : list) -> list:
207 | """
208 | Args:
209 | ----
210 | mols (list) : Contains `rdkit.Mol` objects corresponding to molecular
211 | graphs sampled.
212 |
213 | Returns:
214 | -------
215 | activity (list) : Contains predicted protac activity for input molecules.
216 | """
217 |
218 | n_mols = len(mols)
219 | activity = torch.zeros(n_mols, device=self.device)
220 |
221 | for idx, mol in enumerate(mols):
222 | score = predictProteinDegradation(mol, self.constants)
223 | activity[idx] = score
224 |
225 | print(f"shape of activity tensor is {activity.shape}")
226 | print(f"sum of activity tensor is {sum(activity.tolist())}")
227 |
228 | return activity
229 |
230 | def compute_structure(self, mols : list) -> list:
231 | """
232 | Args:
233 | ----
234 | mols (list) : Contains `rdkit.Mol` objects corresponding to molecular
235 | graphs sampled.
236 |
237 | Returns:
238 | -------
239 | activity (list) : Contains predicted structures for input molecules (0 for non-protac, 1 for protac).
240 | """
241 |
242 | n_mols = len(mols)
243 | activity = torch.zeros(n_mols, device=self.device)
244 |
245 | for idx, mol in enumerate(mols):
246 | score = predictStructure(mol, self.constants)
247 | activity[idx] = score
248 |
249 | return activity
250 |
251 |
252 |
253 | def compute_activity(self, mols : list,
254 | activity_model : sklearn.svm.SVC) -> list:
255 | """
256 | Note: this function may have to be tuned/replicated depending on how
257 | the activity model is saved.
258 |
259 | Args:
260 | ----
261 | mols (list) : Contains `rdkit.Mol` objects corresponding to molecular
262 | graphs sampled.
263 | activity_model (sklearn.svm.classes.SVC) : Pre-trained QSAR model.
264 |
265 | Returns:
266 | -------
267 | activity (list) : Contains predicted activities for input molecules.
268 | """
269 | n_mols = len(mols)
270 | activity = torch.zeros(n_mols, device=self.device)
271 |
272 | for idx, mol in enumerate(mols):
273 | try:
274 | fingerprint = AllChem.GetMorganFingerprintAsBitVect(mol,
275 | 2,
276 | nBits=2048)
277 | ecfp4 = np.zeros((2048,))
278 | DataStructs.ConvertToNumpyArray(fingerprint, ecfp4)
279 | activity[idx] = activity_model.predict_proba([ecfp4])[0][1]
280 | except:
281 | pass # activity[idx] will remain 0.0
282 |
283 | return activity
284 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divnori/Protac-Design/1456f55400b5b24396f4e17ecf458948e51f315b/GraphINVENT_Protac/graphinvent/__init__.py
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/data/fine-tuning/README.md:
--------------------------------------------------------------------------------
1 |
2 | Place the features.pkl and surrogate_model.pkl files here.
3 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/data/pre-training/MOSES/README.md:
--------------------------------------------------------------------------------
1 |
2 | Put the following files in this folder:
3 |
4 | - train.smi
5 | - valid.smi
6 | - test.smi
7 |
8 | Once you've generated train.csv using training_stats.py, put that here as well.
9 |
10 | After main.py has been run with (defaults.py -> job_type = 'preprocess') there should be the following files in this directory:
11 |
12 | - train.h5
13 | - valid.h5
14 | - test.h5
15 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/data/protac/README.md:
--------------------------------------------------------------------------------
1 |
2 | Download protac.csv from http://cadd.zju.edu.cn/protacdb/statics/binaryDownload/csv/protac/protac.csv and save here under the name data.csv
3 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/gnn/aggregation_mpnn.py:
--------------------------------------------------------------------------------
1 | """
2 | Defines the `AggregationMPNN` class.
3 | """
4 | # load general packages and functions
5 | from collections import namedtuple
6 | import torch
7 |
8 |
9 | class AggregationMPNN(torch.nn.Module):
10 | """
11 | Abstract `AggregationMPNN` class. Specific models using this class are
12 | defined in `mpnn.py`; these are the attention networks AttS2V and AttGGNN.
13 | """
14 | def __init__(self, constants : namedtuple) -> None:
15 | super().__init__()
16 |
17 | self.hidden_node_features = constants.hidden_node_features
18 | self.edge_features = constants.n_edge_features
19 | self.message_size = constants.message_size
20 | self.message_passes = constants.message_passes
21 | self.constants = constants
22 |
23 | def aggregate_message(self, nodes : torch.Tensor, node_neighbours : torch.Tensor,
24 | edges : torch.Tensor, mask : torch.Tensor) -> None:
25 | """
26 | Message aggregation function, to be implemented in all `AggregationMPNN` subclasses.
27 |
28 | Args:
29 | ----
30 | nodes (torch.Tensor) : Batch of node feature vectors.
31 | node_neighbours (torch.Tensor) : Batch of node feature vectors for neighbors.
32 | edges (torch.Tensor) : Batch of edge feature vectors.
33 | mask (torch.Tensor) : Mask for non-existing neighbors, where
34 | elements are 1 if corresponding element
35 | exists and 0 otherwise.
36 |
37 | Shapes:
38 | ------
39 | nodes : (total N nodes in batch, N node features)
40 | node_neighbours : (total N nodes in batch, max node degree, N node features)
41 | edges : (total N nodes in batch, max node degree, N edge features)
42 | mask : (total N nodes in batch, max node degree)
43 | """
44 | raise NotImplementedError
45 |
46 | def update(self, nodes : torch.Tensor, messages : torch.Tensor) -> None:
47 | """
48 | Message update function, to be implemented in all `AggregationMPNN` subclasses.
49 |
50 | Args:
51 | ----
52 | nodes (torch.Tensor) : Batch of node feature vectors.
53 | messages (torch.Tensor) : Batch of incoming messages.
54 |
55 | Shapes:
56 | ------
57 | nodes : (total N nodes in batch, N node features)
58 | messages : (total N nodes in batch, N node features)
59 | """
60 | raise NotImplementedError
61 |
62 | def readout(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor,
63 | node_mask : torch.Tensor) -> None:
64 | """
65 | Local readout function, to be implemented in all `AggregationMPNN` subclasses.
66 |
67 | Args:
68 | ----
69 | hidden_nodes (torch.Tensor) : Batch of node feature vectors.
70 | input_nodes (torch.Tensor) : Batch of node feature vectors.
71 | node_mask (torch.Tensor) : Mask for non-existing neighbors, where
72 | elements are 1 if corresponding element
73 | exists and 0 otherwise.
74 |
75 | Shapes:
76 | ------
77 | hidden_nodes : (total N nodes in batch, N node features)
78 | input_nodes : (total N nodes in batch, N node features)
79 | node_mask : (total N nodes in batch, N features)
80 | """
81 | raise NotImplementedError
82 |
83 | def forward(self, nodes : torch.Tensor, edges : torch.Tensor) -> torch.Tensor:
84 | """
85 | Defines forward pass.
86 |
87 | Args:
88 | ----
89 | nodes (torch.Tensor) : Batch of node feature matrices.
90 | edges (torch.Tensor) : Batch of edge feature tensors.
91 |
92 | Shapes:
93 | ------
94 | nodes : (batch size, N nodes, N node features)
95 | edges : (batch size, N nodes, N nodes, N edge features)
96 |
97 | Returns:
98 | -------
99 | output (torch.Tensor) : This would normally be the learned graph
100 | representation, but in all MPNN readout functions
101 | in this work, the last layer is used to predict
102 | the action probability distribution for a batch
103 | of graphs from the learned graph representation.
104 | """
105 | adjacency = torch.sum(edges, dim=3)
106 |
107 | # **note: "idc" == "indices", "nghb{s}" == "neighbour(s)"
108 | edge_batch_batch_idc, edge_batch_node_idc, edge_batch_nghb_idc = \
109 | adjacency.nonzero(as_tuple=True)
110 |
111 | node_batch_batch_idc, node_batch_node_idc = adjacency.sum(-1).nonzero(as_tuple=True)
112 | node_batch_adj = adjacency[node_batch_batch_idc, node_batch_node_idc, :]
113 | node_batch_size = node_batch_batch_idc.shape[0]
114 | node_degrees = node_batch_adj.sum(-1).long()
115 | max_node_degree = node_degrees.max()
116 |
117 | node_batch_node_nghbs = torch.zeros(node_batch_size,
118 | max_node_degree,
119 | self.hidden_node_features,
120 | device=self.constants.device)
121 | node_batch_edges = torch.zeros(node_batch_size,
122 | max_node_degree,
123 | self.edge_features,
124 | device=self.constants.device)
125 |
126 | node_batch_nghb_nghb_idc = torch.cat(
127 | [torch.arange(i) for i in node_degrees]
128 | ).long()
129 |
130 | edge_batch_node_batch_idc = torch.cat(
131 | [i * torch.ones(degree) for i, degree in enumerate(node_degrees)]
132 | ).long()
133 |
134 | node_batch_node_nghb_mask = torch.zeros(node_batch_size,
135 | max_node_degree,
136 | device=self.constants.device)
137 |
138 | node_batch_node_nghb_mask[edge_batch_node_batch_idc, node_batch_nghb_nghb_idc] = 1
139 |
140 | node_batch_edges[edge_batch_node_batch_idc, node_batch_nghb_nghb_idc, :] = \
141 | edges[edge_batch_batch_idc, edge_batch_node_idc, edge_batch_nghb_idc, :]
142 |
143 | # pad up the hidden nodes
144 | hidden_nodes = torch.zeros(nodes.shape[0],
145 | nodes.shape[1],
146 | self.hidden_node_features,
147 | device=self.constants.device)
148 | hidden_nodes[:nodes.shape[0], :nodes.shape[1], :nodes.shape[2]] = nodes.clone()
149 |
150 | for _ in range(self.message_passes):
151 |
152 | node_batch_nodes = hidden_nodes[node_batch_batch_idc, node_batch_node_idc, :]
153 | node_batch_node_nghbs[edge_batch_node_batch_idc, node_batch_nghb_nghb_idc, :] = \
154 | hidden_nodes[edge_batch_batch_idc, edge_batch_nghb_idc, :]
155 |
156 | messages = self.aggregate_message(nodes=node_batch_nodes,
157 | node_neighbours=node_batch_node_nghbs.clone(),
158 | edges=node_batch_edges,
159 | mask=node_batch_node_nghb_mask)
160 |
161 | hidden_nodes[node_batch_batch_idc, node_batch_node_idc, :] = \
162 | self.update(node_batch_nodes.clone(), messages)
163 |
164 | node_mask = (adjacency.sum(-1) != 0)
165 |
166 | output = self.readout(hidden_nodes, nodes, node_mask)
167 |
168 | return output
169 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/gnn/edge_mpnn.py:
--------------------------------------------------------------------------------
1 | """
2 | Defines the `EdgeMPNN` class.
3 | """# load general packages and functions
4 | from collections import namedtuple
5 | import torch
6 |
7 |
8 | class EdgeMPNN(torch.nn.Module):
9 | """
10 | Abstract `EdgeMPNN` class. A specific model using this class is defined
11 | in `mpnn.py`; this is the EMN.
12 | """
13 | def __init__(self, constants : namedtuple) -> None:
14 | super().__init__()
15 |
16 | self.edge_features = constants.edge_features
17 | self.edge_embedding_size = constants.edge_embedding_size
18 | self.message_passes = constants.message_passes
19 | self.n_nodes_largest_graph = constants.max_n_nodes
20 | self.constants = constants
21 |
22 | def preprocess_edges(self, nodes : torch.Tensor, node_neighbours : torch.Tensor,
23 | edges : torch.Tensor) -> None:
24 | """
25 | Edge preprocessing step, to be implemented in all `EdgeMPNN` subclasses.
26 |
27 | Args:
28 | ----
29 | nodes (torch.Tensor) : Batch of node feature vectors.
30 | node_neighbours (torch.Tensor) : Batch of node feature vectors for neighbors.
31 | edges (torch.Tensor) : Batch of edge feature vectors.
32 |
33 | Shapes:
34 | ------
35 | nodes : (total N nodes in batch, N node features)
36 | node_neighbours : (total N nodes in batch, max node degree, N node features)
37 | edges : (total N nodes in batch, max node degree, N edge features)
38 | """
39 | raise NotImplementedError
40 |
41 | def propagate_edges(self, edges : torch.Tensor, ingoing_edge_memories : torch.Tensor,
42 | ingoing_edges_mask : torch.Tensor) -> None:
43 | """
44 | Edge propagation rule, to be implemented in all `EdgeMPNN` subclasses.
45 |
46 | Args:
47 | ----
48 | edges (torch.Tensor) : Batch of edge feature tensors.
49 | ingoing_edge_memories (torch.Tensor) : Batch of memories for all
50 | ingoing edges.
51 | ingoing_edges_mask (torch.Tensor) : Mask for ingoing edges.
52 |
53 | Shapes:
54 | ------
55 | edges : (batch size, N nodes, N nodes, total N edge features)
56 | ingoing_edge_memories : (total N edges in batch, total N edge features)
57 | ingoing_edges_mask : (total N edges in batch, max node degree, total N edge features)
58 | """
59 | raise NotImplementedError
60 |
61 | def readout(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor,
62 | node_mask : torch.Tensor) -> None:
63 | """
64 | Local readout function, to be implemented in all `EdgeMPNN` subclasses.
65 |
66 | Args:
67 | ----
68 | hidden_nodes (torch.Tensor) : Batch of node feature vectors.
69 | input_nodes (torch.Tensor) : Batch of node feature vectors.
70 | node_mask (torch.Tensor) : Mask for non-existing neighbors, where
71 | elements are 1 if corresponding element
72 | exists and 0 otherwise.
73 |
74 | Shapes:
75 | ------
76 | hidden_nodes : (total N nodes in batch, N node features)
77 | input_nodes : (total N nodes in batch, N node features)
78 | node_mask : (total N nodes in batch, N features)
79 | """
80 | raise NotImplementedError
81 |
82 | def forward(self, nodes : torch.Tensor, edges : torch.Tensor) -> torch.Tensor:
83 | """
84 | Defines forward pass.
85 |
86 | Args:
87 | ----
88 | nodes (torch.Tensor) : Batch of node feature matrices.
89 | edges (torch.Tensor) : Batch of edge feature tensors.
90 |
91 | Shapes:
92 | ------
93 | nodes : (batch size, N nodes, N node features)
94 | edges : (batch size, N nodes, N nodes, N edge features)
95 |
96 | Returns:
97 | -------
98 | output (torch.Tensor) : This would normally be the learned graph representation,
99 | but in all MPNN readout functions in this work,
100 | the last layer is used to predict the action
101 | probability distribution for a batch of graphs from
102 | the learned graph representation.
103 | """
104 | adjacency = torch.sum(edges, dim=3)
105 |
106 | # indices for finding edges in batch; `edges_b_idx` is batch index,
107 | # `edges_n_idx` is the node index, and `edges_nghb_idx` is the index
108 | # that each node in `edges_n_idx` is bound to
109 | edges_b_idx, edges_n_idx, edges_nghb_idx = adjacency.nonzero(as_tuple=True)
110 |
111 | n_edges = edges_n_idx.shape[0]
112 | adj_of_edge_batch_idc = adjacency.clone().long()
113 |
114 | # +1 to distinguish idx 0 from empty elements, subtracted few lines down
115 | r = torch.arange(1, n_edges + 1, device=self.constants.device)
116 |
117 | adj_of_edge_batch_idc[edges_b_idx, edges_n_idx, edges_nghb_idx] = r
118 |
119 | ingoing_edges_eb_idx = (
120 | torch.cat([row[row.nonzero()] for row in
121 | adj_of_edge_batch_idc[edges_b_idx, edges_nghb_idx, :]]) - 1
122 | ).squeeze()
123 |
124 | edge_degrees = adjacency[edges_b_idx, edges_nghb_idx, :].sum(-1).long()
125 | ingoing_edges_igeb_idx = torch.cat(
126 | [i * torch.ones(d) for i, d in enumerate(edge_degrees)]
127 | ).long()
128 | ingoing_edges_ige_idx = torch.cat([torch.arange(i) for i in edge_degrees]).long()
129 |
130 |
131 | batch_size = adjacency.shape[0]
132 | n_nodes = adjacency.shape[1]
133 | max_node_degree = adjacency.sum(-1).max().int()
134 | edge_memories = torch.zeros(n_edges,
135 | self.edge_embedding_size,
136 | device=self.constants.device)
137 |
138 | ingoing_edge_memories = torch.zeros(n_edges, max_node_degree,
139 | self.edge_embedding_size,
140 | device=self.constants.device)
141 | ingoing_edges_mask = torch.zeros(n_edges,
142 | max_node_degree,
143 | device=self.constants.device)
144 |
145 | edge_batch_nodes = nodes[edges_b_idx, edges_n_idx, :]
146 | # **note: "nghb{s}" == "neighbour(s)"
147 | edge_batch_nghbs = nodes[edges_b_idx, edges_nghb_idx, :]
148 | edge_batch_edges = edges[edges_b_idx, edges_n_idx, edges_nghb_idx, :]
149 | edge_batch_edges = self.preprocess_edges(nodes=edge_batch_nodes,
150 | node_neighbours=edge_batch_nghbs,
151 | edges=edge_batch_edges)
152 |
153 | # remove h_ji:s influence on h_ij
154 | ingoing_edges_nghb_idx = edges_nghb_idx[ingoing_edges_eb_idx]
155 | ingoing_edges_receiving_edge_n_idx = edges_n_idx[ingoing_edges_igeb_idx]
156 | diff_idx = (ingoing_edges_receiving_edge_n_idx != ingoing_edges_nghb_idx).nonzero()
157 |
158 | try:
159 | ingoing_edges_eb_idx = ingoing_edges_eb_idx[diff_idx].squeeze()
160 | ingoing_edges_ige_idx = ingoing_edges_ige_idx[diff_idx].squeeze()
161 | ingoing_edges_igeb_idx = ingoing_edges_igeb_idx[diff_idx].squeeze()
162 | except:
163 | pass
164 |
165 | ingoing_edges_mask[ingoing_edges_igeb_idx, ingoing_edges_ige_idx] = 1
166 |
167 | for _ in range(self.message_passes):
168 | ingoing_edge_memories[ingoing_edges_igeb_idx, ingoing_edges_ige_idx, :] = \
169 | edge_memories[ingoing_edges_eb_idx, :]
170 | edge_memories = self.propagate_edges(
171 | edges=edge_batch_edges,
172 | ingoing_edge_memories=ingoing_edge_memories.clone(),
173 | ingoing_edges_mask=ingoing_edges_mask
174 | )
175 |
176 | node_mask = (adjacency.sum(-1) != 0)
177 |
178 | node_sets = torch.zeros(batch_size,
179 | n_nodes,
180 | max_node_degree,
181 | self.edge_embedding_size,
182 | device=self.constants.device)
183 |
184 | edge_batch_edge_memory_idc = torch.cat(
185 | [torch.arange(row.sum()) for row in adjacency.view(-1, n_nodes)]
186 | ).long()
187 |
188 | node_sets[edges_b_idx, edges_n_idx, edge_batch_edge_memory_idc, :] = edge_memories
189 | graph_sets = node_sets.sum(2)
190 |
191 | output = self.readout(graph_sets, graph_sets, node_mask)
192 | return output
193 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/gnn/modules.py:
--------------------------------------------------------------------------------
1 | """
2 | Defines MPNN modules and readout functions, and APD readout functions.
3 | """
4 | # load general packages and functions
5 | from collections import namedtuple
6 | import torch
7 |
8 | # load GraphINVENT-specific functions
9 | # (none)
10 |
11 |
12 | class GraphGather(torch.nn.Module):
13 | """
14 | GGNN readout function.
15 | """
16 | def __init__(self, node_features : int, hidden_node_features : int,
17 | out_features : int, att_depth : int, att_hidden_dim : int,
18 | att_dropout_p : float, emb_depth : int, emb_hidden_dim : int,
19 | emb_dropout_p : float, big_positive : float) -> None:
20 |
21 | super().__init__()
22 |
23 | self.big_positive = big_positive
24 |
25 | self.att_nn = MLP(
26 | in_features=node_features + hidden_node_features,
27 | hidden_layer_sizes=[att_hidden_dim] * att_depth,
28 | out_features=out_features,
29 | dropout_p=att_dropout_p
30 | )
31 |
32 | self.emb_nn = MLP(
33 | in_features=hidden_node_features,
34 | hidden_layer_sizes=[emb_hidden_dim] * emb_depth,
35 | out_features=out_features,
36 | dropout_p=emb_dropout_p
37 | )
38 |
39 | def forward(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor,
40 | node_mask : torch.Tensor) -> torch.Tensor:
41 | """
42 | Defines forward pass.
43 | """
44 | Softmax = torch.nn.Softmax(dim=1)
45 |
46 | cat = torch.cat((hidden_nodes, input_nodes), dim=2)
47 | energy_mask = (node_mask == 0).float() * self.big_positive
48 | energies = self.att_nn(cat) - energy_mask.unsqueeze(-1)
49 | attention = Softmax(energies)
50 | embedding = self.emb_nn(hidden_nodes)
51 |
52 | return torch.sum(attention * embedding, dim=1)
53 |
54 |
55 | class Set2Vec(torch.nn.Module):
56 | """
57 | S2V readout function.
58 | """
59 | def __init__(self, node_features : int, hidden_node_features : int,
60 | lstm_computations : int, memory_size : int,
61 | constants : namedtuple) -> None:
62 |
63 | super().__init__()
64 |
65 | self.constants = constants
66 | self.lstm_computations = lstm_computations
67 | self.memory_size = memory_size
68 |
69 | self.embedding_matrix = torch.nn.Linear(
70 | in_features=node_features + hidden_node_features,
71 | out_features=self.memory_size,
72 | bias=True
73 | )
74 |
75 | self.lstm = torch.nn.LSTMCell(
76 | input_size=self.memory_size,
77 | hidden_size=self.memory_size,
78 | bias=True
79 | )
80 |
81 | def forward(self, hidden_output_nodes : torch.Tensor, input_nodes : torch.Tensor,
82 | node_mask : torch.Tensor) -> torch.Tensor:
83 | """
84 | Defines forward pass.
85 | """
86 | Softmax = torch.nn.Softmax(dim=1)
87 |
88 | batch_size = input_nodes.shape[0]
89 | energy_mask = torch.bitwise_not(node_mask).float() * self.C.big_negative
90 | lstm_input = torch.zeros(batch_size, self.memory_size, device=self.constants.device)
91 | cat = torch.cat((hidden_output_nodes, input_nodes), dim=2)
92 | memory = self.embedding_matrix(cat)
93 | hidden_state = torch.zeros(batch_size, self.memory_size, device=self.constants.device)
94 | cell_state = torch.zeros(batch_size, self.memory_size, device=self.constants.device)
95 |
96 | for _ in range(self.lstm_computations):
97 | query, cell_state = self.lstm(lstm_input, (hidden_state, cell_state))
98 |
99 | # dot product query x memory
100 | energies = (query.view(batch_size, 1, self.memory_size) * memory).sum(dim=-1)
101 | attention = Softmax(energies + energy_mask)
102 | read = (attention.unsqueeze(-1) * memory).sum(dim=1)
103 |
104 | hidden_state = query
105 | lstm_input = read
106 |
107 | cat = torch.cat((query, read), dim=1)
108 | return cat
109 |
110 |
111 | class MLP(torch.nn.Module):
112 | """
113 | Multi-layer perceptron. Applies SELU after every linear layer.
114 |
115 | Args:
116 | ----
117 | in_features (int) : Size of each input sample.
118 | hidden_layer_sizes (list) : Hidden layer sizes.
119 | out_features (int) : Size of each output sample.
120 | dropout_p (float) : Probability of dropping a weight.
121 | """
122 |
123 | def __init__(self, in_features : int, hidden_layer_sizes : list, out_features : int,
124 | dropout_p : float) -> None:
125 | super().__init__()
126 |
127 | activation_function = torch.nn.SELU
128 |
129 | # create list of all layer feature sizes
130 | fs = [in_features, *hidden_layer_sizes, out_features]
131 |
132 | # create list of linear_blocks
133 | layers = [self._linear_block(in_f, out_f,
134 | activation_function,
135 | dropout_p)
136 | for in_f, out_f in zip(fs, fs[1:])]
137 |
138 | # concatenate modules in all sequentials in layers list
139 | layers = [module for sq in layers for module in sq.children()]
140 |
141 | # add modules to sequential container
142 | self.seq = torch.nn.Sequential(*layers)
143 |
144 | def _linear_block(self, in_f : int, out_f : int, activation : torch.nn.Module,
145 | dropout_p : float) -> torch.nn.Sequential:
146 | """
147 | Returns a linear block consisting of a linear layer, an activation function
148 | (SELU), and dropout (optional) stack.
149 |
150 | Args:
151 | ----
152 | in_f (int) : Size of each input sample.
153 | out_f (int) : Size of each output sample.
154 | activation (torch.nn.Module) : Activation function.
155 | dropout_p (float) : Probability of dropping a weight.
156 |
157 | Returns:
158 | -------
159 | torch.nn.Sequential : The linear block.
160 | """
161 | # bias must be used in most MLPs in our models to learn from empty graphs
162 | linear = torch.nn.Linear(in_f, out_f, bias=True)
163 | torch.nn.init.xavier_uniform_(linear.weight)
164 | return torch.nn.Sequential(linear, activation(), torch.nn.AlphaDropout(dropout_p))
165 |
166 | def forward(self, layers_input : torch.nn.Sequential) -> torch.nn.Sequential:
167 | """
168 | Defines forward pass.
169 | """
170 | return self.seq(layers_input)
171 |
172 |
173 | class GlobalReadout(torch.nn.Module):
174 | """
175 | Global readout function class. Used to predict the action probability distributions
176 | (APDs) for molecular graphs.
177 |
178 | The first tier of two `MLP`s take as input, for each graph in the batch, the
179 | final transformed node feature vectors. These feed-forward networks correspond
180 | to the preliminary "f_add" and "f_conn" distributions.
181 |
182 | The second tier of three `MLP`s takes as input the output of the first tier
183 | of `MLP`s (the "preliminary" APDs) as well as the graph embeddings for all
184 | graphs in the batch. Output are the final APD components, which are then flattened
185 | and concatenated. No activation function is applied after the final layer, so
186 | that this can be done outside (e.g. in the loss function, and before sampling).
187 | """
188 | def __init__(self, f_add_elems : int, f_conn_elems : int, f_term_elems : int,
189 | mlp1_depth : int, mlp1_dropout_p : float, mlp1_hidden_dim : int,
190 | mlp2_depth : int, mlp2_dropout_p : float, mlp2_hidden_dim : int,
191 | graph_emb_size : int, max_n_nodes : int, node_emb_size : int,
192 | device : str) -> None:
193 | super().__init__()
194 |
195 | self.device = device
196 |
197 | # preliminary f_add
198 | self.fAddNet1 = MLP(
199 | in_features=node_emb_size,
200 | hidden_layer_sizes=[mlp1_hidden_dim] * mlp1_depth,
201 | out_features=f_add_elems,
202 | dropout_p=mlp1_dropout_p
203 | )
204 |
205 | # preliminary f_conn
206 | self.fConnNet1 = MLP(
207 | in_features=node_emb_size,
208 | hidden_layer_sizes=[mlp1_hidden_dim] * mlp1_depth,
209 | out_features=f_conn_elems,
210 | dropout_p=mlp1_dropout_p
211 | )
212 |
213 | # final f_add
214 | self.fAddNet2 = MLP(
215 | in_features=(max_n_nodes * f_add_elems + graph_emb_size),
216 | hidden_layer_sizes=[mlp2_hidden_dim] * mlp2_depth,
217 | out_features=f_add_elems * max_n_nodes,
218 | dropout_p=mlp2_dropout_p
219 | )
220 |
221 | # final f_conn
222 | self.fConnNet2 = MLP(
223 | in_features=(max_n_nodes * f_conn_elems + graph_emb_size),
224 | hidden_layer_sizes=[mlp2_hidden_dim] * mlp2_depth,
225 | out_features=f_conn_elems * max_n_nodes,
226 | dropout_p=mlp2_dropout_p
227 | )
228 |
229 | # final f_term (only takes as input graph embeddings)
230 | self.fTermNet2 = MLP(
231 | in_features=graph_emb_size,
232 | hidden_layer_sizes=[mlp2_hidden_dim] * mlp2_depth,
233 | out_features=f_term_elems,
234 | dropout_p=mlp2_dropout_p
235 | )
236 |
237 | def forward(self, node_level_output : torch.Tensor,
238 | graph_embedding_batch : torch.Tensor) -> torch.Tensor:
239 | """
240 | Defines forward pass.
241 | """
242 | if self.device == "cuda":
243 | self.fAddNet1 = self.fAddNet1.to("cuda", non_blocking=True)
244 | self.fConnNet1 = self.fConnNet1.to("cuda", non_blocking=True)
245 | self.fAddNet2 = self.fAddNet2.to("cuda", non_blocking=True)
246 | self.fConnNet2 = self.fConnNet2.to("cuda", non_blocking=True)
247 | self.fTermNet2 = self.fTermNet2.to("cuda", non_blocking=True)
248 |
249 | # get preliminary f_add and f_conn
250 | f_add_1 = self.fAddNet1(node_level_output)
251 | f_conn_1 = self.fConnNet1(node_level_output)
252 |
253 | if self.device == "cuda":
254 | f_add_1 = f_add_1.to("cuda", non_blocking=True)
255 | f_conn_1 = f_conn_1.to("cuda", non_blocking=True)
256 |
257 | # reshape preliminary APDs into flattenened vectors (e.g. one vector per
258 | # graph in batch)
259 | f_add_1_size = f_add_1.size()
260 | f_conn_1_size = f_conn_1.size()
261 | f_add_1 = f_add_1.view((f_add_1_size[0], f_add_1_size[1] * f_add_1_size[2]))
262 | f_conn_1 = f_conn_1.view((f_conn_1_size[0], f_conn_1_size[1] * f_conn_1_size[2]))
263 |
264 | # get final f_add, f_conn, and f_term
265 | f_add_2 = self.fAddNet2(
266 | torch.cat((f_add_1, graph_embedding_batch), dim=1).unsqueeze(dim=1)
267 | )
268 | f_conn_2 = self.fConnNet2(
269 | torch.cat((f_conn_1, graph_embedding_batch), dim=1).unsqueeze(dim=1)
270 | )
271 | f_term_2 = self.fTermNet2(graph_embedding_batch)
272 |
273 | if self.device == "cuda":
274 | f_add_2 = f_add_2.to("cuda", non_blocking=True)
275 | f_conn_2 = f_conn_2.to("cuda", non_blocking=True)
276 | f_term_2 = f_term_2.to("cuda", non_blocking=True)
277 |
278 | # flatten and concatenate
279 | cat = torch.cat((f_add_2.squeeze(dim=1), f_conn_2.squeeze(dim=1), f_term_2), dim=1)
280 |
281 | return cat # note: no activation function before returning
282 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/gnn/mpnn.py:
--------------------------------------------------------------------------------
1 | """
2 | Defines specific MPNN implementations.
3 | """
4 | # load general packages and functions
5 | from collections import namedtuple
6 | import math
7 | import torch
8 |
9 | # load GraphINVENT-specific functions
10 | import gnn.aggregation_mpnn
11 | import gnn.edge_mpnn
12 | import gnn.summation_mpnn
13 | import gnn.modules
14 |
15 |
16 | class MNN(gnn.summation_mpnn.SummationMPNN):
17 | """
18 | The "message neural network" model.
19 | """
20 | def __init__(self, constants : namedtuple) -> None:
21 | super().__init__(constants)
22 |
23 | self.constants = constants
24 | message_weights = torch.Tensor(self.constants.message_size,
25 | self.constants.hidden_node_features,
26 | self.constants.n_edge_features)
27 | if self.constants.device == "cuda":
28 | message_weights = message_weights.to("cuda", non_blocking=True)
29 |
30 | self.message_weights = torch.nn.Parameter(message_weights)
31 |
32 | self.gru = torch.nn.GRUCell(
33 | input_size=self.constants.message_size,
34 | hidden_size=self.constants.hidden_node_features,
35 | bias=True
36 | )
37 |
38 | self.APDReadout = gnn.modules.GlobalReadout(
39 | node_emb_size=self.constants.hidden_node_features,
40 | graph_emb_size=self.constants.hidden_node_features,
41 | mlp1_hidden_dim=self.constants.mlp1_hidden_dim,
42 | mlp1_depth=self.constants.mlp1_depth,
43 | mlp1_dropout_p=self.constants.mlp1_dropout_p,
44 | mlp2_hidden_dim=self.constants.mlp2_hidden_dim,
45 | mlp2_depth=self.constants.mlp2_depth,
46 | mlp2_dropout_p=self.constants.mlp2_dropout_p,
47 | f_add_elems=self.constants.len_f_add_per_node,
48 | f_conn_elems=self.constants.len_f_conn_per_node,
49 | f_term_elems=1,
50 | max_n_nodes=self.constants.max_n_nodes,
51 | device=self.constants.device,
52 | )
53 |
54 | self.reset_parameters()
55 |
56 | def reset_parameters(self) -> None:
57 | stdev = 1.0 / math.sqrt(self.message_weights.size(1))
58 | self.message_weights.data.uniform_(-stdev, stdev)
59 |
60 | def message_terms(self, nodes : torch.Tensor, node_neighbours : torch.Tensor,
61 | edges : torch.Tensor) -> torch.Tensor:
62 | edges_view = edges.view(-1, 1, 1, self.constants.n_edge_features)
63 | weights_for_each_edge = (edges_view * self.message_weights.unsqueeze(0)).sum(3)
64 | return torch.matmul(weights_for_each_edge,
65 | node_neighbours.unsqueeze(-1)).squeeze()
66 |
67 | def update(self, nodes : torch.Tensor, messages : torch.Tensor) -> torch.Tensor:
68 | return self.gru(messages, nodes)
69 |
70 | def readout(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor,
71 | node_mask : torch.Tensor) -> torch.Tensor:
72 | graph_embeddings = torch.sum(hidden_nodes, dim=1)
73 | output = self.APDReadout(hidden_nodes, graph_embeddings)
74 | return output
75 |
76 |
77 | class S2V(gnn.summation_mpnn.SummationMPNN):
78 | """
79 | The "set2vec" model.
80 | """
81 | def __init__(self, constants : namedtuple) -> None:
82 | super().__init__(constants)
83 |
84 | self.constants = constants
85 |
86 | self.enn = gnn.modules.MLP(
87 | in_features=self.constants.n_edge_features,
88 | hidden_layer_sizes=[self.constants.enn_hidden_dim] * self.constants.enn_depth,
89 | out_features=self.constants.hidden_node_features * self.constants.message_size,
90 | dropout_p=self.constants.enn_dropout_p
91 | )
92 |
93 | self.gru = torch.nn.GRUCell(
94 | input_size=self.constants.message_size,
95 | hidden_size=self.constants.hidden_node_features,
96 | bias=True
97 | )
98 |
99 | self.s2v = gnn.modules.Set2Vec(
100 | node_features=self.constants.n_node_features,
101 | hidden_node_features=self.constants.hidden_node_features,
102 | lstm_computations=self.constants.s2v_lstm_computations,
103 | memory_size=self.constants.s2v_memory_size
104 | )
105 |
106 | self.APDReadout = gnn.modules.GlobalReadout(
107 | node_emb_size=self.constants.hidden_node_features,
108 | graph_emb_size=self.constants.s2v_memory_size * 2,
109 | mlp1_hidden_dim=self.constants.mlp1_hidden_dim,
110 | mlp1_depth=self.constants.mlp1_depth,
111 | mlp1_dropout_p=self.constants.mlp1_dropout_p,
112 | mlp2_hidden_dim=self.constants.mlp2_hidden_dim,
113 | mlp2_depth=self.constants.mlp2_depth,
114 | mlp2_dropout_p=self.constants.mlp2_dropout_p,
115 | f_add_elems=self.constants.len_f_add_per_node,
116 | f_conn_elems=self.constants.len_f_conn_per_node,
117 | f_term_elems=1,
118 | max_n_nodes=self.constants.max_n_nodes,
119 | device=self.constants.device,
120 | )
121 |
122 | def message_terms(self, nodes : torch.Tensor, node_neighbours : torch.Tensor,
123 | edges : torch.Tensor) -> torch.Tensor:
124 | enn_output = self.enn(edges)
125 | matrices = enn_output.view(-1,
126 | self.constants.message_size,
127 | self.constants.hidden_node_features)
128 | msg_terms = torch.matmul(matrices,
129 | node_neighbours.unsqueeze(-1)).squeeze(-1)
130 | return msg_terms
131 |
132 | def update(self, nodes : torch.Tensor, messages : torch.Tensor) -> torch.Tensor:
133 | return self.gru(messages, nodes)
134 |
135 | def readout(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor,
136 | node_mask : torch.Tensor) -> torch.Tensor:
137 | graph_embeddings = self.s2v(hidden_nodes, input_nodes, node_mask)
138 | output = self.APDReadout(hidden_nodes, graph_embeddings)
139 | return output
140 |
141 |
142 | class AttentionS2V(gnn.aggregation_mpnn.AggregationMPNN):
143 | """
144 | The "set2vec with attention" model.
145 | """
146 | def __init__(self, constants : namedtuple) -> None:
147 |
148 | super().__init__(constants)
149 |
150 | self.constants = constants
151 |
152 | self.enn = gnn.modules.MLP(
153 | in_features=self.constants.n_edge_features,
154 | hidden_layer_sizes=[self.constants.enn_hidden_dim] * self.constants.enn_depth,
155 | out_features=self.constants.hidden_node_features * self.constants.message_size,
156 | dropout_p=self.constants.enn_dropout_p
157 | )
158 |
159 | self.att_enn = gnn.modules.MLP(
160 | in_features=self.constants.hidden_node_features + self.constants.n_edge_features,
161 | hidden_layer_sizes=[self.constants.att_hidden_dim] * self.constants.att_depth,
162 | out_features=self.constants.message_size,
163 | dropout_p=self.constants.att_dropout_p
164 | )
165 |
166 | self.gru = torch.nn.GRUCell(
167 | input_size=self.constants.message_size,
168 | hidden_size=self.constants.hidden_node_features,
169 | bias=True
170 | )
171 |
172 | self.s2v = gnn.modules.Set2Vec(
173 | node_features=self.constants.n_node_features,
174 | hidden_node_features=self.constants.hidden_node_features,
175 | lstm_computations=self.constants.s2v_lstm_computations,
176 | memory_size=self.constants.s2v_memory_size,
177 | )
178 |
179 | self.APDReadout = gnn.modules.GlobalReadout(
180 | node_emb_size=self.constants.hidden_node_features,
181 | graph_emb_size=self.constants.s2v_memory_size * 2,
182 | mlp1_hidden_dim=self.constants.mlp1_hidden_dim,
183 | mlp1_depth=self.constants.mlp1_depth,
184 | mlp1_dropout_p=self.constants.mlp1_dropout_p,
185 | mlp2_hidden_dim=self.constants.mlp2_hidden_dim,
186 | mlp2_depth=self.constants.mlp2_depth,
187 | mlp2_dropout_p=self.constants.mlp2_dropout_p,
188 | f_add_elems=self.constants.len_f_add_per_node,
189 | f_conn_elems=self.constants.len_f_conn_per_node,
190 | f_term_elems=1,
191 | max_n_nodes=self.constants.max_n_nodes,
192 | device=self.constants.device,
193 | )
194 |
195 | def aggregate_message(self, nodes : torch.Tensor,
196 | node_neighbours : torch.Tensor,
197 | edges : torch.Tensor,
198 | mask : torch.Tensor) -> torch.Tensor:
199 | Softmax = torch.nn.Softmax(dim=1)
200 | max_node_degree = node_neighbours.shape[1]
201 |
202 | enn_output = self.enn(edges)
203 | matrices = enn_output.view(-1,
204 | max_node_degree,
205 | self.constants.message_size,
206 | self.constants.hidden_node_features)
207 | message_terms = torch.matmul(matrices, node_neighbours.unsqueeze(-1)).squeeze()
208 |
209 | att_enn_output = self.att_enn(torch.cat((edges, node_neighbours), dim=2))
210 | energies = att_enn_output.view(-1, max_node_degree, self.constants.message_size)
211 | energy_mask = (1 - mask).float() * self.constants.big_negative
212 | weights = Softmax(energies + energy_mask.unsqueeze(-1))
213 |
214 | return (weights * message_terms).sum(1)
215 |
216 | def update(self, nodes : torch.Tensor, messages : torch.Tensor) -> torch.Tensor:
217 | if self.constants.device == "cuda":
218 | messages = messages + torch.zeros(self.constants.message_size, device="cuda")
219 | return self.gru(messages, nodes)
220 |
221 | def readout(self, hidden_nodes : torch.Tensor,
222 | input_nodes : torch.Tensor,
223 | node_mask : torch.Tensor) -> torch.Tensor:
224 | graph_embeddings = self.s2v(hidden_nodes, input_nodes, node_mask)
225 | output = self.APDReadout(hidden_nodes, graph_embeddings)
226 | return output
227 |
228 |
229 | class GGNN(gnn.summation_mpnn.SummationMPNN):
230 | """
231 | The "gated-graph neural network" model.
232 | """
233 | def __init__(self, constants : namedtuple) -> None:
234 | super().__init__(constants)
235 |
236 | self.constants = constants
237 |
238 | self.msg_nns = torch.nn.ModuleList()
239 | for _ in range(self.constants.n_edge_features):
240 | self.msg_nns.append(
241 | gnn.modules.MLP(
242 | in_features=self.constants.hidden_node_features,
243 | hidden_layer_sizes=[self.constants.enn_hidden_dim] * self.constants.enn_depth,
244 | out_features=self.constants.message_size,
245 | dropout_p=self.constants.enn_dropout_p,
246 | )
247 | )
248 |
249 | self.gru = torch.nn.GRUCell(
250 | input_size=self.constants.message_size,
251 | hidden_size=self.constants.hidden_node_features,
252 | bias=True
253 | )
254 |
255 | self.gather = gnn.modules.GraphGather(
256 | node_features=self.constants.n_node_features,
257 | hidden_node_features=self.constants.hidden_node_features,
258 | out_features=self.constants.gather_width,
259 | att_depth=self.constants.gather_att_depth,
260 | att_hidden_dim=self.constants.gather_att_hidden_dim,
261 | att_dropout_p=self.constants.gather_att_dropout_p,
262 | emb_depth=self.constants.gather_emb_depth,
263 | emb_hidden_dim=self.constants.gather_emb_hidden_dim,
264 | emb_dropout_p=self.constants.gather_emb_dropout_p,
265 | big_positive=self.constants.big_positive
266 | )
267 |
268 | self.APDReadout = gnn.modules.GlobalReadout(
269 | node_emb_size=self.constants.hidden_node_features,
270 | graph_emb_size=self.constants.gather_width,
271 | mlp1_hidden_dim=self.constants.mlp1_hidden_dim,
272 | mlp1_depth=self.constants.mlp1_depth,
273 | mlp1_dropout_p=self.constants.mlp1_dropout_p,
274 | mlp2_hidden_dim=self.constants.mlp2_hidden_dim,
275 | mlp2_depth=self.constants.mlp2_depth,
276 | mlp2_dropout_p=self.constants.mlp2_dropout_p,
277 | f_add_elems=self.constants.len_f_add_per_node,
278 | f_conn_elems=self.constants.len_f_conn_per_node,
279 | f_term_elems=1,
280 | max_n_nodes=self.constants.max_n_nodes,
281 | device=self.constants.device,
282 | )
283 |
284 | def message_terms(self, nodes : torch.Tensor, node_neighbours : torch.Tensor,
285 | edges : torch.Tensor) -> torch.Tensor:
286 | edges_v = edges.view(-1, self.constants.n_edge_features, 1)
287 | node_neighbours_v = edges_v * node_neighbours.view(-1,
288 | 1,
289 | self.constants.hidden_node_features)
290 | terms_masked_per_edge = [
291 | edges_v[:, i, :] * self.msg_nns[i](node_neighbours_v[:, i, :])
292 | for i in range(self.constants.n_edge_features)
293 | ]
294 | return sum(terms_masked_per_edge)
295 |
296 | def update(self, nodes : torch.Tensor, messages : torch.Tensor) -> torch.Tensor:
297 | return self.gru(messages, nodes)
298 |
299 | def readout(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor,
300 | node_mask : torch.Tensor) -> torch.Tensor:
301 | graph_embeddings = self.gather(hidden_nodes, input_nodes, node_mask)
302 | output = self.APDReadout(hidden_nodes, graph_embeddings)
303 | return output
304 |
305 |
306 | class AttentionGGNN(gnn.aggregation_mpnn.AggregationMPNN):
307 | """
308 | The "GGNN with attention" model.
309 | """
310 | def __init__(self, constants : namedtuple) -> None:
311 | super().__init__(constants)
312 |
313 | self.constants = constants
314 | self.msg_nns = torch.nn.ModuleList()
315 | self.att_nns = torch.nn.ModuleList()
316 |
317 | for _ in range(self.constants.n_edge_features):
318 | self.msg_nns.append(
319 | gnn.modules.MLP(
320 | in_features=self.constants.hidden_node_features,
321 | hidden_layer_sizes=[self.constants.msg_hidden_dim] * self.constants.msg_depth,
322 | out_features=self.constants.message_size,
323 | dropout_p=self.constants.msg_dropout_p,
324 | )
325 | )
326 | self.att_nns.append(
327 | gnn.modules.MLP(
328 | in_features=self.constants.hidden_node_features,
329 | hidden_layer_sizes=[self.constants.att_hidden_dim] * self.constants.att_depth,
330 | out_features=self.constants.message_size,
331 | dropout_p=self.constants.att_dropout_p,
332 | )
333 | )
334 |
335 | self.gru = torch.nn.GRUCell(
336 | input_size=self.constants.message_size,
337 | hidden_size=self.constants.hidden_node_features,
338 | bias=True
339 | )
340 |
341 | self.gather = gnn.modules.GraphGather(
342 | node_features=self.constants.n_node_features,
343 | hidden_node_features=self.constants.hidden_node_features,
344 | out_features=self.constants.gather_width,
345 | att_depth=self.constants.gather_att_depth,
346 | att_hidden_dim=self.constants.gather_att_hidden_dim,
347 | att_dropout_p=self.constants.gather_att_dropout_p,
348 | emb_depth=self.constants.gather_emb_depth,
349 | emb_hidden_dim=self.constants.gather_emb_hidden_dim,
350 | emb_dropout_p=self.constants.gather_emb_dropout_p,
351 | big_positive=self.constants.big_positive
352 | )
353 |
354 | self.APDReadout = gnn.modules.GlobalReadout(
355 | node_emb_size=self.constants.hidden_node_features,
356 | graph_emb_size=self.constants.gather_width,
357 | mlp1_hidden_dim=self.constants.mlp1_hidden_dim,
358 | mlp1_depth=self.constants.mlp1_depth,
359 | mlp1_dropout_p=self.constants.mlp1_dropout_p,
360 | mlp2_hidden_dim=self.constants.mlp2_hidden_dim,
361 | mlp2_depth=self.constants.mlp2_depth,
362 | mlp2_dropout_p=self.constants.mlp2_dropout_p,
363 | f_add_elems=self.constants.len_f_add_per_node,
364 | f_conn_elems=self.constants.len_f_conn_per_node,
365 | f_term_elems=1,
366 | max_n_nodes=self.constants.max_n_nodes,
367 | device=self.constants.device,
368 | )
369 |
370 | def aggregate_message(self, nodes : torch.Tensor, node_neighbours : torch.Tensor,
371 | edges : torch.Tensor, mask : torch.Tensor) -> torch.Tensor:
372 | Softmax = torch.nn.Softmax(dim=1)
373 |
374 | energy_mask = (mask == 0).float() * self.constants.big_positive
375 |
376 | embeddings_masked_per_edge = [
377 | edges[:, :, i].unsqueeze(-1) * self.msg_nns[i](node_neighbours)
378 | for i in range(self.constants.n_edge_features)
379 | ]
380 | energies_masked_per_edge = [
381 | edges[:, :, i].unsqueeze(-1) * self.att_nns[i](node_neighbours)
382 | for i in range(self.constants.n_edge_features)
383 | ]
384 |
385 | embedding = sum(embeddings_masked_per_edge)
386 | energies = sum(energies_masked_per_edge) - energy_mask.unsqueeze(-1)
387 | attention = Softmax(energies)
388 |
389 | return torch.sum(attention * embedding, dim=1)
390 |
391 | def update(self, nodes : torch.Tensor, messages : torch.Tensor) -> torch.Tensor:
392 | return self.gru(messages, nodes)
393 |
394 | def readout(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor,
395 | node_mask : torch.Tensor) -> torch.Tensor:
396 | graph_embeddings = self.gather(hidden_nodes, input_nodes, node_mask)
397 | output = self.APDReadout(hidden_nodes, graph_embeddings)
398 | return output
399 |
400 |
401 | class EMN(gnn.edge_mpnn.EdgeMPNN):
402 | """
403 | The "edge memory network" model.
404 | """
405 | def __init__(self, constants : namedtuple) -> None:
406 | super().__init__(constants)
407 |
408 | self.constants = constants
409 |
410 | self.embedding_nn = gnn.modules.MLP(
411 | in_features=self.constants.n_node_features * 2 + self.constants.n_edge_features,
412 | hidden_layer_sizes=[self.constants.edge_emb_hidden_dim] *self.constants.edge_emb_depth,
413 | out_features=self.constants.edge_emb_size,
414 | dropout_p=self.constants.edge_emb_dropout_p,
415 | )
416 |
417 | self.emb_msg_nn = gnn.modules.MLP(
418 | in_features=self.constants.edge_emb_size,
419 | hidden_layer_sizes=[self.constants.msg_hidden_dim] * self.constants.msg_depth,
420 | out_features=self.constants.edge_emb_size,
421 | dropout_p=self.constants.msg_dropout_p,
422 | )
423 |
424 | self.att_msg_nn = gnn.modules.MLP(
425 | in_features=self.constants.edge_emb_size,
426 | hidden_layer_sizes=[self.constants.att_hidden_dim] * self.constants.att_depth,
427 | out_features=self.constants.edge_emb_size,
428 | dropout_p=self.constants.att_dropout_p,
429 | )
430 |
431 | self.gru = torch.nn.GRUCell(
432 | input_size=self.constants.edge_emb_size,
433 | hidden_size=self.constants.edge_emb_size,
434 | bias=True
435 | )
436 |
437 | self.gather = gnn.modules.GraphGather(
438 | node_features=self.constants.edge_emb_size,
439 | hidden_node_features=self.constants.edge_emb_size,
440 | out_features=self.constants.gather_width,
441 | att_depth=self.constants.gather_att_depth,
442 | att_hidden_dim=self.constants.gather_att_hidden_dim,
443 | att_dropout_p=self.constants.gather_att_dropout_p,
444 | emb_depth=self.constants.gather_emb_depth,
445 | emb_hidden_dim=self.constants.gather_emb_hidden_dim,
446 | emb_dropout_p=self.constants.gather_emb_dropout_p,
447 | big_positive=self.constants.big_positive
448 | )
449 |
450 | self.APDReadout = gnn.modules.GlobalReadout(
451 | node_emb_size=self.constants.edge_emb_size,
452 | graph_emb_size=self.constants.gather_width,
453 | mlp1_hidden_dim=self.constants.mlp1_hidden_dim,
454 | mlp1_depth=self.constants.mlp1_depth,
455 | mlp1_dropout_p=self.constants.mlp1_dropout_p,
456 | mlp2_hidden_dim=self.constants.mlp2_hidden_dim,
457 | mlp2_depth=self.constants.mlp2_depth,
458 | mlp2_dropout_p=self.constants.mlp2_dropout_p,
459 | f_add_elems=self.constants.len_f_add_per_node,
460 | f_conn_elems=self.constants.len_f_conn_per_node,
461 | f_term_elems=1,
462 | max_n_nodes=self.constants.max_n_nodes,
463 | device=self.constants.device,
464 | )
465 |
466 | def preprocess_edges(self, nodes : torch.Tensor, node_neighbours : torch.Tensor,
467 | edges : torch.Tensor) -> torch.Tensor:
468 | cat = torch.cat((nodes, node_neighbours, edges), dim=1)
469 | return torch.tanh(self.embedding_nn(cat))
470 |
471 | def propagate_edges(self, edges : torch.Tensor, ingoing_edge_memories : torch.Tensor,
472 | ingoing_edges_mask : torch.Tensor) -> torch.Tensor:
473 | Softmax = torch.nn.Softmax(dim=1)
474 |
475 | energy_mask = (
476 | (1 - ingoing_edges_mask).float() * self.constants.big_negative
477 | ).unsqueeze(-1)
478 | cat = torch.cat((edges.unsqueeze(1), ingoing_edge_memories), dim=1)
479 | embeddings = self.emb_msg_nn(cat)
480 | edge_energy = self.att_msg_nn(edges)
481 | ing_memory_energies = self.att_msg_nn(ingoing_edge_memories) + energy_mask
482 | energies = torch.cat((edge_energy.unsqueeze(1), ing_memory_energies), dim=1)
483 | attention = Softmax(energies)
484 |
485 | # set aggregation of set of given edge feature and ingoing edge memories
486 | message = (attention * embeddings).sum(dim=1)
487 |
488 | return self.gru(message) # return hidden state
489 |
490 | def readout(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor,
491 | node_mask : torch.Tensor) -> torch.Tensor:
492 | graph_embeddings = self.gather(hidden_nodes, input_nodes, node_mask)
493 | output = self.APDReadout(hidden_nodes, graph_embeddings)
494 | return output
495 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/gnn/summation_mpnn.py:
--------------------------------------------------------------------------------
1 | """
2 | Defines the `SummationMPNN` class.
3 | """
4 | # load general packages and functions
5 | from collections import namedtuple
6 | import torch
7 |
8 |
9 | class SummationMPNN(torch.nn.Module):
10 | """
11 | Abstract `SummationMPNN` class. Specific models using this class are
12 | defined in `mpnn.py`; these are MNN, S2V, and GGNN.
13 | """
14 | def __init__(self, constants : namedtuple):
15 |
16 | super().__init__()
17 |
18 | self.hidden_node_features = constants.hidden_node_features
19 | self.edge_features = constants.n_edge_features
20 | self.message_size = constants.message_size
21 | self.message_passes = constants.message_passes
22 | self.constants = constants
23 |
24 | def message_terms(self, nodes : torch.Tensor, node_neighbours : torch.Tensor,
25 | edges : torch.Tensor) -> None:
26 | """
27 | Message passing function, to be implemented in all `SummationMPNN` subclasses.
28 |
29 | Args:
30 | ----
31 | nodes (torch.Tensor) : Batch of node feature vectors.
32 | node_neighbours (torch.Tensor) : Batch of node feature vectors for neighbors.
33 | edges (torch.Tensor) : Batch of edge feature vectors.
34 |
35 | Shapes:
36 | ------
37 | nodes : (total N nodes in batch, N node features)
38 | node_neighbours : (total N nodes in batch, max node degree, N node features)
39 | edges : (total N nodes in batch, max node degree, N edge features)
40 | """
41 | raise NotImplementedError
42 |
43 | def update(self, nodes : torch.Tensor, messages : torch.Tensor) -> None:
44 | """
45 | Message update function, to be implemented in all `SummationMPNN` subclasses.
46 |
47 | Args:
48 | ----
49 | nodes (torch.Tensor) : Batch of node feature vectors.
50 | messages (torch.Tensor) : Batch of incoming messages.
51 |
52 | Shapes:
53 | ------
54 | nodes : (total N nodes in batch, N node features)
55 | messages : (total N nodes in batch, N node features)
56 | """
57 | raise NotImplementedError
58 |
59 | def readout(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor,
60 | node_mask : torch.Tensor) -> None:
61 | """
62 | Local readout function, to be implemented in all `SummationMPNN` subclasses.
63 |
64 | Args:
65 | ----
66 | hidden_nodes (torch.Tensor) : Batch of node feature vectors.
67 | input_nodes (torch.Tensor) : Batch of node feature vectors.
68 | node_mask (torch.Tensor) : Mask for non-existing neighbors, where elements
69 | are 1 if corresponding element exists and 0
70 | otherwise.
71 |
72 | Shapes:
73 | ------
74 | hidden_nodes : (total N nodes in batch, N node features)
75 | input_nodes : (total N nodes in batch, N node features)
76 | node_mask : (total N nodes in batch, N features)
77 | """
78 | raise NotImplementedError
79 |
80 | def forward(self, nodes : torch.Tensor, edges : torch.Tensor) -> None:
81 | """
82 | Defines forward pass.
83 |
84 | Args:
85 | ----
86 | nodes (torch.Tensor) : Batch of node feature matrices.
87 | edges (torch.Tensor) : Batch of edge feature tensors.
88 |
89 | Shapes:
90 | ------
91 | nodes : (batch size, N nodes, N node features)
92 | edges : (batch size, N nodes, N nodes, N edge features)
93 |
94 | Returns:
95 | -------
96 | output (torch.Tensor) : This would normally be the learned graph representation,
97 | but in all MPNN readout functions in this work,
98 | the last layer is used to predict the action
99 | probability distribution for a batch of graphs
100 | from the learned graph representation.
101 | """
102 | adjacency = torch.sum(edges, dim=3)
103 |
104 | # **note: "idc" == "indices", "nghb{s}" == "neighbour(s)"
105 | (edge_batch_batch_idc,
106 | edge_batch_node_idc,
107 | edge_batch_nghb_idc) = adjacency.nonzero(as_tuple=True)
108 |
109 | (node_batch_batch_idc, node_batch_node_idc) = adjacency.sum(-1).nonzero(as_tuple=True)
110 |
111 | same_batch = node_batch_batch_idc.view(-1, 1) == edge_batch_batch_idc
112 | same_node = node_batch_node_idc.view(-1, 1) == edge_batch_node_idc
113 |
114 | # element ij of `message_summation_matrix` is 1 if `edge_batch_edges[j]`
115 | # is connected with `node_batch_nodes[i]`, else 0
116 | message_summation_matrix = (same_batch * same_node).float()
117 |
118 | edge_batch_edges = edges[edge_batch_batch_idc, edge_batch_node_idc, edge_batch_nghb_idc, :]
119 |
120 | # pad up the hidden nodes
121 | hidden_nodes = torch.zeros(nodes.shape[0],
122 | nodes.shape[1],
123 | self.hidden_node_features,
124 | device=self.constants.device)
125 | hidden_nodes[:nodes.shape[0], :nodes.shape[1], :nodes.shape[2]] = nodes.clone()
126 | node_batch_nodes = hidden_nodes[node_batch_batch_idc, node_batch_node_idc, :]
127 |
128 | for _ in range(self.message_passes):
129 | edge_batch_nodes = hidden_nodes[edge_batch_batch_idc, edge_batch_node_idc, :]
130 |
131 | edge_batch_nghbs = hidden_nodes[edge_batch_batch_idc, edge_batch_nghb_idc, :]
132 |
133 | message_terms = self.message_terms(edge_batch_nodes,
134 | edge_batch_nghbs,
135 | edge_batch_edges)
136 |
137 | if len(message_terms.size()) == 1: # if a single graph in batch
138 | message_terms = message_terms.unsqueeze(0)
139 |
140 | # the summation in eq. 1 of the NMPQC paper happens here
141 | messages = torch.matmul(message_summation_matrix, message_terms)
142 |
143 | node_batch_nodes = self.update(node_batch_nodes, messages)
144 | hidden_nodes[node_batch_batch_idc, node_batch_node_idc, :] = node_batch_nodes.clone()
145 |
146 | node_mask = adjacency.sum(-1) != 0
147 | output = self.readout(hidden_nodes, nodes, node_mask)
148 |
149 | return output
150 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Main function for running GraphINVENT jobs.
3 |
4 | Examples:
5 | --------
6 | * If you define an "input.csv" with desired job parameters in job_dir/:
7 | (graphinvent) ~/GraphINVENT$ python main.py --job_dir path/to/job_dir/
8 | * If you instead want to run your job using the submission scripts:
9 | (graphinvent) ~/GraphINVENT$ python submit-fine-tuning.py
10 | """
11 | # load general packages and functions
12 | import datetime
13 |
14 | # load GraphINVENT-specific functions
15 | import util
16 | from parameters.constants import constants
17 | from Workflow import Workflow
18 |
19 | # suppress minor warnings
20 | util.suppress_warnings()
21 |
22 |
23 | def main():
24 | """
25 | Defines the type of job (preprocessing, training, generation, testing, or
26 | fine-tuning), writes the job parameters (for future reference), and runs
27 | the job.
28 | """
29 | _ = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") # fix date/time
30 |
31 | workflow = Workflow(constants=constants)
32 |
33 | job_type = constants.job_type
34 | print(f"* Run mode: '{job_type}'", flush=True)
35 |
36 | if job_type == "preprocess":
37 | # write preprocessing parameters
38 | util.write_preprocessing_parameters(params=constants)
39 | print("done writing preprocessing params to csv", flush=True)
40 |
41 | # preprocess all datasets
42 | workflow.preprocess_phase()
43 |
44 | elif job_type == "train":
45 | # write training parameters
46 | util.write_job_parameters(params=constants)
47 |
48 | # train model and generate graphs
49 | workflow.training_phase()
50 |
51 | elif job_type == "generate":
52 | # write generation parameters
53 | util.write_job_parameters(params=constants)
54 |
55 | # generate molecules only
56 | workflow.generation_phase()
57 |
58 | elif job_type == "test":
59 | # write testing parameters
60 | util.write_job_parameters(params=constants)
61 |
62 | # evaluate best model using the test set data
63 | workflow.testing_phase()
64 |
65 | elif job_type == "fine-tune":
66 | # write training parameters
67 | util.write_job_parameters(params=constants)
68 |
69 | # fine-tune the model and generate graphs
70 | workflow.learning_phase()
71 |
72 | else:
73 | raise NotImplementedError("Not a valid `job_type`.")
74 |
75 |
76 | if __name__ == "__main__":
77 | main()
78 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/parameters/args.py:
--------------------------------------------------------------------------------
1 | """
2 | Defines `ArgumentParser` for specifying job directory using command-line.
3 | """# load general packages and functions
4 | import argparse
5 |
6 |
7 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
8 | add_help=False)
9 | parser.add_argument("--job-dir",
10 | type=str,
11 | default="../output/",
12 | help="Directory in which to write all output.")
13 |
14 |
15 | args = parser.parse_args()
16 |
17 | args_dict = vars(args)
18 | job_dir = args_dict["job_dir"]
19 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/parameters/constants.py:
--------------------------------------------------------------------------------
1 | """
2 | Loads input parameters from `defaults.py`, and defines other global constants
3 | that depend on the input features, creating a `namedtuple` from them;
4 | additionally, if there exists an `input.csv` in the job directory, loads those
5 | arguments and overrides default values in `defaults.py`.
6 | """
7 | # load general packages and functions
8 | from collections import namedtuple
9 | import pickle
10 | import csv
11 | import os
12 | import sys
13 | from typing import Tuple
14 | import numpy as np
15 | from rdkit.Chem.rdchem import BondType
16 |
17 | # load GraphINVENT-specific functions
18 | sys.path.insert(1, "./parameters/") # search "parameters/" directory
19 | import parameters.args as args
20 | import parameters.defaults as defaults
21 |
22 |
23 | def get_feature_dimensions(parameters : dict) -> Tuple[int, int, int, int]:
24 | """
25 | Returns dimensions of all node features.
26 | """
27 | n_atom_types = len(parameters["atom_types"])
28 | n_formal_charge = len(parameters["formal_charge"])
29 | n_numh = (
30 | int(not parameters["use_explicit_H"] and not parameters["ignore_H"])
31 | * len(parameters["imp_H"])
32 | )
33 | n_chirality = int(parameters["use_chirality"]) * len(parameters["chirality"])
34 |
35 | return n_atom_types, n_formal_charge, n_numh, n_chirality
36 |
37 |
38 | def get_tensor_dimensions(n_atom_types : int, n_formal_charge : int, n_num_h : int,
39 | n_chirality : int, n_node_features : int, n_edge_features : int,
40 | parameters : dict) -> Tuple[list, list, list, list, int]:
41 | """
42 | Returns dimensions for all tensors that describe molecular graphs. Tensor dimensions
43 | are `list`s, except for `dim_f_term` which is simply an `int`. Each element
44 | of the lists indicate the corresponding dimension of a particular subgraph matrix
45 | (i.e. `nodes`, `f_add`, etc).
46 | """
47 | max_nodes = parameters["max_n_nodes"]
48 |
49 | # define the matrix dimensions as `list`s
50 | # first for the graph reps...
51 | dim_nodes = [max_nodes, n_node_features]
52 |
53 | dim_edges = [max_nodes, max_nodes, n_edge_features]
54 |
55 | # ... then for the APDs
56 | if parameters["use_chirality"]:
57 | if parameters["use_explicit_H"] or parameters["ignore_H"]:
58 | dim_f_add = [
59 | parameters["max_n_nodes"],
60 | n_atom_types,
61 | n_formal_charge,
62 | n_chirality,
63 | n_edge_features,
64 | ]
65 | else:
66 | dim_f_add = [
67 | parameters["max_n_nodes"],
68 | n_atom_types,
69 | n_formal_charge,
70 | n_num_h,
71 | n_chirality,
72 | n_edge_features,
73 | ]
74 | else:
75 | if parameters["use_explicit_H"] or parameters["ignore_H"]:
76 | dim_f_add = [
77 | parameters["max_n_nodes"],
78 | n_atom_types,
79 | n_formal_charge,
80 | n_edge_features,
81 | ]
82 | else:
83 | dim_f_add = [
84 | parameters["max_n_nodes"],
85 | n_atom_types,
86 | n_formal_charge,
87 | n_num_h,
88 | n_edge_features,
89 | ]
90 |
91 | dim_f_conn = [parameters["max_n_nodes"], n_edge_features]
92 |
93 | dim_f_term = 1
94 |
95 | return dim_nodes, dim_edges, dim_f_add, dim_f_conn, dim_f_term
96 |
97 |
98 | def load_params(input_csv_path : str) -> dict:
99 | """
100 | Loads job parameters/hyperparameters from CSV (in `input_csv_path`).
101 | """
102 | params_to_override = {}
103 | with open(input_csv_path, "r") as csv_file:
104 |
105 | params_reader = csv.reader(csv_file, delimiter=";")
106 |
107 | for key, value in params_reader:
108 | try:
109 | params_to_override[key] = eval(value)
110 | except NameError: # `value` is a `str`
111 | params_to_override[key] = value
112 | except SyntaxError: # to avoid "unexpected `EOF`"
113 | params_to_override[key] = value
114 |
115 | return params_to_override
116 |
117 |
118 | def override_params(all_params : dict) -> dict:
119 | """
120 | If there exists an `input.csv` in the job directory, loads those arguments
121 | and overrides their default values from `features.py`.
122 | """
123 | input_csv_path = all_params["job_dir"] + "input.csv"
124 |
125 |
126 | # check if there exists and `input.csv` in working directory
127 | if os.path.exists(input_csv_path):
128 | # override default values for parameters in `input.csv`
129 | params_to_override_dict = load_params(input_csv_path)
130 | for key, value in params_to_override_dict.items():
131 | all_params[key] = value
132 |
133 | return all_params
134 |
135 |
136 | def collect_global_constants(parameters : dict, job_dir : str) -> namedtuple:
137 | """
138 | Collects constants defined in `features.py` with those defined by the
139 | ArgParser (`args.py`), and returns the bundle as a `namedtuple`.
140 |
141 | Args:
142 | ----
143 | parameters (dict) : Dictionary of parameters defined in `features.py`.
144 | job_dir (str) : Current job directory, defined on the command line.
145 |
146 | Returns:
147 | -------
148 | constants (namedtuple) : Collected constants.
149 | """
150 | #first override any arguments from `input.csv`:
151 | parameters["job_dir"] = job_dir
152 | print("job directory is now", parameters["job_dir"])
153 | parameters = override_params(all_params=parameters)
154 | print("compute_train_csv is now", parameters["compute_train_csv"])
155 |
156 | # then calculate any global constants below:
157 | if parameters["use_explicit_H"] and parameters["ignore_H"]:
158 | raise ValueError("Cannot use explicit Hs and ignore Hs at "
159 | "the same time. Please fix flags.")
160 |
161 | # define edge feature (rdkit `GetBondType()` result -> `int`) constants
162 | bondtype_to_int = {BondType.SINGLE: 0, BondType.DOUBLE: 1, BondType.TRIPLE: 2}
163 |
164 | if parameters["use_aromatic_bonds"]:
165 | bondtype_to_int[BondType.AROMATIC] = 3
166 |
167 | int_to_bondtype = dict(map(reversed, bondtype_to_int.items()))
168 |
169 | n_edge_features = len(bondtype_to_int)
170 |
171 | # define node feature constants
172 | n_atom_types, n_formal_charge, n_imp_H, n_chirality = get_feature_dimensions(parameters)
173 |
174 | n_node_features = n_atom_types + n_formal_charge + n_imp_H + n_chirality
175 |
176 | # define matrix dimensions
177 | (dim_nodes, dim_edges, dim_f_add,
178 | dim_f_conn, dim_f_term) = get_tensor_dimensions(n_atom_types,
179 | n_formal_charge,
180 | n_imp_H,
181 | n_chirality,
182 | n_node_features,
183 | n_edge_features,
184 | parameters)
185 |
186 | len_f_add = np.prod(dim_f_add[:])
187 | len_f_add_per_node = np.prod(dim_f_add[1:])
188 | len_f_conn = np.prod(dim_f_conn[:])
189 | len_f_conn_per_node = np.prod(dim_f_conn[1:])
190 |
191 | # create a dictionary of global constants, and add `job_dir` to it; this
192 | # will ultimately be converted to a `namedtuple`
193 | constants_dict = {
194 | "big_negative" : -1e6,
195 | "big_positive" : 1e6,
196 | "bondtype_to_int" : bondtype_to_int,
197 | "int_to_bondtype" : int_to_bondtype,
198 | "n_edge_features" : n_edge_features,
199 | "n_atom_types" : n_atom_types,
200 | "n_formal_charge" : n_formal_charge,
201 | "n_imp_H" : n_imp_H,
202 | "n_chirality" : n_chirality,
203 | "n_node_features" : n_node_features,
204 | "dim_nodes" : dim_nodes,
205 | "dim_edges" : dim_edges,
206 | "dim_f_add" : dim_f_add,
207 | "dim_f_conn" : dim_f_conn,
208 | "dim_f_term" : dim_f_term,
209 | "dim_apd" : [np.prod(dim_f_add) + np.prod(dim_f_conn) + 1],
210 | "len_f_add" : len_f_add,
211 | "len_f_add_per_node" : len_f_add_per_node,
212 | "len_f_conn" : len_f_conn,
213 | "len_f_conn_per_node": len_f_conn_per_node,
214 | }
215 |
216 | # join with `features.args_dict`
217 | constants_dict.update(parameters)
218 |
219 | # define path to dataset splits
220 | constants_dict["test_set"] = parameters["dataset_dir"] + "test.smi"
221 | constants_dict["training_set"] = parameters["dataset_dir"] + "train.smi"
222 | constants_dict["validation_set"] = parameters["dataset_dir"] + "valid.smi"
223 |
224 | # check (if a job is not a preprocessing job) that parameters match those for
225 | # the original preprocessing job
226 | if constants_dict["job_type"] != "preprocess":
227 | print(
228 | "* Running job using HDF datasets located at "
229 | + parameters["dataset_dir"],
230 | flush=True,
231 | )
232 | print(
233 | "* Checking that the relevant parameters match "
234 | "those used in preprocessing the dataset.",
235 | flush=True,
236 | )
237 |
238 | try:
239 | #load preprocessing parameters for comparison (if they exist already)
240 | csv_file = parameters["dataset_dir"] + "preprocessing_params.csv"
241 | params_to_check = load_params(input_csv_path=csv_file)
242 |
243 | for key, value in params_to_check.items():
244 | if key in constants_dict.keys() and value != constants_dict[key]:
245 | raise ValueError(
246 | f"Check that training job parameters match those used in "
247 | f"preprocessing. {key} does not match."
248 | )
249 |
250 | # if above error never raised, then all relevant parameters match! :)
251 | print("-- Job parameters match preprocessing parameters.", flush=True)
252 | except:
253 | print("-- Preprocessing pa rameters file does not exist for comparison.", flush=True)
254 |
255 | # load QSAR models (sklearn activity model)
256 | if constants_dict["job_type"] == "fine-tune":
257 | # print("-- Loading pre-trained scikit-learn activity model.", flush=True)
258 | # for qsar_model_name, qsar_model_path in constants_dict["qsar_models"].items():
259 | # with open(qsar_model_path, 'rb') as file:
260 | # model_dict = pickle.load(file)
261 | # activity_model = model_dict["classifier_sv"]
262 | # constants_dict["qsar_models"][qsar_model_name] = activity_model
263 | print("-- Loading pre-trained gbm activity model.", flush=True)
264 | scoring_model_activity = pickle.load(open('/home/gridsan/dnori/GraphINVENT/data/protac_scoring_models/Protac_Scoring_Model_1024_100nM.pkl', 'rb'))
265 | scoring_model_structure = pickle.load(open('/home/gridsan/dnori/GraphINVENT/data/protac_scoring_models/Protac_Scoring_Model_Structure.pkl', 'rb'))
266 | with open('/home/gridsan/dnori/GraphINVENT/data/protac_scoring_models/features_1024_100nM.pkl','rb') as fp:
267 | features_activity = pickle.load(fp)
268 | with open('/home/gridsan/dnori/GraphINVENT/data/protac_scoring_models/features_structure.pkl','rb') as fp:
269 | features_structure = pickle.load(fp)
270 | constants_dict["qsar_models"]["protac_qsar_model"] = scoring_model_activity
271 | constants_dict["protac_structure_model"] = scoring_model_structure
272 | constants_dict["activity_model_features"] = features_activity
273 | constants_dict["structure_model_features"] = features_structure
274 |
275 | # convert `CONSTANTS` dictionary into a namedtuple (immutable + cleaner)
276 | Constants = namedtuple("CONSTANTS", sorted(constants_dict))
277 | constants = Constants(**constants_dict)
278 |
279 | return constants
280 |
281 | # collect the constants using the functions defined above
282 | #change job_dir back to args.job_dir - ask Rocio how to get from arg parser
283 | constants = collect_global_constants(parameters=defaults.parameters, job_dir=args.job_dir)
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/parameters/defaults.py:
--------------------------------------------------------------------------------
1 | """
2 | Defines default model parameters, hyperparameters, and settings.
3 | Recommended not to modify the default settings here, but rather create an input
4 | file with the modified parameters in a new job directory (see README). **Used
5 | as an alternative to using argparser, as there are many variables.**
6 | """
7 | # load general packages and functions
8 | import sys
9 |
10 | # load GraphINVENT-specific functions
11 | sys.path.insert(1, "./parameters/") # search "parameters/" directory
12 | import parameters.args as args
13 | import parameters.load as load
14 |
15 |
16 | # default parameters defined below
17 | """
18 | General settings for the generative model:
19 | atom_types (list) : Contains atom types (str) to encode in node features.
20 | formal_charge (list) : Contains formal charges (int) to encode in node
21 | features.
22 | imp_H (list) : Contains number of implicit hydrogens (int) to encode
23 | in node features.
24 | chirality (list) : Contains chiral states (str) to encode in node features.
25 | device (str) : Specifies type of architecture to run on. Options:
26 | "cuda" or "cpu").
27 | generation_epoch (int) : Epoch to sample during a 'generation' job.
28 | n_samples (int) : Number of molecules to generate during each sampling
29 | epoch. Note: if `n_samples` > 100000 molecules, these
30 | will be generated in batches of 100000.
31 | n_workers (int) : Number of subprocesses to use during data loading.
32 | restart (bool) : If specified, will restart training from previous saved
33 | state. Can only be used for preprocessing or training
34 | jobs.
35 | max_n_nodes (int) : Maximum number of allowed nodes in graph. Must be
36 | greater than or equal to the number of nodes in
37 | largest graph in training set.
38 | job_type (str) : Type of job to run; options: 'preprocess', 'train',
39 | 'generate', 'test', or 'fine-tune'.
40 | sample_every (int) : Specifies when to sample the model (i.e. epochs
41 | between sampling).
42 | dataset_dir (str) : Full path to directory containing testing ("test.smi"),
43 | training ("train.smi"), and validation ("valid.smi")
44 | sets.
45 | use_aromatic_bonds (bool) : If specified, aromatic bond types will be used.
46 | use_canon (bool) : If specified, uses canonical RDKit ordering in graph
47 | representations.
48 | use_chirality (bool) : If specified, includes chirality in the atomic
49 | representations.
50 | use_explicit_H (bool) : If specified, uses explicit Hs in molecular
51 | representations (not recommended for most applications).
52 | ignore_H (bool) : If specified, ignores H's completely in graph
53 | representations (treats them neither as explicit or
54 | implicit). When generating graphs, H's are added to
55 | graphs after generation is terminated.
56 | use_tensorboard (bool) : If specified, enables the use of tensorboard during
57 | training.
58 | tensorboard_dir (str) : Path to directory in which to write tensorboard
59 | things.
60 | batch_size (int) : Number of graphs in a mini-batch. When preprocessing
61 | graphs, this is the size of the preprocessing groups
62 | (e.g. how many subgraphs preprocessed at once).
63 | epochs (int) : Number of training epochs.
64 | init_lr (float) : Initial learning rate.
65 | max_rel_lr (float) : Maximum allowed learning rate relative to the initial
66 | (used for learning rate ramp-up).
67 | model (str) : MPNN model to use ('MNN', 'S2V', 'AttS2V', 'GGNN',
68 | 'AttGGNN', or 'EMN').
69 | decoding_route (str) : Breadth-first search ("bfs") or depth-first search
70 | ("dfs").
71 | score_components (list) : A list of all the components to use in the RL scoring
72 | function. Can include "target_size={int}", "QED",
73 | "{name}_activity".
74 | score_thresholds (list) : Acceptable thresholds for the above score components.
75 | score_type (str) : If there are multiple components used in the scoring
76 | function, determines if the final score should be
77 | "continuous" (in which case, the above thresholds
78 | are ignored), or "binary" (in which case a generated
79 | molecule will receive a score of 1 iff all its score
80 | components are greater than the specified thresholds).
81 | qsar_models (dict) : A dictionary containing the path to each activity
82 | model specified in `score_components`. Note that
83 | the key in this dict must correspond to the name
84 | of the score component.
85 | sigma (float) : Can take any value. Tunes the contribution of the
86 | score in the augmented log-likelihood. See Atance
87 | et al (2021) https://doi.org/10.33774/chemrxiv-2021-9w3tc
88 | for suitable values.
89 | alpha (float) : Can take values between [0.0, 1.0]. Tunes the contribution
90 | from the best agent so far (BASF) in the loss.
91 | """
92 | # general job parameters
93 | parameters = {
94 | "atom_types" : ["C", "N", "O", "F", "S", "Cl", "Br"],
95 | "formal_charge" : [-1, 0, 1],
96 | "imp_H" : [0, 1, 2, 3],
97 | "chirality" : ["None", "R", "S"],
98 | "device" : "cuda",
99 | "generation_epoch" : 30,
100 | "n_samples" : 2000,
101 | "n_workers" : 2,
102 | "restart" : True,
103 | "max_n_nodes" : 27,
104 | "job_type" : "preprocess", # "preprocess" -> "train" -> "fine-tune" -> "generate"
105 | "sample_every" : 10,
106 | "dataset_dir" : "data/pre-training/MOSES/",
107 | "use_aromatic_bonds" : False,
108 | "use_canon" : True,
109 | "use_chirality" : False,
110 | "use_explicit_H" : False,
111 | "ignore_H" : True,
112 | "compute_train_csv" : True,
113 | "tensorboard_dir" : "tensorboard/",
114 | "batch_size" : 1000,
115 | "block_size" : 100000,
116 | "epochs" : 100,
117 | "init_lr" : 1e-4,
118 | "max_rel_lr" : 1,
119 | "min_rel_lr" : 0.0001,
120 | "decoding_route" : "bfs",
121 | "activity_model_dir" : "data/fine-tuning/",
122 | "score_components" : ["QED", "drd2_activity", "target_size=13"],
123 | "score_thresholds" : [0.5, 0.5, 0.0], # 0.0 essentially means no threshold
124 | "score_type" : "binary",
125 | "qsar_models" : {"drd2_activity": "data/fine-tuning/qsar_model.pickle"},
126 | "pretrained_model_dir": "output/",
127 | "sigma" : 20,
128 | "alpha" : 0.5,
129 | }
130 |
131 | # make sure job dir ends in "/"
132 | if args.job_dir[-1] != "/":
133 | print("* Adding '/' to end of `job_dir`.")
134 | args.job_dir += "/"
135 |
136 | # get the model before loading model-specific hyperparameters
137 | try:
138 | input_csv_path = args.job_dir + "input.csv"
139 | model = load.which_model(input_csv_path=input_csv_path)
140 | except:
141 | model = "GGNN" # default model
142 | parameters["model"] = model
143 |
144 |
145 | # model-specific hyperparameters (implementation-specific)
146 | if parameters["model"] == "MNN":
147 | """
148 | MNN hyperparameters:
149 | mlp1_depth (int) : Num layers in first-tier MLP in `APDReadout`.
150 | mlp1_dropout_p (float) : Dropout probability in first-tier MLP in `APDReadout`.
151 | mlp1_hidden_dim (int) : Number of weights (layer width) in first-tier MLP
152 | in `APDReadout`.
153 | mlp2_depth (int) : Num layers in second-tier MLP in `APDReadout`.
154 | mlp2_dropout_p (float) : Dropout probability in second-tier MLP in `APDReadout`.
155 | mlp2_hidden_dim (int) : Number of weights (layer width) in second-tier MLP
156 | in `APDReadout`.
157 | message_passes (int) : Number of message passing steps.
158 | message_size (int) : Size of message passed ('enn' MLP output size).
159 | """
160 | hyperparameters = {
161 | "mlp1_depth" : 4,
162 | "mlp1_dropout_p" : 0.0,
163 | "mlp1_hidden_dim" : 500,
164 | "mlp2_depth" : 4,
165 | "mlp2_dropout_p" : 0.0,
166 | "mlp2_hidden_dim" : 500,
167 | "hidden_node_features": 100,
168 | "message_passes" : 3,
169 | "message_size" : 100,
170 | }
171 | elif parameters["model"] == "S2V":
172 | """
173 | S2V hyperparameters:
174 | enn_depth (int) : Num layers in 'enn' MLP.
175 | enn_dropout_p (float) : Dropout probability in 'enn' MLP.
176 | enn_hidden_dim (int) : Number of weights (layer width) in 'enn' MLP.
177 | mlp1_depth (int) : Num layers in first-tier MLP in `APDReadout`.
178 | mlp1_dropout_p (float) : Dropout probability in first-tier MLP in `APDReadout`.
179 | mlp1_hidden_dim (int) : Number of weights (layer width) in first-tier
180 | MLP in `APDReadout`.
181 | mlp2_depth (int) : Num layers in second-tier MLP in `APDReadout`.
182 | mlp2_dropout_p (float) : Dropout probability in second-tier MLP in `APDReadout`.
183 | mlp2_hidden_dim (int) : Number of weights (layer width) in second-tier
184 | MLP in `APDReadout`.
185 | message_passes (int) : Number of message passing steps.
186 | message_size (int) : Size of message passed (input size to `GRU`).
187 | s2v_lstm_computations (int) : Number of LSTM computations (loop) in S2V readout.
188 | s2v_memory_size (int) : Number of input features and hidden state size
189 | in LSTM cell in S2V readout.
190 | """
191 | hyperparameters = {
192 | "enn_depth" : 4,
193 | "enn_dropout_p" : 0.0,
194 | "enn_hidden_dim" : 250,
195 | "mlp1_depth" : 4,
196 | "mlp1_dropout_p" : 0.0,
197 | "mlp1_hidden_dim" : 500,
198 | "mlp2_depth" : 4,
199 | "mlp2_dropout_p" : 0.0,
200 | "mlp2_hidden_dim" : 500,
201 | "hidden_node_features" : 100,
202 | "message_passes" : 3,
203 | "message_size" : 100,
204 | "s2v_lstm_computations": 3,
205 | "s2v_memory_size" : 100,
206 | }
207 | elif parameters["model"] == "AttS2V":
208 | """
209 | AttS2V hyperparameters:
210 | att_depth (int) : Num layers in 'att_enn' MLP.
211 | att_dropout_p (float) : Dropout probability in 'att_enn' MLP.
212 | att_hidden_dim (int) : Number of weights (layer width) in 'att_enn'
213 | MLP.
214 | enn_depth (int) : Num layers in 'enn' MLP.
215 | enn_dropout_p (float) : Dropout probability in 'enn' MLP.
216 | enn_hidden_dim (int) : Number of weights (layer width) in 'enn' MLP.
217 | mlp1_depth (int) : Num layers in first-tier MLP in `APDReadout`.
218 | mlp1_dropout_p (float) : Dropout probability in first-tier MLP in `APDReadout`.
219 | mlp1_hidden_dim (int) : Number of weights (layer width) in first-tier
220 | MLP in `APDReadout`.
221 | mlp2_depth (int) : Num layers in second-tier MLP in `APDReadout`.
222 | mlp2_dropout_p (float) : Dropout probability in second-tier MLP in `APDReadout`.
223 | mlp2_hidden_dim (int) : Number of weights (layer width) in second-tier
224 | MLP in `APDReadout`.
225 | message_passes (int) : Number of message passing steps.
226 | message_size (int) : Size of message passed (output size of 'att_enn'
227 | MLP, input size to `GRU`).
228 | s2v_lstm_computations (int) : Number of LSTM computations (loop) in S2V readout.
229 | s2v_memory_size (int) : Number of input features and hidden state size
230 | in LSTM cell in S2V readout.
231 | """
232 | hyperparameters = {
233 | "att_depth" : 4,
234 | "att_dropout_p" : 0.0,
235 | "att_hidden_dim" : 250,
236 | "enn_depth" : 4,
237 | "enn_dropout_p" : 0.0,
238 | "enn_hidden_dim" : 250,
239 | "mlp1_depth" : 4,
240 | "mlp1_dropout_p" : 0.0,
241 | "mlp1_hidden_dim" : 500,
242 | "mlp2_depth" : 4,
243 | "mlp2_dropout_p" : 0.0,
244 | "mlp2_hidden_dim" : 500,
245 | "hidden_node_features" : 100,
246 | "message_passes" : 3,
247 | "message_size" : 100,
248 | "s2v_lstm_computations": 3,
249 | "s2v_memory_size" : 100,
250 | }
251 | elif parameters["model"] == "GGNN":
252 | """
253 | GGNN hyperparameters:
254 | enn_depth (int) : Num layers in 'enn' MLP.
255 | enn_dropout_p (float) : Dropout probability in 'enn' MLP.
256 | enn_hidden_dim (int) : Number of weights (layer width) in 'enn' MLP.
257 | mlp1_depth (int) : Num layers in first-tier MLP in `APDReadout`.
258 | mlp1_dropout_p (float) : Dropout probability in first-tier MLP in `APDReadout`.
259 | mlp1_hidden_dim (int) : Number of weights (layer width) in first-tier
260 | MLP in `APDReadout`.
261 | mlp2_depth (int) : Num layers in second-tier MLP in `APDReadout`.
262 | mlp2_dropout_p (float) : Dropout probability in second-tier MLP in `APDReadout`.
263 | mlp2_hidden_dim (int) : Number of weights (layer width) in second-tier
264 | MLP in `APDReadout`.
265 | gather_att_depth (int) : Num layers in 'gather_att' MLP in `GraphGather`.
266 | gather_att_dropout_p (float) : Dropout probability in 'gather_att' MLP in
267 | `GraphGather`.
268 | gather_att_hidden_dim (int) : Number of weights (layer width) in 'gather_att'
269 | MLP in `GraphGather`.
270 | gather_emb_depth (int) : Num layers in 'gather_emb' MLP in `GraphGather`.
271 | gather_emb_dropout_p (float) : Dropout probability in 'gather_emb' MLP in
272 | `GraphGather`.
273 | gather_emb_hidden_dim (int) : Number of weights (layer width) in 'gather_emb'
274 | MLP in `GraphGather`.
275 | gather_width (int) : Output size of `GraphGather` block.
276 | message_passes (int) : Number of message passing steps.
277 | message_size (int) : Size of message passed (output size of all
278 | MLPs in message aggregation step, input size
279 | to `GRU`).
280 | """
281 | hyperparameters = {
282 | "enn_depth" : 4,
283 | "enn_dropout_p" : 0.0,
284 | "enn_hidden_dim" : 250,
285 | "mlp1_depth" : 4,
286 | "mlp1_dropout_p" : 0.0,
287 | "mlp1_hidden_dim" : 500,
288 | "mlp2_depth" : 4,
289 | "mlp2_dropout_p" : 0.0,
290 | "mlp2_hidden_dim" : 500,
291 | "gather_att_depth" : 4,
292 | "gather_att_dropout_p" : 0.0,
293 | "gather_att_hidden_dim": 250,
294 | "gather_emb_depth" : 4,
295 | "gather_emb_dropout_p" : 0.0,
296 | "gather_emb_hidden_dim": 250,
297 | "gather_width" : 100,
298 | "hidden_node_features" : 100,
299 | "message_passes" : 3,
300 | "message_size" : 100,
301 | }
302 | elif parameters["model"] == "AttGGNN":
303 | """
304 | AttGGNN hyperparameters:
305 | att_depth (int) : Num layers in 'att_nns' MLP (message aggregation
306 | step).
307 | att_dropout_p (float) : Dropout probability in 'att_nns' MLP (message
308 | aggregation step).
309 | att_hidden_dim (int) : Number of weights (layer width) in 'att_nns'
310 | MLP (message aggregation step).
311 | mlp1_depth (int) : Num layers in first-tier MLP in `APDReadout`.
312 | mlp1_dropout_p (float) : Dropout probability in first-tier MLP in `APDReadout`.
313 | mlp1_hidden_dim (int) : Number of weights (layer width) in first-tier
314 | MLP in `APDReadout`.
315 | mlp2_depth (int) : Num layers in second-tier MLP in `APDReadout`.
316 | mlp2_dropout_p (float) : Dropout probability in second-tier MLP in `APDReadout`.
317 | mlp2_hidden_dim (int) : Number of weights (layer width) in second-tier
318 | MLP in `APDReadout`.
319 | gather_att_depth (int) : Num layers in 'gather_att' MLP in `GraphGather`.
320 | gather_att_dropout_p (float) : Dropout probability in 'gather_att' MLP in
321 | `GraphGather`.
322 | gather_att_hidden_dim (int) : Number of weights (layer width) in 'gather_att'
323 | MLP in `GraphGather`.
324 | gather_emb_depth (int) : Num layers in 'gather_emb' MLP in `GraphGather`.
325 | gather_emb_dropout_p (float) : Dropout probability in 'gather_emb' MLP in
326 | `GraphGather`.
327 | gather_emb_hidden_dim (int) : Number of weights (layer width) in 'gather_emb'
328 | MLP in `GraphGather`.
329 | gather_width (int) : Output size of `GraphGather` block.
330 | message_passes (int) : Number of message passing steps.
331 | message_size (int) : Size of message passed (output size of all
332 | MLPs in message aggregation step, input size
333 | to `GRU`).
334 | msg_depth (int) : Num layers in 'msg_nns' MLP (message aggregation
335 | step).
336 | msg_dropout_p (float) : Dropout probability in 'msg_nns' MLP (message
337 | aggregation step).
338 | msg_hidden_dim (int) : Number of weights (layer width) in 'msg_nns'
339 | MLP (message aggregation step).
340 | """
341 | hyperparameters = {
342 | "att_depth" : 4,
343 | "att_dropout_p" : 0.0,
344 | "att_hidden_dim" : 250,
345 | "mlp1_depth" : 4,
346 | "mlp1_dropout_p" : 0.0,
347 | "mlp1_hidden_dim" : 500,
348 | "mlp2_depth" : 4,
349 | "mlp2_dropout_p" : 0.0,
350 | "mlp2_hidden_dim" : 500,
351 | "gather_att_depth" : 4,
352 | "gather_att_dropout_p" : 0.0,
353 | "gather_att_hidden_dim": 250,
354 | "gather_emb_depth" : 4,
355 | "gather_emb_dropout_p" : 0.0,
356 | "gather_emb_hidden_dim": 250,
357 | "gather_width" : 100,
358 | "hidden_node_features" : 100,
359 | "message_passes" : 3,
360 | "message_size" : 100,
361 | "msg_depth" : 4,
362 | "msg_dropout_p" : 0.0,
363 | "msg_hidden_dim" : 250,
364 | }
365 | elif parameters["model"] == "EMN":
366 | """
367 | EMN hyperparameters:
368 | att_depth (int) : Num layers in 'att_msg_nn' MLP (edge propagation
369 | step).
370 | att_dropout_p (float) : Dropout probability in 'att_msg_nn' MLP (edge
371 | propagation step).
372 | att_hidden_dim (int) : Number of weights (layer width) in 'att_msg_nn'
373 | MLP (edge propagation step).
374 | edge_emb_depth (int) : Num layers in 'embedding_nn' MLP (edge processing
375 | step).
376 | edge_emb_dropout_p (float) : Dropout probability in 'embedding_nn' MLP (edge
377 | processing step).
378 | edge_emb_hidden_dim (int) : Number of weights (layer width) in 'embedding_nn'
379 | MLP (edge processing step).
380 | edge_emb_size (int) : Output size of all MLPs in edge propagation
381 | and processing steps (input size to `GraphGather`).
382 | mlp1_depth (int) : Num layers in first-tier MLP in `APDReadout`.
383 | mlp1_dropout_p (float) : Dropout probability in first-tier MLP in `APDReadout`.
384 | mlp1_hidden_dim (int) : Number of weights (layer width) in first-tier
385 | MLP in `APDReadout`.
386 | mlp2_depth (int) : Num layers in second-tier MLP in `APDReadout`.
387 | mlp2_dropout_p (float) : Dropout probability in second-tier MLP in `APDReadout`.
388 | mlp2_hidden_dim (int) : Number of weights (layer width) in second-tier
389 | MLP in `APDReadout`.
390 | gather_att_depth (int) : Num layers in 'gather_att' MLP in `GraphGather`.
391 | gather_att_dropout_p (float) : Dropout probability in 'gather_att' MLP in
392 | `GraphGather`.
393 | gather_att_hidden_dim (int) : Number of weights (layer width) in 'gather_att'
394 | MLP in `GraphGather`.
395 | gather_emb_depth (int) : Num layers in 'gather_emb' MLP in `GraphGather`.
396 | gather_emb_dropout_p (float) : Dropout probability in 'gather_emb' MLP in
397 | `GraphGather`.
398 | gather_emb_hidden_dim (int) : Number of weights (layer width) in 'gather_emb'
399 | MLP in `GraphGather`.
400 | gather_width (int) : Output size of `GraphGather` block.
401 | message_passes (int) : Number of message passing steps.
402 | msg_depth (int) : Num layers in 'emb_msg_nn' MLP (edge propagation
403 | step).
404 | msg_dropout_p (float) : Dropout probability in 'emb_msg_n' MLP (edge
405 | propagation step).
406 | msg_hidden_dim (int) : Number of weights (layer width) in 'emb_msg_nn'
407 | MLP (edge propagation step).
408 | """
409 | hyperparameters = {
410 | "att_depth" : 4,
411 | "att_dropout_p" : 0.0,
412 | "att_hidden_dim" : 250,
413 | "edge_emb_depth" : 4,
414 | "edge_emb_dropout_p" : 0.0,
415 | "edge_emb_hidden_dim" : 250,
416 | "edge_emb_size" : 100,
417 | "mlp1_depth" : 4,
418 | "mlp1_dropout_p" : 0.0,
419 | "mlp1_hidden_dim" : 500,
420 | "mlp2_depth" : 4,
421 | "mlp2_dropout_p" : 0.0,
422 | "mlp2_hidden_dim" : 500,
423 | "gather_att_depth" : 4,
424 | "gather_att_dropout_p" : 0.0,
425 | "gather_att_hidden_dim": 250,
426 | "gather_emb_depth" : 4,
427 | "gather_emb_dropout_p" : 0.0,
428 | "gather_emb_hidden_dim": 250,
429 | "gather_width" : 100,
430 | "message_passes" : 3,
431 | "msg_depth" : 4,
432 | "msg_dropout_p" : 0.0,
433 | "msg_hidden_dim" : 250,
434 | }
435 |
436 | # make sure dataset dir ends in "/"
437 | if parameters["dataset_dir"][-1] != "/":
438 | print("* Adding '/' to end of `dataset_dir`.")
439 | parameters["dataset_dir"] += "/"
440 |
441 | # join dictionaries
442 | parameters.update(hyperparameters)
443 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/parameters/load.py:
--------------------------------------------------------------------------------
1 | """
2 | Functions for loading molecules from SMILES, as well as loading the model type.
3 | """
4 | # load general packages and functions
5 | import csv
6 | import rdkit
7 | from rdkit.Chem.rdmolfiles import SmilesMolSupplier
8 |
9 |
10 | def molecules(path : str) -> rdkit.Chem.rdmolfiles.SmilesMolSupplier:
11 | """
12 | Reads a SMILES file (full path/filename specified by `path`) and returns
13 | `rdkit.Mol` objects.
14 | """
15 | # check first line of SMILES file to see if contains header
16 | with open(path) as smi_file:
17 | first_line = smi_file.readline()
18 | has_header = bool("SMILES" in first_line)
19 | smi_file.close()
20 |
21 | # read file
22 | molecule_set = SmilesMolSupplier(path,
23 | sanitize=True,
24 | nameColumn=-1,
25 | titleLine=has_header)
26 | return molecule_set
27 |
28 | def which_model(input_csv_path : str) -> str:
29 | """
30 | Gets the type of model to use by reading it from CSV (in "input.csv").
31 |
32 | Args:
33 | ----
34 | input_csv_path (str) : The full path/filename to "input.csv" file
35 | containing parameters to overwrite from defaults.
36 |
37 | Returns:
38 | -------
39 | value (str) : Name of model to use.
40 | """
41 | with open(input_csv_path, "r") as csv_file:
42 |
43 | params_reader = csv.reader(csv_file, delimiter=";")
44 |
45 | for key, value in params_reader:
46 | if key == "model":
47 | return value # string describing model e.g. "GGNN"
48 |
49 | raise ValueError("Model type not specified.")
50 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/graphinvent/training_stats.py:
--------------------------------------------------------------------------------
1 | """
2 | Create train.csv containing summary stats on training set
3 |
4 | To use script, run:
5 | python graphinvent/training_stats.py
6 | """
7 | import random
8 | import csv
9 |
10 | random.seed(10)
11 |
12 | # def write_input_csv(params_dict : dict, filename : str="params.csv") -> None:
13 | # """
14 | # Writes job parameters/hyperparameters in `params_dict` to CSV using the specified
15 | # `filename`.
16 | # """
17 | # print("in write input csv")
18 | # dict_path = params_dict["job_dir"] + filename
19 |
20 | # with open(dict_path, "w") as csv_file:
21 |
22 | # writer = csv.writer(csv_file, delimiter=";")
23 | # for key, value in params_dict.items():
24 | # writer.writerow([key, value])
25 |
26 | # params = {"compute_train_csv": True, "job_dir": "data/pre-training/MOSES/"}
27 | # write_input_csv(params_dict=params, filename="input.csv")
28 |
29 | from rdkit import Chem
30 |
31 | from DataProcesser import DataProcesser
32 | import util
33 | import torch
34 | from parameters.constants import constants
35 |
36 | from torch.utils.tensorboard import SummaryWriter
37 |
38 | #take a random subset of 10000 smiles
39 | orig_smi_file_path = str('data/pre-training/MOSES/train.smi')
40 | list_of_smiles = []
41 | with open(orig_smi_file_path, 'r') as f:
42 | for line in f.readlines():
43 | words = line.split()
44 | list_of_smiles.append(words[0])
45 |
46 | # The first entry in list_of_smiles is 'SMILES', not a SMILES string
47 | list_of_smiles.pop(0)
48 |
49 | train_smiles = list_of_smiles[:92]
50 | test_smiles = list_of_smiles[92:109]
51 | valid_smiles = list_of_smiles[109:]
52 |
53 | with open('data/pre-training/protac_db_subset_50/train.smi', 'w+') as f:
54 | for smi in train_smiles:
55 | f.write("{}\n".format(smi))
56 | with open('data/pre-training/protac_db_subset_50/test.smi', 'w+') as f:
57 | for smi in test_smiles:
58 | f.write("{}\n".format(smi))
59 | with open('data/pre-training/protac_db_subset_50/valid.smi', 'w+') as f:
60 | for smi in valid_smiles:
61 | f.write("{}\n".format(smi))
62 |
63 |
64 | #convert smiles to molecular graphs
65 | mols = [Chem.MolFromSmiles(smi) for smi in list_of_smiles]
66 | processor = DataProcesser(path = 'data/pre-training/MOSES/train.smi', is_training_set=True, molset = mols)
67 | graphs = [processor.get_graph(mol) for mol in mols]
68 |
69 | processor.get_ts_properties(molecular_graphs=graphs, group_size=10000)
70 | print(processor.ts_properties)
71 |
72 | #writecsv (from util.py)
73 | # Uncomment to write train.csv
74 | writer = SummaryWriter()
75 | util.properties_to_csv(processor.ts_properties, 'data/pre-training/MOSES/train.csv', 'Training set', writer)
76 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/tools/analyze_final_epoch.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import rdkit
4 | from rdkit import Chem
5 | from rdkit.Chem.Draw import MolsToGridImage
6 | from rdkit.Chem.rdmolfiles import SmilesMolSupplier
7 |
8 | smi_file = "/home/gridsan/dnori/GraphINVENT/output_MOSES_subset/example-job-name/job_0/generation/aggregated_generation/epoch_100.smi"
9 | train_smi = "/home/gridsan/dnori/GraphINVENT/data/pre-training/MOSES_subset/train.smi"
10 | #smi_file = "/home/gridsan/dnori/GraphINVENT/output_protac_db_subset_60/example-job-name/job_0/generation/aggregated_generation/epoch_GEN80.smi"
11 |
12 | # # load molecules from file
13 | # mols = SmilesMolSupplier(smi_file, sanitize=True, nameColumn=-1,titleLine=True)
14 |
15 | # n_samples = 40
16 | # mols_list = [mol for mol in mols]
17 | # mols_sampled = random.sample(mols_list, n_samples) # sample 100 random molecules to visualize
18 |
19 | # mols_per_row = int(math.sqrt(n_samples)) # make a square grid
20 |
21 | # png_filename=smi_file[:-3] + "png" # name of PNG file to create
22 | # labels=list(range(n_samples)) # label structures with a number
23 |
24 | # # draw the molecules (creates a PIL image)
25 | # img = MolsToGridImage(mols=mols_sampled,
26 | # molsPerRow=mols_per_row,
27 | # legends=[str(i) for i in labels])
28 |
29 | # img.save(png_filename)
30 |
31 | #calculate regeneration percentage
32 |
33 | with open(smi_file, 'r') as f1:
34 | gen_smi = f1.readlines()
35 |
36 | with open (train_smi, 'r') as f2:
37 | tr_smi = f2.readlines()
38 |
39 | #canon all smi in tr_smi:
40 | for i in range(len(tr_smi)):
41 | mol = Chem.MolFromSmiles(tr_smi[i])
42 | canon_smi = Chem.MolToSmiles(mol)
43 | if canon_smi!=tr_smi[i]:
44 | tr_smi[i]
45 |
46 | total = len(gen_smi)
47 | num = 0
48 | for s in gen_smi:
49 | mol = Chem.MolFromSmiles(s)
50 | canon_s = Chem.MolToSmiles(mol)
51 | if canon_s not in tr_smi:
52 | num+=1
53 |
54 | print(f"percentage of new molecules: {num/total}")
55 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/tools/atom_types.py:
--------------------------------------------------------------------------------
1 | """
2 | Gets the atom types present in a set of molecules.
3 |
4 | To use script, run:
5 | python atom_types.py --smi path/to/file.smi
6 | """
7 | import argparse
8 | import rdkit
9 | from utils import load_molecules
10 |
11 |
12 | # define the argument parser
13 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
14 | add_help=False)
15 |
16 | # define two potential arguments to use when drawing SMILES from a file
17 | parser.add_argument("--smi",
18 | type=str,
19 | default="data/gdb13_1K/train.smi",
20 | help="SMILES file containing molecules to analyse.")
21 | args = parser.parse_args()
22 |
23 |
24 | def get_atom_types(smi_file : str) -> list:
25 | """
26 | Determines the atom types present in an input SMILES file.
27 |
28 | Args:
29 | ----
30 | smi_file (str) : Full path/filename to SMILES file.
31 | """
32 | molecules = load_molecules(path=smi_file)
33 |
34 | # create a list of all the atom types
35 | atom_types = list()
36 | for mol in molecules:
37 | for atom in mol.GetAtoms():
38 | atom_types.append(atom.GetAtomicNum())
39 |
40 | # remove duplicate atom types then sort by atomic number
41 | set_of_atom_types = set(atom_types)
42 | atom_types_sorted = list(set_of_atom_types)
43 | atom_types_sorted.sort()
44 |
45 | # return the symbols, for convenience
46 | return [rdkit.Chem.Atom(atom).GetSymbol() for atom in atom_types_sorted]
47 |
48 |
49 | if __name__ == "__main__":
50 | atom_types = get_atom_types(smi_file=args.smi)
51 | print("* Atom types present in input file:", atom_types, flush=True)
52 | print("Done.", flush=True)
53 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/tools/combine_HDFs.py:
--------------------------------------------------------------------------------
1 | """
2 | Combines preprocessed HDF files. Useful when preprocessing large datasets, as
3 | one can split the `{split}.smi` into multiple files (and directories), preprocess
4 | them separately, and then combine using this script.
5 |
6 | To use script, modify the variables below to automatically create a list of
7 | paths **assuming** HDFs were created with the following directory structure:
8 | data/
9 | |-- {dataset}_1/
10 | |-- {dataset}_2/
11 | |-- {dataset}_3/
12 | |...
13 | |-- {dataset}_{n_dirs}/
14 |
15 | The variables are also used in setting the dimensions of the HDF datasets later on.
16 |
17 | If directories were not named as above, then simply replace `path_list` below
18 | with a list of the paths to all the HDFs to combine.
19 |
20 | Then, run:
21 | python combine_HDFs.py
22 | """
23 | import csv
24 | import numpy as np
25 | import h5py
26 | import torch
27 | from typing import Union
28 |
29 |
30 | def load_ts_properties_from_csv(csv_path : str) -> Union[dict, None]:
31 | """
32 | Loads CSV file containing training set properties and returns contents as a dictionary.
33 | """
34 | print("* Loading training set properties.", flush=True)
35 |
36 | # read dictionaries from csv
37 | try:
38 | with open(csv_path, "r") as csv_file:
39 | reader = csv.reader(csv_file, delimiter=";")
40 | csv_dict = dict(reader)
41 | except:
42 | return None
43 |
44 | # fix file types within dict in going from `csv_dict` --> `properties_dict`
45 | properties_dict = {}
46 | for key, value in csv_dict.items():
47 |
48 | # first determine if key is a tuple
49 | key = eval(key)
50 | if len(key) > 1:
51 | tuple_key = (str(key[0]), str(key[1]))
52 | else:
53 | tuple_key = key
54 |
55 | # then convert the values to the correct data type
56 | try:
57 | properties_dict[tuple_key] = eval(value)
58 | except (SyntaxError, NameError):
59 | properties_dict[tuple_key] = value
60 |
61 | # convert any `list`s to `torch.Tensor`s (for consistency)
62 | if type(properties_dict[tuple_key]) == list:
63 | properties_dict[tuple_key] = torch.Tensor(properties_dict[tuple_key])
64 |
65 | return properties_dict
66 |
67 | def write_ts_properties_to_csv(ts_properties_dict : dict) -> None:
68 | """
69 | Writes the training set properties in `ts_properties_dict` to a CSV file.
70 | """
71 | dict_path = f"data/{dataset}/{split}.csv"
72 |
73 | with open(dict_path, "w") as csv_file:
74 |
75 | csv_writer = csv.writer(csv_file, delimiter=";")
76 | for key, value in ts_properties_dict.items():
77 | if "validity_tensor" in key:
78 | continue # skip writing the validity tensor because it is really long
79 | elif type(value) == np.ndarray:
80 | csv_writer.writerow([key, list(value)])
81 | elif type(value) == torch.Tensor:
82 | try:
83 | csv_writer.writerow([key, float(value)])
84 | except ValueError:
85 | csv_writer.writerow([key, [float(i) for i in value]])
86 | else:
87 | csv_writer.writerow([key, value])
88 |
89 | def get_dims() -> dict:
90 | """
91 | Gets the dims corresponding to the three datasets in each preprocessed HDF
92 | file: "nodes", "edges", and "APDs".
93 | """
94 | dims = {}
95 | dims["nodes"] = [max_n_nodes, n_atom_types + n_formal_charges]
96 | dims["edges"] = [max_n_nodes, max_n_nodes, n_bond_types]
97 | dim_f_add = [max_n_nodes, n_atom_types, n_formal_charges, n_bond_types]
98 | dim_f_conn = [max_n_nodes, n_bond_types]
99 | dims["APDs"] = [np.prod(dim_f_add) + np.prod(dim_f_conn) + 1]
100 |
101 | return dims
102 |
103 | def get_total_n_subgraphs(paths : list) -> int:
104 | """
105 | Gets the total number of subgraphs saved in all the HDF files in the `paths`,
106 | where `paths` is a list of strings containing the path to each HDF file we want
107 | to combine.
108 | """
109 | total_n_subgraphs = 0
110 | for path in paths:
111 | print("path:", path)
112 | hdf_file = h5py.File(path, "r")
113 | nodes = hdf_file.get("nodes")
114 | n_subgraphs = nodes.shape[0]
115 | total_n_subgraphs += n_subgraphs
116 | hdf_file.close()
117 |
118 | return total_n_subgraphs
119 |
120 | def main(paths : list, training_set : bool) -> None:
121 | """
122 | Combine many small HDF files (their paths defined in `paths`) into one large HDF file.
123 | """
124 | total_n_subgraphs = get_total_n_subgraphs(paths)
125 | dims = get_dims()
126 |
127 | print(f"* Creating HDF file to contain {total_n_subgraphs} subgraphs")
128 | new_hdf_file = h5py.File(f"data/{dataset}/{split}.h5", "a")
129 | new_dataset_nodes = new_hdf_file.create_dataset("nodes",
130 | (total_n_subgraphs, *dims["nodes"]),
131 | dtype=np.dtype("int8"))
132 | new_dataset_edges = new_hdf_file.create_dataset("edges",
133 | (total_n_subgraphs, *dims["edges"]),
134 | dtype=np.dtype("int8"))
135 | new_dataset_APDs = new_hdf_file.create_dataset("APDs",
136 | (total_n_subgraphs, *dims["APDs"]),
137 | dtype=np.dtype("int8"))
138 |
139 | print("* Combining data from smaller HDFs into a new larger HDF.")
140 | init_index = 0
141 | for path in paths:
142 | print("path:", path)
143 | hdf_file = h5py.File(path, "r")
144 |
145 | nodes = hdf_file.get("nodes")
146 | edges = hdf_file.get("edges")
147 | APDs = hdf_file.get("APDs")
148 |
149 | n_subgraphs = nodes.shape[0]
150 |
151 | new_dataset_nodes[init_index:(init_index + n_subgraphs)] = nodes
152 | new_dataset_edges[init_index:(init_index + n_subgraphs)] = edges
153 | new_dataset_APDs[init_index:(init_index + n_subgraphs)] = APDs
154 |
155 | init_index += n_subgraphs
156 | hdf_file.close()
157 |
158 | new_hdf_file.close()
159 |
160 | if training_set:
161 | print(f"* Combining data from respective `{split}.csv` files into one.")
162 | csv_list = [f"{path[:-2]}csv" for path in paths]
163 |
164 | ts_properties_old = None
165 | csv_files_processed = 0
166 | for path in csv_list:
167 | ts_properties = load_ts_properties_from_csv(csv_path=path)
168 | ts_properties_new = {}
169 | if ts_properties_old and ts_properties:
170 | for key, value in ts_properties_old.items():
171 | if type(value) == float:
172 | ts_properties_new[key] = (
173 | value * csv_files_processed + ts_properties[key]
174 | )/(csv_files_processed + 1)
175 | else:
176 | new_list = []
177 | for i, value_i in enumerate(value):
178 | new_list.append(
179 | float(
180 | value_i * csv_files_processed + ts_properties[key][i]
181 | )/(csv_files_processed + 1)
182 | )
183 | ts_properties_new[key] = new_list
184 | else:
185 | ts_properties_new = ts_properties
186 | ts_properties_old = ts_properties_new
187 | csv_files_processed += 1
188 |
189 | write_ts_properties_to_csv(ts_properties_dict=ts_properties_new)
190 |
191 |
192 | if __name__ == "__main__":
193 | # combine the HDFs defined in `path_list`
194 |
195 | # set variables
196 | dataset = "ChEMBL"
197 | n_atom_types = 15 # number of atom types used in preprocessing the data
198 | n_formal_charges = 3 # number of formal charges used in preprocessing the data
199 | n_bond_types = 3 # number of bond types used in preprocessing the data
200 | max_n_nodes = 40 # maximum number of nodes in the data
201 |
202 | # combine the training files
203 | n_dirs = 12 # how many times was `{split}.smi` split?
204 | split = "train" # train, test, or valid
205 | path_list = [f"data/{dataset}_{i}/{split}.h5" for i in range(0, n_dirs)]
206 | main(path_list, training_set=True)
207 |
208 | # combine the test files
209 | n_dirs = 4 # how many times was `{split}.smi` split?
210 | split = "test" # train, test, or valid
211 | path_list = [f"data/{dataset}_{i}/{split}.h5" for i in range(0, n_dirs)]
212 | main(path_list, training_set=False)
213 |
214 | # combine the validation files
215 | n_dirs = 2 # how many times was `{split}.smi` split?
216 | split = "valid" # train, test, or valid
217 | path_list = [f"data/{dataset}_{i}/{split}.h5" for i in range(0, n_dirs)]
218 | main(path_list, training_set=False)
219 |
220 | print("Done.", flush=True)
221 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/tools/combine_generation_batches.py:
--------------------------------------------------------------------------------
1 | """
2 | Combines .likelihood, .smi, and .valid files across batches,
3 | consolidating into a set of 3 files for each epoch
4 |
5 | Run:
6 | python combine_generation_batches.py
7 | """
8 |
9 | import os
10 |
11 | path_to_folder = "/home/gridsan/dnori/GraphINVENT/output_protac_db/example-job-name/job_0/generation/"
12 | extensions = ['likelihood','smi','valid']
13 |
14 |
15 | for filename in os.listdir(path_to_folder):
16 | if 'GEN140' in filename:
17 | try:
18 | file_extension = filename[filename.index(".")+1:]
19 | except:
20 | file_extension = 'folder'
21 | if file_extension == 'smi':
22 | path = path_to_folder + filename
23 | with open(path, 'r') as input:
24 | first_instance = filename.index("_")
25 | filename_sub = filename[first_instance+1:]
26 | second_instance = filename_sub.index('_')
27 | epoch_number = filename[first_instance+7:second_instance+1+first_instance]
28 |
29 | print('yes')
30 | print('next file')
31 | #with open(f"{path_to_folder}aggregated_generation/epoch_{epoch_number}.{file_extension}", 'a+') as f:
32 | new_path = path_to_folder + "aggregated_generation/epoch_GEN140_AGG.smi"
33 | with open(new_path, 'a+') as f:
34 | for line in input:
35 | str_line = str(line)
36 | if "SMILES" in str_line:
37 | pass
38 | else:
39 | f.write(line)
--------------------------------------------------------------------------------
/GraphINVENT_Protac/tools/formal_charges.py:
--------------------------------------------------------------------------------
1 | """
2 | Gets the formal charges present in a set of molecules.
3 |
4 | To use script, run:
5 | python formal_charges.py --smi path/to/file.smi
6 | """
7 | import argparse
8 | from utils import load_molecules
9 |
10 |
11 | # define the argument parser
12 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
13 | add_help=False)
14 |
15 | # define two potential arguments to use when drawing SMILES from a file
16 | parser.add_argument("--smi",
17 | type=str,
18 | default="data/gdb13_1K/train.smi",
19 | help="SMILES file containing molecules to analyse.")
20 | args = parser.parse_args()
21 |
22 |
23 | def get_formal_charges(smi_file : str) -> list:
24 | """
25 | Determines the formal charges present in an input SMILES file.
26 |
27 | Args:
28 | ----
29 | smi_file (str) : Full path/filename to SMILES file.
30 | """
31 | molecules = load_molecules(path=smi_file)
32 |
33 | # create a list of all the formal charges
34 | formal_charges = list()
35 | for mol in molecules:
36 | for atom in mol.GetAtoms():
37 | formal_charges.append(atom.GetFormalCharge())
38 |
39 | # remove duplicate formal charges then sort
40 | set_of_formal_charges = set(formal_charges)
41 | formal_charges_sorted = list(set_of_formal_charges)
42 | formal_charges_sorted.sort()
43 |
44 | return formal_charges_sorted
45 |
46 |
47 | if __name__ == "__main__":
48 | formal_charges = get_formal_charges(smi_file=args.smi)
49 | print("* Formal charges present in input file:", formal_charges, flush=True)
50 | print("Done.", flush=True)
51 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/tools/max_n_nodes.py:
--------------------------------------------------------------------------------
1 | """
2 | Gets the maximum number of nodes per molecule present in a set of molecules.
3 |
4 | To use script, run:
5 | python max_n_nodes.py --smi path/to/file.smi
6 | """
7 | import argparse
8 | from utils import load_molecules
9 |
10 |
11 | # define the argument parser
12 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
13 | add_help=False)
14 |
15 | # define two potential arguments to use when drawing SMILES from a file
16 | parser.add_argument("--smi",
17 | type=str,
18 | default="data/gdb13_1K/train.smi",
19 | help="SMILES file containing molecules to analyse.")
20 | args = parser.parse_args()
21 |
22 |
23 | def get_max_n_atoms(smi_file : str) -> int:
24 | """
25 | Determines the maximum number of atoms per molecule in an input SMILES file.
26 |
27 | Args:
28 | ----
29 | smi_file (str) : Full path/filename to SMILES file.
30 | """
31 | molecules = load_molecules(path=smi_file)
32 |
33 | max_n_atoms = 0
34 | for mol in molecules:
35 | n_atoms = mol.GetNumAtoms()
36 |
37 | if n_atoms > max_n_atoms:
38 | max_n_atoms = n_atoms
39 |
40 | return max_n_atoms
41 |
42 |
43 | if __name__ == "__main__":
44 | max_n_atoms = get_max_n_atoms(smi_file=args.smi)
45 | print("* Max number of atoms in input file:", max_n_atoms, flush=True)
46 | print("Done.", flush=True)
47 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/tools/split_filter_protac.py:
--------------------------------------------------------------------------------
1 | '''
2 | Filters data by max_n_nodes. Randomly splits into 80% test, 15% test, 5% valid
3 |
4 | python split_filter_protac.py --orig_smi path/to/file.smi --new_smi path/to/file.smi --threshold num
5 | '''
6 | import random
7 |
8 | import argparse
9 | from utils import load_molecules
10 | from rdkit import Chem
11 |
12 | # define the argument parser
13 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
14 | add_help=False)
15 |
16 | # define two potential arguments to use when drawing SMILES from a file
17 | parser.add_argument("--orig_smi",
18 | type=str,
19 | default="data/gdb13_1K/train.smi",
20 | help="path to original SMILES file.")
21 | parser.add_argument("--new_smi",
22 | type=str,
23 | default="data/gdb13_1K/train.smi",
24 | help="path to new SMILES file.")
25 | parser.add_argument("--threshold",
26 | type=int,
27 | default=100,
28 | help="All molecules with >n atoms will be filtered out.")
29 | args = parser.parse_args()
30 |
31 | def filter(threshold, original_smi_file, filtered_smi_file):
32 | #filter from original.smi to new smi file filtered by size
33 | #threshold num must be an integer less than 139
34 |
35 | molecules = load_molecules(path=original_smi_file)
36 |
37 | with open(filtered_smi_file, 'w+') as f:
38 | for mol in molecules:
39 | n_atoms = mol.GetNumAtoms()
40 | if n_atoms <= threshold:
41 | atoms = mol.GetAtoms()
42 | atom_types = set(atoms)
43 | atom_symbols = [Chem.Atom(atom).GetSymbol() for atom in atom_types]
44 | if 'I' not in atom_symbols and 'P' not in atom_symbols:
45 | smi = Chem.MolToSmiles(mol)
46 | f.write("{}\n".format(smi))
47 |
48 |
49 |
50 | def split(filtered_smi_file):
51 |
52 | orig_smiles = []
53 | with open(filtered_smi_file, 'r') as f:
54 | for line in f.readlines():
55 | words = line.split()
56 | orig_smiles.append(words[0])
57 |
58 | test_sz = int(.15*len(orig_smiles))
59 | valid_sz = int(.05*len(orig_smiles))
60 | train_sz = len(orig_smiles) - test_sz - valid_sz
61 |
62 | random.shuffle(orig_smiles)
63 | train_smiles = orig_smiles[:train_sz]
64 | test_smiles = orig_smiles[train_sz:len(orig_smiles)-valid_sz]
65 | valid_smiles = orig_smiles[len(orig_smiles)-valid_sz:]
66 |
67 | with open('data/pre-training/protac_db_subset_70/train.smi', 'w+') as f:
68 | for smi in train_smiles:
69 | f.write("{}\n".format(smi))
70 | with open('data/pre-training/protac_db_subset_70/test.smi', 'w+') as f:
71 | for smi in test_smiles:
72 | f.write("{}\n".format(smi))
73 | with open('data/pre-training/protac_db_subset_70/valid.smi', 'w+') as f:
74 | for smi in valid_smiles:
75 | f.write("{}\n".format(smi))
76 |
77 | if __name__ == "__main__":
78 | filter(threshold=args.threshold, original_smi_file = args.orig_smi, filtered_smi_file=args.new_smi)
79 | split(filtered_smi_file = args.new_smi)
80 | print("Done.", flush=True)
--------------------------------------------------------------------------------
/GraphINVENT_Protac/tools/submit-split-preprocessing-supercloud.py:
--------------------------------------------------------------------------------
1 | """
2 | Example submission script for a GraphINVENT preprocessing job (before distribution-
3 | based training, not fine-tuning/optimization job), when the dataset is large and
4 | we want to split one large preprocessing job into multiple smaller preprocessing
5 | jobs, aggregating the final HDF files at the end. The HDF dataset created can be
6 | used to pre-train a model before a fine-tuning (via reinforcement learning) job.
7 |
8 | To run, you can first split the dataset as follows (do this within an interactive session):
9 | (graphinvent)$ python submit-split-preprocessing-supercloud.py --type split
10 |
11 | Then, submit the separate preprocessing jobs for the split dataset as follows:
12 | (graphinvent)$ python submit-split-preprocessing-supercloud.py --type submit
13 |
14 | When the above jobs have completed, aggregate the generated HDFs for each dataset split into the main dataset dir:
15 | (graphinvent)$ python submit-split-preprocessing-supercloud.py --type aggregate
16 |
17 | The above script also cleans up extra files.
18 |
19 | This script was modified to run on the MIT Supercloud.
20 | """
21 | # load general packages and functions
22 | import csv
23 | import argparse
24 | import sys
25 | import os
26 | import shutil
27 | from pathlib import Path
28 | import subprocess
29 | import time
30 | from math import ceil
31 | from typing import Union
32 | import numpy as np
33 | import h5py
34 | import torch
35 |
36 | # define the argument parser
37 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
38 | add_help=False)
39 |
40 | # define potential arguments for using this script
41 | parser.add_argument("--type",
42 | type=str,
43 | default="split",
44 | help="Acceptable values include 'split', 'submit', 'aggregate', and 'cleanup'.")
45 | args = parser.parse_args()
46 |
47 | # define what you want to do for the specified job(s)
48 | DATASET = "MOSES" # dataset name in "./data/pre-training/"
49 | JOB_TYPE = "preprocess" # "preprocess", "train", "generate", or "test"
50 | JOBDIR_START_IDX = 0 # where to start indexing job dirs
51 | N_JOBS = 1 # number of jobs to run per model
52 | RESTART = True # whether or not this is a restart job
53 | FORCE_OVERWRITE = True # overwrite job directories which already exist
54 | JOBNAME = "preprocessing" # used to create a sub directory
55 |
56 | # if running using LLsub, specify params below
57 | USE_LLSUB = True # use LLsub or not
58 | MEM_GB = 20 # required RAM in GB
59 |
60 | # for LLsub jobs, set number of CPUs per task
61 | if JOB_TYPE == "preprocess":
62 | CPUS_PER_TASK = 20
63 | DEVICE = "cpu"
64 | else:
65 | CPUS_PER_TASK = 10
66 | DEVICE = "cuda"
67 |
68 | # set paths here
69 | HOME = str(Path.home())
70 | PYTHON_PATH = f"{HOME}/.conda/envs/graphinvent/bin/python"
71 | GRAPHINVENT_PATH = "./graphinvent/"
72 | DATA_PATH = "./data/pre-training/"
73 |
74 | # define dataset-specific parameters
75 | params = {
76 | "atom_types" : ["C", "N", "O", "F", "S", "Cl", "Br"],
77 | "formal_charge": [-1, 0, +1],
78 | "max_n_nodes" : 27,
79 | "job_type" : JOB_TYPE,
80 | "restart" : RESTART,
81 | "model" : "GGNN",
82 | "sample_every" : 2,
83 | "init_lr" : 1e-4,
84 | "epochs" : 100,
85 | "batch_size" : 50,
86 | "block_size" : 1000,
87 | "device" : DEVICE,
88 | "n_samples" : 100,
89 | # additional paramaters can be defined here, if different from the "defaults"
90 | # for instance, for "generate" jobs, don't forget to specify "generation_epoch"
91 | # and "n_samples"
92 | }
93 |
94 |
95 | def submit() -> None:
96 | """
97 | Creates and submits submission script. Uses global variables defined at top
98 | of this file.
99 | """
100 | check_paths()
101 |
102 | # create an output directory
103 | dataset_output_path = f"{HOME}/GraphINVENT/output_{DATASET}"
104 | tensorboard_path = os.path.join(dataset_output_path, "tensorboard")
105 | if JOBNAME != "":
106 | dataset_output_path = os.path.join(dataset_output_path, JOBNAME)
107 | tensorboard_path = os.path.join(tensorboard_path, JOBNAME)
108 |
109 | os.makedirs(dataset_output_path, exist_ok=True)
110 | os.makedirs(tensorboard_path, exist_ok=True)
111 | print(f"* Creating dataset directory {dataset_output_path}/", flush=True)
112 |
113 | # submit `N_JOBS` separate jobs
114 | jobdir_end_idx = JOBDIR_START_IDX + N_JOBS
115 | for job_idx in range(JOBDIR_START_IDX, jobdir_end_idx):
116 |
117 | # specify and create the job subdirectory if it does not exist
118 | params["job_dir"] = f"{dataset_output_path}/job_{job_idx}/"
119 | params["tensorboard_dir"] = f"{tensorboard_path}/job_{job_idx}/"
120 |
121 | # create the directory if it does not exist already, otherwise raises an
122 | # error, which is good because *might* not want to override data in our
123 | # existing directories!
124 | os.makedirs(params["tensorboard_dir"], exist_ok=True)
125 | try:
126 | job_dir_exists_already = bool(
127 | JOB_TYPE in ["generate", "test"] or FORCE_OVERWRITE
128 | )
129 | os.makedirs(params["job_dir"], exist_ok=job_dir_exists_already)
130 | print(
131 | f"* Creating model subdirectory {dataset_output_path}/job_{job_idx}/",
132 | flush=True,
133 | )
134 | except FileExistsError:
135 | print(
136 | f"-- Model subdirectory {dataset_output_path}/job_{job_idx}/ already exists.",
137 | flush=True,
138 | )
139 | if not RESTART:
140 | continue
141 |
142 | # write the `input.csv` file
143 | write_input_csv(params_dict=params, filename="input.csv")
144 |
145 | # write `submit.sh` and submit
146 | if USE_LLSUB:
147 | print("* Writing submission script.", flush=True)
148 | write_submission_script(job_dir=params["job_dir"],
149 | job_idx=job_idx,
150 | job_type=params["job_type"],
151 | max_n_nodes=params["max_n_nodes"],
152 | cpu_per_task=CPUS_PER_TASK,
153 | python_bin_path=PYTHON_PATH)
154 |
155 | print("* Submitting batch job using LLsub.", flush=True)
156 | subprocess.run(["LLsub", params["job_dir"] + "submit.sh"],
157 | check=True)
158 | else:
159 | print("* Running job as a normal process.", flush=True)
160 | subprocess.run(["ls", f"{PYTHON_PATH}"], check=True)
161 | subprocess.run([f"{PYTHON_PATH}",
162 | f"{GRAPHINVENT_PATH}main.py",
163 | "--job-dir",
164 | params["job_dir"]],
165 | check=True)
166 |
167 | # sleep a few secs before submitting next job
168 | print("-- Sleeping 2 seconds.")
169 | time.sleep(2)
170 |
171 |
172 | def write_input_csv(params_dict : dict, filename : str="params.csv") -> None:
173 | """
174 | Writes job parameters/hyperparameters in `params_dict` to CSV using the specified
175 | `filename`.
176 | """
177 | dict_path = params_dict["job_dir"] + filename
178 |
179 | with open(dict_path, "w") as csv_file:
180 |
181 | writer = csv.writer(csv_file, delimiter=";")
182 | for key, value in params_dict.items():
183 | writer.writerow([key, value])
184 |
185 |
186 | def write_submission_script(job_dir : str, job_idx : int, job_type : str, max_n_nodes : int,
187 | cpu_per_task : int, python_bin_path : str) -> None:
188 | """
189 | Writes a submission script (`submit.sh`).
190 |
191 | Args:
192 | ----
193 | job_dir (str) : Job running directory.
194 | job_idx (int) : Job idx.
195 | job_type (str) : Type of job to run.
196 | max_n_nodes (int) : Maximum number of nodes in dataset.
197 | cpu_per_task (int) : How many CPUs to use per task.
198 | python_bin_path (str) : Path to Python binary to use.
199 | """
200 | submit_filename = job_dir + "submit.sh"
201 | with open(submit_filename, "w") as submit_file:
202 | submit_file.write("#!/bin/bash\n")
203 | submit_file.write(f"#SBATCH --job-name={job_type}{max_n_nodes}_{job_idx}\n")
204 | submit_file.write(f"#SBATCH --output={job_type}{max_n_nodes}_{job_idx}o\n")
205 | submit_file.write(f"#SBATCH --cpus-per-task={cpu_per_task}\n")
206 | if DEVICE == "cuda":
207 | submit_file.write("#SBATCH --gres=gpu:volta:1\n")
208 | submit_file.write("hostname\n")
209 | submit_file.write("export QT_QPA_PLATFORM='offscreen'\n")
210 | submit_file.write(f"{python_bin_path} {GRAPHINVENT_PATH}main.py --job-dir {job_dir}")
211 | submit_file.write(f" > {job_dir}output.o${{LLSUB_RANK}}\n")
212 |
213 |
214 | def check_paths() -> None:
215 | """
216 | Checks that paths to Python binary, data, and GraphINVENT are properly
217 | defined before running a job, and tells the user to define them if not.
218 | """
219 | for path in [PYTHON_PATH, GRAPHINVENT_PATH, DATA_PATH]:
220 | if "path/to/" in path:
221 | print("!!!")
222 | print("* Update the following paths in `submit.py` before running:")
223 | print("-- `PYTHON_PATH`\n-- `GRAPHINVENT_PATH`\n-- `DATA_PATH`")
224 | sys.exit(0)
225 |
226 | def split_file(filename : str, n_lines_per_split : int=100000) -> None:
227 | """
228 | _summary_
229 |
230 | Args:
231 | ----
232 | filename (str) : The filename.
233 | n_lines_per_split (int) : Number of lines per file.
234 |
235 | Returns:
236 | -------
237 | n_splits (int) : Number of splits.
238 | """
239 | output_base = filename[:-4]
240 | input = open(filename, "r")
241 | extension = filename[-3:]
242 |
243 | count = 0
244 | at = 0
245 | dest = None
246 | for line in input:
247 | if count % n_lines_per_split == 0:
248 | if dest: dest.close()
249 | dest = open(f"{output_base}.{at}.{extension}", "w")
250 | at += 1
251 | dest.write(line)
252 | count += 1
253 |
254 | n_splits = at
255 | return n_splits
256 |
257 | def get_n_splits(filename : str, n_lines_per_split : int=100000) -> None:
258 | """
259 | _summary_
260 |
261 | Args:
262 | ----
263 | filename (str) : The filename.
264 | n_lines_per_split (int) : Number of lines per file.
265 |
266 | Returns:
267 | -------
268 | n_splits (int) : Number of splits.
269 | """
270 | output_base = filename[:-4]
271 | input = open(filename, "r")
272 | n_splits = ceil(len(input.readlines()) / n_lines_per_split)
273 | return n_splits
274 |
275 | def load_ts_properties_from_csv(csv_path : str) -> Union[dict, None]:
276 | """
277 | Loads CSV file containing training set properties and returns contents as a dictionary.
278 | """
279 | print("* Loading training set properties.", flush=True)
280 |
281 | # read dictionaries from csv
282 | try:
283 | with open(csv_path, "r") as csv_file:
284 | reader = csv.reader(csv_file, delimiter=";")
285 | csv_dict = dict(reader)
286 | except:
287 | return None
288 |
289 | # fix file types within dict in going from `csv_dict` --> `properties_dict`
290 | properties_dict = {}
291 | for key, value in csv_dict.items():
292 |
293 | # first determine if key is a tuple
294 | key = eval(key)
295 | if len(key) > 1:
296 | tuple_key = (str(key[0]), str(key[1]))
297 | else:
298 | tuple_key = key
299 |
300 | # then convert the values to the correct data type
301 | try:
302 | properties_dict[tuple_key] = eval(value)
303 | except (SyntaxError, NameError):
304 | properties_dict[tuple_key] = value
305 |
306 | # convert any `list`s to `torch.Tensor`s (for consistency)
307 | if type(properties_dict[tuple_key]) == list:
308 | properties_dict[tuple_key] = torch.Tensor(properties_dict[tuple_key])
309 |
310 | return properties_dict
311 |
312 | def write_ts_properties_to_csv(ts_properties_dict : dict, split : str) -> None:
313 | """
314 | Writes the training set properties in `ts_properties_dict` to a CSV file.
315 | """
316 | dict_path = f"data/pre-training/{dataset}/{split}.csv"
317 |
318 | with open(dict_path, "w") as csv_file:
319 |
320 | csv_writer = csv.writer(csv_file, delimiter=";")
321 | for key, value in ts_properties_dict.items():
322 | if "validity_tensor" in key:
323 | continue # skip writing the validity tensor because it is really long
324 | elif type(value) == np.ndarray:
325 | csv_writer.writerow([key, list(value)])
326 | elif type(value) == torch.Tensor:
327 | try:
328 | csv_writer.writerow([key, float(value)])
329 | except ValueError:
330 | csv_writer.writerow([key, [float(i) for i in value]])
331 | else:
332 | csv_writer.writerow([key, value])
333 |
334 | def get_dims() -> dict:
335 | """
336 | Gets the dims corresponding to the three datasets in each preprocessed HDF
337 | file: "nodes", "edges", and "APDs".
338 | """
339 | dims = {}
340 | dims["nodes"] = [max_n_nodes, n_atom_types + n_formal_charges]
341 | dims["edges"] = [max_n_nodes, max_n_nodes, n_bond_types]
342 | dim_f_add = [max_n_nodes, n_atom_types, n_formal_charges, n_bond_types]
343 | dim_f_conn = [max_n_nodes, n_bond_types]
344 | dims["APDs"] = [np.prod(dim_f_add) + np.prod(dim_f_conn) + 1]
345 |
346 | return dims
347 |
348 | def get_total_n_subgraphs(paths : list) -> int:
349 | """
350 | Gets the total number of subgraphs saved in all the HDF files in the `paths`,
351 | where `paths` is a list of strings containing the path to each HDF file we want
352 | to combine.
353 | """
354 | total_n_subgraphs = 0
355 | for path in paths:
356 | print("path:", path)
357 | hdf_file = h5py.File(path, "r")
358 | nodes = hdf_file.get("nodes")
359 | n_subgraphs = nodes.shape[0]
360 | total_n_subgraphs += n_subgraphs
361 | hdf_file.close()
362 |
363 | return total_n_subgraphs
364 |
365 | def combine_HDFs(paths : list, training_set : bool, split : str) -> None:
366 | """
367 | Combine many small HDF files (their paths defined in `paths`) into one large
368 | HDF file. Works assuming HDFs were created for the preprocessed dataset
369 | following the following directory structure:
370 | data/pre-training/
371 | |-- {dataset}_1/
372 | |-- {dataset}_2/
373 | |-- {dataset}_3/
374 | |...
375 | |-- {dataset}_{n_dirs}/
376 | """
377 | total_n_subgraphs = get_total_n_subgraphs(paths)
378 | dims = get_dims()
379 |
380 | print(f"* Creating HDF file to contain {total_n_subgraphs} subgraphs")
381 | new_hdf_file = h5py.File(f"data/pre-training/{dataset}/{split}.h5", "a")
382 | new_dataset_nodes = new_hdf_file.create_dataset("nodes",
383 | (total_n_subgraphs, *dims["nodes"]),
384 | dtype=np.dtype("int8"))
385 | new_dataset_edges = new_hdf_file.create_dataset("edges",
386 | (total_n_subgraphs, *dims["edges"]),
387 | dtype=np.dtype("int8"))
388 | new_dataset_APDs = new_hdf_file.create_dataset("APDs",
389 | (total_n_subgraphs, *dims["APDs"]),
390 | dtype=np.dtype("int8"))
391 |
392 | print("* Combining data from smaller HDFs into a new larger HDF.")
393 | init_index = 0
394 | for path in paths:
395 | print("path:", path)
396 | hdf_file = h5py.File(path, "r")
397 |
398 | nodes = hdf_file.get("nodes")
399 | edges = hdf_file.get("edges")
400 | APDs = hdf_file.get("APDs")
401 |
402 | n_subgraphs = nodes.shape[0]
403 |
404 | new_dataset_nodes[init_index:(init_index + n_subgraphs)] = nodes
405 | new_dataset_edges[init_index:(init_index + n_subgraphs)] = edges
406 | new_dataset_APDs[init_index:(init_index + n_subgraphs)] = APDs
407 |
408 | init_index += n_subgraphs
409 | hdf_file.close()
410 |
411 | new_hdf_file.close()
412 |
413 | if training_set:
414 | print(f"* Combining data from respective `{split}.csv` files into one.")
415 | csv_list = [f"{path[:-2]}csv" for path in paths]
416 |
417 | ts_properties_old = None
418 | csv_files_processed = 0
419 | for path in csv_list:
420 | ts_properties = load_ts_properties_from_csv(csv_path=path)
421 | ts_properties_new = {}
422 | if ts_properties_old and ts_properties:
423 | for key, value in ts_properties_old.items():
424 | if type(value) == float:
425 | ts_properties_new[key] = (
426 | value * csv_files_processed + ts_properties[key]
427 | )/(csv_files_processed + 1)
428 | else:
429 | new_list = []
430 | for i, value_i in enumerate(value):
431 | new_list.append(
432 | float(
433 | value_i * csv_files_processed + ts_properties[key][i]
434 | )/(csv_files_processed + 1)
435 | )
436 | ts_properties_new[key] = new_list
437 | else:
438 | ts_properties_new = ts_properties
439 | ts_properties_old = ts_properties_new
440 | csv_files_processed += 1
441 |
442 | write_ts_properties_to_csv(ts_properties_dict=ts_properties_new, split=split)
443 |
444 | if __name__ == "__main__":
445 | dataset = DATASET
446 |
447 | if args.type == "split":
448 | # --------- SPLIT THE DATASET ----------
449 | # 1) first, split the training set
450 | n_training_splits = split_file(filename=f"{DATA_PATH}{DATASET}/train.smi",
451 | n_lines_per_split=100000)
452 |
453 | # 2) then, split the test set (if necessary)
454 | n_test_splits = split_file(filename=f"{DATA_PATH}{DATASET}/test.smi",
455 | n_lines_per_split=100000)
456 |
457 | # 3) finally, split the validation set (if necessary)
458 | n_valid_splits = split_file(filename=f"{DATA_PATH}{DATASET}/valid.smi",
459 | n_lines_per_split=100000)
460 |
461 | elif args.type == "submit":
462 | # ---------- MOVE EACH SPLIT INTO ITS OWN DIRECTORY AND SUBMIT EACH AS SEPARATE JOB ----------
463 | # first get the number of splits for each train/test/valid split if each
464 | # file is split into files of max 100000 lines
465 | n_training_splits = get_n_splits(filename=f"{DATA_PATH}{DATASET}/train.smi",
466 | n_lines_per_split=100000)
467 |
468 | for split_idx in range(n_training_splits):
469 | if not os.path.exists(f"{DATA_PATH}{dataset}_{split_idx}/"):
470 | os.mkdir(f"{DATA_PATH}{dataset}_{split_idx}/") # make the dir
471 |
472 | # moving train split into folder for given index
473 | try:
474 | os.rename(f"{DATA_PATH}{dataset}/train.{split_idx}.smi", f"{DATA_PATH}{dataset}_{split_idx}/train.smi") # move the file to the dir and rename
475 | except:
476 | pass
477 |
478 | # moving test split into folder for given index, if test split exists
479 | try:
480 | os.rename(f"{DATA_PATH}{dataset}/test.{split_idx}.smi", f"{DATA_PATH}{dataset}_{split_idx}/test.smi") # move the file to the dir and rename
481 | except:
482 | pass
483 |
484 | # moving valid split into folder for given index, if valid split exists
485 | try:
486 | os.rename(f"{DATA_PATH}{dataset}/valid.{split_idx}.smi", f"{DATA_PATH}{dataset}_{split_idx}/valid.smi") # move the file to the dir and rename
487 | except:
488 | pass
489 |
490 | DATASET = f"{dataset}_{split_idx}/"
491 | params["dataset_dir"] = f"{DATA_PATH}{DATASET}"
492 | submit()
493 |
494 |
495 | elif args.type == "aggregate":
496 | # first get the number of splits for each train/test/valid split if each
497 | # file is split into files of max 100000 lines
498 | n_training_splits = get_n_splits(filename=f"{DATA_PATH}{DATASET}/train.smi",
499 | n_lines_per_split=100000)
500 | n_test_splits = get_n_splits(filename=f"{DATA_PATH}{DATASET}/test.smi",
501 | n_lines_per_split=100000)
502 | n_valid_splits = get_n_splits(filename=f"{DATA_PATH}{DATASET}/valid.smi",
503 | n_lines_per_split=100000)
504 | # ---------- AGGREGATE THE RESULTS ----------
505 | # set variables
506 | n_atom_types = len(params["atom_types"]) # number of atom types used in preprocessing the data
507 | n_formal_charges = len(params["formal_charge"]) # number of formal charges used in preprocessing the data
508 | n_bond_types = 3 # number of bond types used in preprocessing the data
509 | max_n_nodes = params["max_n_nodes"] # maximum number of nodes in the data
510 |
511 | # # 1) combine the training files
512 | # path_list = [f"data/pre-training/{dataset}_{i}/train.h5" for i in range(0, n_training_splits)]
513 | # combine_HDFs(path_list, training_set=False, split="train")
514 |
515 | # 2) combine the test files
516 | path_list = [f"data/pre-training/{dataset}_{i}/test.h5" for i in range(0, n_test_splits)]
517 | combine_HDFs(path_list, training_set=False, split="test")
518 |
519 | # 3) combine the validation files
520 | path_list = [f"data/pre-training/{dataset}_{i}/valid.h5" for i in range(0, n_valid_splits)]
521 | combine_HDFs(path_list, training_set=False, split="valid")
522 |
523 | # ---------- DELETE TEMPORARY FILES ----------
524 | for split_idx in range(max(n_training_splits, n_test_splits, n_valid_splits)):
525 | shutil.rmtree(f"{DATA_PATH}{dataset}_{split_idx}/") # remove the dir and all files in it
526 | else:
527 | raise ValueError("Not a valid job type.")
528 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/tools/tdc-create-dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Uses the Therapeutics Data Commons (TDC) to get datasets for goal-directed
3 | molecular optimization tasks and then filters the molecules based on number
4 | of heavy atoms and formal charge.
5 |
6 | See:
7 | * https://tdcommons.ai/
8 | * https://github.com/mims-harvard/TDC
9 |
10 | To use script, run:
11 | (graphinvent)$ python tdc-create-dataset.py --dataset MOSES
12 | """
13 | import os
14 | import argparse
15 | from pathlib import Path
16 | import shutil
17 | from tdc.generation import MolGen
18 | import rdkit
19 | from rdkit import Chem
20 |
21 | # define the argument parser
22 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
23 | add_help=False)
24 |
25 | # define two potential arguments to use when drawing SMILES from a file
26 | parser.add_argument("--dataset",
27 | type=str,
28 | default="ChEMBL",
29 | help="Specifies the dataset to use for creating the data. Options "
30 | "are: 'ChEMBL', 'MOSES', or 'ZINC'.")
31 | args = parser.parse_args()
32 |
33 |
34 | def save_smiles(smi_file : str, smi_list : list) -> None:
35 | """Saves input list of SMILES to the specified file path."""
36 | smi_writer = rdkit.Chem.rdmolfiles.SmilesWriter(smi_file)
37 | for smi in smi_list:
38 | try:
39 | mol = rdkit.Chem.MolFromSmiles(smi[0])
40 | if mol.GetNumAtoms() < 81: # filter out molecules with >= 81 atoms
41 | save = True
42 | for atom in mol.GetAtoms():
43 | if atom.GetFormalCharge() not in [-1, 0, +1]: # filter out molecules with large formal charge
44 | save = False
45 | break
46 | if save:
47 | smi_writer.write(mol)
48 | except: # likely TypeError or AttributeError e.g. "smi[0]" is "nan"
49 | continue
50 | smi_writer.close()
51 |
52 |
53 | if __name__ == "__main__":
54 | print(f"* Loading {args.dataset} dataset using the TDC.")
55 | data = MolGen(name=args.dataset)
56 | split = data.get_split()
57 | HOME = str(Path.home())
58 | DATA_PATH = f"./data/{args.dataset}/"
59 | try:
60 | os.mkdir(DATA_PATH)
61 | print(f"-- Creating dataset at {DATA_PATH}")
62 | except FileExistsError:
63 | shutil.rmtree(DATA_PATH)
64 | os.mkdir(DATA_PATH)
65 | print(f"-- Removed old directory at {DATA_PATH}")
66 | print(f"-- Creating new dataset at {DATA_PATH}")
67 |
68 | print(f"* Re-saving {args.dataset} dataset in a format GraphINVENT can parse.")
69 | print("-- Saving training data...")
70 | save_smiles(smi_file=f"{DATA_PATH}train.smi", smi_list=split["train"].values)
71 | print("-- Saving testing data...")
72 | save_smiles(smi_file=f"{DATA_PATH}test.smi", smi_list=split["test"].values)
73 | print("-- Saving validation data...")
74 | save_smiles(smi_file=f"{DATA_PATH}valid.smi", smi_list=split["valid"].values)
75 |
76 | # # delete the raw downloaded files
77 | # dir_path = "./data/"
78 | # shutil.rmtree(dir_path)
79 | print("Done.", flush=True)
80 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/tools/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Miscellaneous functions.
3 | """
4 | import rdkit
5 | from rdkit.Chem.rdmolfiles import SmilesMolSupplier
6 |
7 |
8 | def load_molecules(path : str) -> rdkit.Chem.rdmolfiles.SmilesMolSupplier:
9 | """
10 | Reads a SMILES file (full path/filename specified by `path`) and returns the
11 | `rdkit.Mol` object "supplier".
12 | """
13 | # check first line of SMILES file to see if contains header
14 | with open(path) as smi_file:
15 | first_line = smi_file.readline()
16 | has_header = bool("SMILES" in first_line)
17 | smi_file.close()
18 |
19 | # read file
20 | molecule_set = SmilesMolSupplier(path, sanitize=True, nameColumn=-1, titleLine=has_header)
21 |
22 | return molecule_set
23 |
--------------------------------------------------------------------------------
/GraphINVENT_Protac/tools/visualization.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import rdkit
4 | from rdkit.Chem.Draw import MolsToGridImage
5 | from rdkit.Chem.rdmolfiles import SmilesMolSupplier
6 |
7 | smi_file = "/home/gridsan/dnori/GraphINVENT/output_MOSES_subset/example-job-name/job_0/generation/step61_agent.smi"
8 |
9 | # load molecules from file
10 | mols = SmilesMolSupplier(smi_file, sanitize=True, nameColumn=-1,titleLine=True)
11 |
12 | n_samples = 8
13 | mols_list = [mol for mol in mols]
14 | mols_sampled = random.sample(mols_list, n_samples) # sample 100 random molecules to visualize
15 |
16 | mols_per_row = int(math.sqrt(n_samples)) # make a square grid
17 |
18 | png_filename=smi_file[:-3] + "png" # name of PNG file to create
19 | #labels=list(range(n_samples)) # label structures with a number
20 |
21 | labels = [i for i in range(n_samples)]
22 |
23 | # draw the molecules (creates a PIL image)
24 | img = MolsToGridImage(mols=mols_sampled,
25 | molsPerRow=mols_per_row,
26 | legends=[str(i) for i in labels])
27 |
28 | img.save(png_filename)
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Divya Nori
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Protac-Design
2 | ## Description
3 | This repo contains the code behind our workshop paper at the NeurIPS 2022 AI4Science Workshop, [Link to Paper](https://openreview.net/pdf?id=pGyp4o9gky0). It is organized into the following notebooks:
4 |
5 | * [surrogate_model.ipynb](./surrogate_model.ipynb): Contains the code for processing the raw PROTAC data and training the DC50 surrogate model. Note that you will need to download the public PROTAC data from [PROTAC-DB](http://cadd.zju.edu.cn/protacdb/downloads) in order to reproduce the results.
6 | * [molecule_metrics.ipynb](./molecule_metrics.ipynb): Contains code for computing metrics on a set of generated molecules. Metrics include percentage predicted active, percentage of duplicate molecules, percentage of molecules regenerated from training set, average number of atoms, chemical diversity, and drug-likeness.
7 | * [binary_label_metrics.py](./binary_label_metrics.py): Contains useful functions for analyzing performance of binary classification models.
8 |
9 | Then there are additional files in the repo:
10 | * [surrogate_model.pkl](./surrogate_model.pkl): Contains the pre-trained surrogate model for DC50 prediction.
11 | * [features.pkl](./features.pkl): Contains list of features used in surrogate model training; required to reproduce reinforcement learning jobs using protac scoring function.
12 |
13 |
14 | ## Instructions
15 | 1. Before running any of the notebooks, you will need to download the PROTAC data from the public [PROTAC-DB](http://cadd.zju.edu.cn/protacdb) database.
16 | 2. You will then need to create a conda environment containing the following main packages: `rdkit`, `pandas`, `sklearn`, `scipy`, `ipython`, and `optuna`. See the instructions in the next section for setting this up.
17 | 3. Open the notebooks on your favorite platform, and make sure to select the right kernel before executing.
18 |
19 | ## Environment
20 | To set up the environment for running the notebooks in this repo, you can follow the following set of instructions:
21 | ```
22 | conda create -n protacs-env -c conda-forge scikit-learn optuna rdkit
23 | conda activate protacs-env
24 | conda install pandas scipy
25 | ```
26 |
27 |
28 | ## Citation
29 | Nori, Divya et al. (2022) "De novo PROTAC design using graph-based deep generative models." NeurIPS 2022 AI4Science Workshop.
30 |
31 | ## Additional data
32 | Additional data, including saved GraphINVENT model states, generated structures, analysis scripts, and training data, are available on Zenodo [here](https://doi.org/10.5281/zenodo.7278277).
33 |
34 | ## Authors
35 | * Divya Nori
36 | * Rocío Mercado
37 |
--------------------------------------------------------------------------------
/features.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divnori/Protac-Design/1456f55400b5b24396f4e17ecf458948e51f315b/features.pkl
--------------------------------------------------------------------------------
/surrogate_model.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divnori/Protac-Design/1456f55400b5b24396f4e17ecf458948e51f315b/surrogate_model.pkl
--------------------------------------------------------------------------------