├── .gitignore ├── LICENSE ├── README.md ├── data └── negative_links_uspto.csv ├── datasets ├── GAE.py ├── __init__.py ├── reaction_graph.py └── seal.py ├── environment.yaml ├── figures └── method_overview.jpg ├── main.py ├── models ├── __init__.py ├── autoencoder.py └── dgcnn.py ├── predict_links.py ├── reaction_data ├── get_negative_reactions_info.py ├── get_reactions_csv.py └── get_smiles_node_degree.py ├── settings ├── __init__.py ├── optuna.py └── settings_reaction_prediction.py ├── tests ├── __init__.py └── test_dataset │ ├── __init__.py │ ├── test_overlaps_splits.py │ ├── test_overlaps_valid_fold.py │ ├── test_reproducible_dataset.py │ └── test_split_distribution.py ├── torch_trainer.py └── utils ├── __init__.py ├── evaluate_model.py ├── evaluate_predictions.py ├── metrics.py ├── negative_sampling.py └── reactions_info.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Data 132 | data/ 133 | 134 | # Logging 135 | results 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Molecular AI, AstraZeneca Sweden AB 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reaction Graph Link Prediction 2 | 3 | **[Installation](#installation)** 4 | | **[Data](#data)** 5 | | **[Usage](#usage)** 6 | | **[Contributors](#contributors)** 7 | | **[Citation](#citation)** 8 | | **[References](#references)** 9 | 10 | This repository contains end-to-end training and evaluation of the SEAL [[1]](https://proceedings.neurips.cc/paper_files/paper/2018/file/53f0d7c537d99b3824f0f99d62ea2428-Paper.pdf) and Graph Auto-Encoder [[2]](https://arxiv.org/abs/1611.07308) link prediction algorithms on a Chemical Reaction Knowledge Graph built on reactions from USPTO. This code has been used to generate the results in [[3]](https://chemrxiv.org/engage/chemrxiv/article-details/64e34fe400bbebf0e68bcfb8). 11 | 12 | In [[3]](https://chemrxiv.org/engage/chemrxiv/article-details/64e34fe400bbebf0e68bcfb8), a novel de novo design method is presented in which the link prediction is used for predicting novel pairs of reactants. The link prediction is then followed by product prediction using a transformer model, Chemformer, which predicts the products given the reactants. This repository covers the link prediction (reaction prediction) and for the subsequent product prediction we refer to the original [Chemformer](https://github.com/MolecularAI/Chemformer) repository. 13 | 14 | Link Prediction in this setting is equivalent to predicting novel reactions between reactant pairs. The code presented here is based on the implementation by Zhang et al. [[1]](https://proceedings.neurips.cc/paper_files/paper/2018/file/53f0d7c537d99b3824f0f99d62ea2428-Paper.pdf) of [SEAL](https://github.com/facebookresearch/SEAL_OGB/tree/main). 15 | 16 | ![plot](figures/method_overview.jpg) 17 | 18 | **Figure 1. Overview of the method.** (top) Step 1, link prediction in a Chemical Reaction Knowledge Graph (CRKG) using [SEAL](https://github.com/facebookresearch/SEAL_OGB/tree/main), and (bottom) Step 2, product prediction for highly predicted novel links using [Chemformer](https://github.com/MolecularAI/Chemformer). 19 | 20 | ## Installation 21 | 22 | After cloning the repository, the recommended way to install the environment is to use `conda`: 23 | 24 | ```bash 25 | $ conda env create -f environment.yaml 26 | ``` 27 | 28 | ## Data 29 | Download the USPTO reaction graph from [here](https://doi.org/10.5281/zenodo.10171188) and place it inside the ```data/``` folder. 30 | 31 | ## Usage 32 | Use this repository to train and evaluate our proposed model with, 33 | 34 | ```bash 35 | $ python main.py --graph_path [PATH-TO-GRAPH] --name [EXPERIMENT-NAME] 36 | ``` 37 | 38 | Optional settings can be provided as additional arguments. The training script generates the following files, 39 | - ```data/```: Processed data files. 40 | - ```results/```: Individual folders containing all relevant results from a GraphTrainer, including 41 | - Checkpoints of model and optimizer parameters, based on best validation AUC and best validation loss separately. 42 | - Log file of outputs from training, including the number of datapoints in train/valid/test split, number of network parameters, and more. 43 | - Pickle files of all results from training and testing separately. 44 | - Some preliminary plots. 45 | - Test metrics and test predictions in csv format. 46 | - A csv settings file of the hyperparameters used for training. 47 | 48 | ### Reproducibility 49 | Once a SEAL model has been trained the probability of novel links can be predicted as follows, 50 | ```bash 51 | $ python predict_links.py --model_dir_path [PATH-TO-TRAINED-SEAL] --save_path [SAVE-PATH] --graph_path [PATH-TO-GRAPH] --edges_path data/negative_links_uspto.csv 52 | ``` 53 | 54 | Exchange ```data/negative_links_uspto.csv``` with your potential links. 55 | 56 | ### Parallelization / Runtime 57 | Most optimally, run with GPU available. In addition, SEAL-based link prediction is parallelizable on CPUs. Negative links generation by default uses a node degree distribution-preserving sampling function (sample_degree_preserving_distribution) which can take a long time depending on graph size. However, it only needs to be run once for a given link-sampling seed after which it is stored in ```data/```. Alternatively, an approximating function (sample_distribution) can be used with quicker runtime. 58 | 59 | ### Codebase 60 | 61 | ```torch_trainer.py``` contains the main trainer class and is called by the ```main.py```, ```optimize.py``` and ```predict_links.py``` individually. 62 | 63 | The main script initializes and runs a GraphTrainer from the ```torch_trainer.py``` file. The training process utilizes the following modules: 64 | - ```datasets/reaction_graph.py```: Importing graph and setting up training/validation/test positive edges. 65 | - ```datasets/seal.py```: Dynamic dataloader for SEAL algorithm, including sub-graph extraction and Double Radius Node Labelling (DRNL). 66 | - ```datasets/GAE.py```: Dataloader for GAE algorithm. 67 | - ```models/dgcnn.py```: The Deep Graph Convolutional Neural Networks used for the prediction of the likelihood of a link between the source and target nodes in the given subgraph. 68 | - ```models/autoencoder.py```: Graph Autoencoder used for prediction of the likelihood of a link between the source and target nodes, implemented using Torch Geometric library. 69 | - ```utils/```: various related functions used throughout the project. 70 | 71 | ## License 72 | 73 | The software is licensed under the Apache 2.0 license (see [LICENSE](https://github.com/MolecularAI/reaction-graph-link-prediction/blob/main/LICENSE)), and is free and provided as-is. 74 | 75 | ## Contributors 76 | - [Emma Rydholm](https://github.com/emmaryd) 77 | - Tomas Bastys 78 | - [Emma Svensson](https://github.com/emmas96) 79 | 80 | ## Citation 81 | 82 | Please cite our work using the following reference. 83 | ```bibtex 84 | @article{rydholm2024expanding, 85 | author = {Rydholm, Emma and Bastys, Tomas and Svensson, Emma and Kannas, Christos and Engkvist, Ola and Kogej, Thierry}, 86 | title = {{Expanding the chemical space using a chemical reaction knowledge graph}}, 87 | journal = {Digital Discovery}, 88 | year = {2024}, 89 | pages = {-}, 90 | publisher = {RSC}, 91 | doi = {10.1039/D3DD00230F} 92 | } 93 | ``` 94 | 95 | ## Funding 96 | This work was partially supported by the Wallenberg Artificial Intelligence, Autonomous Systems, and Software Program (WASP), funded by the Knut and Alice Wallenberg Foundation. Additionally, this work was partially funded by the European Union's Horizon 2020 research and innovation program under the Marie Sklodowska-Curie Innovative Training Network European Industrial Doctorate grant agreement No. 956832 “Advanced machine learning for Innovative Drug Discovery”. 97 | 98 | ## References: 99 | [1] M. Zhang and Y. Chen, "Link prediction based on graph neural networks," Advances in neural information processing systems 31, 2018. 100 | 101 | [2] T. N. Kipf and M. Welling, "Variational Graph Auto-Encoders", Neural Information Processing Systems 2016. 102 | 103 | [3] E. Rydholm, T. Bastys, E. Svensson, C. Kannas, O. Engkvist and T. Kogej, "Expanding the chemical space using a Chemical Reaction Knowledge Graph," ChemRxiv. 2023 104 | 105 | [4] R. Irwin, S. Dimitriadis, J. He and E. Bjerrum, "Chemformer: a pre-trained transformer for computational chemistry," Machine Learning: Science and Technology. 2022, 31 Jan. 2022. 106 | 107 | ## Keywords 108 | Link prediction, chemical reactions, synthesis prediction, forward synthesis prediction, transformer, chemical space, de novo design, knowledge graph, reaction graph 109 | 110 | -------------------------------------------------------------------------------- /data/negative_links_uspto.csv: -------------------------------------------------------------------------------- 1 | ,Reactant smiles 1,Reactant smiles 2,Source,Target,y true 2 | 0,O=CCCc1cccc(Br)c1,c1cnc(N2CCNCC2)nc1,171057,104283,0 3 | 1,Brc1cncc(N2CCCNCC2)c1,CN(C)CCCCl,224583,106564,0 4 | 2,BrCCc1c[nH]c2ccccc12,c1csc(N2CCNCC2)n1,105709,218004,0 5 | 3,CCOC(C)(OCC)OCC,CN1C[C@@H](CN)C[C@H]2c3cccc4[nH]cc(c34)C[C@H]21,123371,350651,0 6 | 4,C1CCN(CCCN2CCNCC2)C1,CCOC(=O)c1cc(Cl)c2ccccc2n1,202616,277424,0 7 | 5,C=Cc1ccncc1,O=C1Cc2cc(C3CCNCC3)ccc2N1,123556,131706,0 8 | 6,N[C@H]1CCNC1,O=CCCc1cccc(Br)c1,109346,171057,0 9 | 7,CCCCN1CCNCC1,Cc1cc2nc(Cl)nc(Cl)c2s1,241930,129571,0 10 | 8,COc1cccc(CCN2CCNCC2)c1,ClC1=NCCN1,169214,185510,0 11 | 9,CC1CC1C(=O)O,CN1C[C@@H](CN)C[C@H]2c3cccc4[nH]cc(c34)C[C@H]21,122085,350651,0 12 | -------------------------------------------------------------------------------- /datasets/GAE.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch_geometric.data import Dataset 4 | 5 | 6 | class GeneralDataset(Dataset): 7 | def __init__(self, root, dataset, settings, split, **kwargs): 8 | self.data = dataset.data 9 | self.sampling_factor = settings["neg_pos_ratio"] 10 | self.split = split 11 | super(GeneralDataset, self).__init__(root) 12 | 13 | # Get positive and neative edges for given split 14 | self.pos_edge = self.data.split_edge[self.split]["pos"] 15 | self.neg_edge = self.data.split_edge[self.split]["neg"] 16 | 17 | # Creates a torch with all edges 18 | self.links = torch.cat([self.pos_edge, self.neg_edge], 1).t().tolist() 19 | self.labels = [1] * self.pos_edge.size(1) + [0] * self.neg_edge.size(1) 20 | 21 | edge_weight = torch.ones(self.data.edge_index.size(1), dtype=int) 22 | 23 | @property 24 | def num_features(self): 25 | if self.data.x != None: 26 | return len(self.data.x[0]) 27 | else: 28 | return None 29 | 30 | def __len__(self): 31 | return len(self.links) 32 | 33 | def get(self, idx): 34 | 35 | return self.links[idx], self.labels[idx] 36 | 37 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/reaction-graph-link-prediction/118acb3b4f2d9afe5c34a1a132c91ad1b8c021d5/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/reaction_graph.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import logging 5 | import warnings 6 | import numpy as np 7 | 8 | import torch 9 | from torch_geometric.data import Data, InMemoryDataset 10 | from torch_geometric.utils import to_undirected 11 | 12 | from utils.negative_sampling import ( 13 | sample_random, 14 | correct_overlaps, 15 | sample_degree_preserving_distribution 16 | ) 17 | 18 | # Import Graph-tool, ignore warnings related to C++ code conversion 19 | with warnings.catch_warnings(): 20 | warnings.filterwarnings("ignore", category=RuntimeWarning) 21 | import graph_tool.all as gt 22 | 23 | 24 | class ReactionGraph(InMemoryDataset): 25 | def __init__(self, root, settings, seed_neg_sampling, **kwargs): 26 | self.settings = settings 27 | 28 | self.percent_edges = settings["train_fraction"] 29 | self.percent_edges_test = settings["train_fraction"] 30 | 31 | self.neg_pos_ratio = settings["neg_pos_ratio"] 32 | self.neg_pos_ratio_test = settings["neg_pos_ratio_test"] 33 | self.splitting = self.settings["splitting"] 34 | 35 | self.seed = seed_neg_sampling 36 | self.fold = None 37 | self.num_nodes = None 38 | 39 | super(ReactionGraph, self).__init__(root) 40 | self.data, self.slices = torch.load(self.processed_paths[0]) 41 | 42 | @property 43 | def raw_file_names(self): 44 | return [self.settings["graph_path"]] 45 | 46 | @property 47 | def processed_file_names(self): 48 | return [f"processed_data.pt"] 49 | 50 | def download(self): 51 | pass 52 | 53 | def process(self): 54 | # Import graph with graph-tool 55 | graph = gt.load_graph(self.settings["graph_path"]) 56 | self.num_nodes = graph.num_vertices() 57 | 58 | attribute = "fingerprint" if self.settings["use_attribute"] else None 59 | # Node matrix with optional attributes 60 | if attribute == "fingerprint": 61 | # Node attributes are fingerprints of size 1024 62 | attrs = graph.vertex_properties[attribute].get_2d_array( 63 | list(range(0, 1024)) 64 | ) 65 | X = torch.tensor(attrs.transpose()) 66 | else: 67 | X = None 68 | 69 | # Create edge adjacency list 70 | adj_list = graph.get_edges( 71 | eprops=[ 72 | graph.edge_properties["random fold"] 73 | ] 74 | ) 75 | 76 | A = torch.LongTensor(adj_list[:, :2].transpose()) 77 | 78 | # Create data object with edges and node attributes 79 | self.data = Data(x=X, edge_index=A) 80 | 81 | # Add folds to use when creating train/val/test splits 82 | self.data.random_folds = adj_list[:, 2] 83 | 84 | # Confirm upper triangular graph 85 | row, col = self.data.edge_index 86 | l = len(row) 87 | self.data.edge_index = None 88 | mask = row < col 89 | row, col = row[mask], col[mask] 90 | if l != len(row): 91 | logging.warning("The graph is not upper triangular.") 92 | sys.exit() 93 | 94 | # All positive edges in full graph 95 | self.data.pos_edge = torch.stack([row, col], dim=0).long() 96 | 97 | # Find how many negative edges to sample by 'sample_distribution' and 'sample_random' 98 | # respectively. 99 | if self.splitting == "random": 100 | self.fold = self.data.random_folds 101 | 102 | fraction_test = np.sum(self.fold == 10) / len(self.fold) 103 | 104 | # n_neg_to_sample = int(((fraction_test * self.neg_pos_ratio_test * self.percent_edges_test) \ 105 | # + ((1 - fraction_test) * self.neg_pos_ratio)) \ 106 | # + ((1 - fraction_test) * self.neg_pos_ratio * self.percent_edges)) \ 107 | # * self.data.pos_edge.shape[1]) 108 | n_neg_to_sample = int( 109 | ( 110 | (fraction_test * self.neg_pos_ratio_test) 111 | + ((1 - fraction_test) * self.neg_pos_ratio) 112 | ) 113 | * self.data.pos_edge.shape[1] 114 | ) 115 | n_dist_negs_to_sample = int( 116 | n_neg_to_sample * self.settings["fraction_dist_neg"] 117 | ) 118 | n_rand_negs_to_sample = int( 119 | n_neg_to_sample * (1 - self.settings["fraction_dist_neg"]) 120 | ) 121 | 122 | # Sample from node distribution of the positive edges 123 | if n_dist_negs_to_sample > 0: 124 | i_dist, j_dist = sample_degree_preserving_distribution( 125 | "data/negative_degree-preserving_distribution_graph=" 126 | + os.path.splitext(os.path.basename(self.settings["graph_path"]))[0] 127 | + "_seed=" 128 | + str(self.seed) 129 | + ".pt", 130 | self.data.pos_edge, 131 | n_dist_negs_to_sample, 132 | self.num_nodes, 133 | self.data.pos_edge, 134 | seed=self.seed, 135 | ) 136 | else: 137 | i_dist, j_dist = torch.tensor([]), torch.tensor([]) 138 | 139 | self.data.neg_edge_dist = torch.stack((i_dist, j_dist), dim=0) 140 | 141 | # Sample randomly from all non-isolated nodes 142 | if n_rand_negs_to_sample > 0: 143 | i_rand, j_rand = sample_random( 144 | self.data.pos_edge, 145 | n_rand_negs_to_sample, 146 | self.num_nodes, 147 | self.data.pos_edge, 148 | seed=self.seed, 149 | ) 150 | 151 | # Correct overlap between 2 sets of negatives 152 | i_rand, j_rand = correct_overlaps( 153 | (i_dist, j_dist), (i_rand, j_rand), self.num_nodes, seed=self.seed 154 | ) 155 | else: 156 | i_rand, j_rand = torch.tensor([]), torch.tensor([]) 157 | self.data.neg_edge_rand = torch.stack((i_rand, j_rand), dim=0) 158 | 159 | # Store data objects 160 | torch.save(self.collate([self.data]), self.processed_paths[0]) 161 | 162 | def create_negative_set(self): 163 | """Creates the negative set from the randomly sampled and sampled from 164 | distribution negatives, according to the given percentage.""" 165 | 166 | n_neg_edges = ( 167 | self.data.neg_edge_rand.shape[1] + self.data.neg_edge_dist.shape[1] 168 | ) 169 | n_dist = self.data.neg_edge_dist.shape[1] 170 | 171 | np.random.seed(seed=self.seed) 172 | mask = n_dist > np.random.permutation(n_neg_edges) 173 | 174 | # Merge and mix negative edges and negative edges sampled from distribution to one tensor 175 | # stored as: 'self.data.neg_edge' 176 | neg_edge = np.zeros((n_neg_edges, 2)) 177 | neg_edge[mask] = self.data.neg_edge_dist.transpose(0, 1).detach().clone() 178 | neg_edge[~mask] = self.data.neg_edge_rand.transpose(0, 1).detach().clone() 179 | 180 | self.data.neg_edge = ( 181 | torch.tensor(neg_edge).transpose(1, 0).int().detach().clone() 182 | ) 183 | 184 | def process_splits(self): 185 | """Create the train, validation and test sets.""" 186 | valid_fold = self.settings["valid_fold"] 187 | 188 | if valid_fold not in self.fold: 189 | sys.exit( 190 | f'Invalid setting "valid_fold" = {valid_fold}. Choose from: {set(self.fold) - set([10])}' 191 | ) 192 | 193 | num_edges = self.data.pos_edge.shape[1] 194 | 195 | # Prepare edge split based on time or random folds 196 | # According to split labels: 197 | # 0 = 'train', 1 = 'valid', 2 = 'test', 20 = not included in any split (time) 198 | 199 | if self.splitting == "random": 200 | self.fold = self.data.random_folds[0] 201 | # Initialize all edges as to belong in train set 202 | edge_split_label = np.zeros(num_edges, dtype=int) 203 | # Test set 204 | edge_split_label[self.fold == 10] = 2 205 | # Validation set 206 | edge_split_label[self.fold == valid_fold] = 1 207 | 208 | # Create self.data.neg_edge 209 | self.create_negative_set() 210 | 211 | n_pos_all = int(self.data.pos_edge.shape[1]) 212 | n_pos = int(self.data.pos_edge.shape[1] * self.percent_edges) 213 | 214 | # Split edges between train / valid / test sets 215 | self.data.split_edge = {"train": {}, "valid": {}, "test": {}} 216 | lower_range = 0 217 | for key, split in {2: "test", 1: "valid", 0: "train"}.items(): 218 | logging.debug("Creating %s split.", split.upper()) 219 | 220 | # Create the positive set 221 | all_pos_edges_in_split = ( 222 | self.data.pos_edge[:, edge_split_label == key].detach().clone() 223 | ) 224 | # Extracting a subset of all positive edges in split 225 | n_in_split = all_pos_edges_in_split.shape[1] 226 | perm_split_percent = np.random.permutation(n_in_split) 227 | if split != "test": 228 | perm_pos = perm_split_percent[: int(n_in_split * self.percent_edges)] 229 | else: 230 | perm_pos = perm_split_percent[ 231 | : int(n_in_split * self.percent_edges_test) 232 | ] 233 | self.data.split_edge[split]["pos"] = ( 234 | all_pos_edges_in_split[:, perm_pos].detach().clone() 235 | ) 236 | 237 | # Create the negative set 238 | if split == "train": 239 | train_neg_set = [] 240 | for fold in range(n_folds): 241 | if fold != valid_fold: 242 | train_neg_set.append( 243 | self.data.neg_edge[ 244 | :, folds2index[fold][0] : folds2index[fold][1] 245 | ] 246 | ) 247 | self.data.split_edge[split]["neg"] = torch.cat(train_neg_set, dim=1) 248 | elif split == "valid": 249 | self.data.split_edge[split]["neg"] = self.data.neg_edge[ 250 | :, folds2index[valid_fold][0] : folds2index[valid_fold][1] 251 | ] 252 | 253 | elif split == "test": 254 | # Create the negative set index dictionary 255 | # n_neg_in_test = int(self.data.split_edge['test']['pos'].shape[1] * self.neg_pos_ratio_test) 256 | folds2index = { 257 | "test": (0, all_pos_edges_in_split.shape[1]) 258 | } # n_neg_in_test)} 259 | first_index = all_pos_edges_in_split.shape[1] # n_neg_in_test 260 | # Test fold has fixed ID 10, so take next largest ID (4 or 9) + 1 for 5 or 10 fold cross validation 261 | n_folds = sorted(set(self.fold.tolist()))[-2] + 1 262 | for _fold in range(n_folds): 263 | n_in_fold = len(edge_split_label[self.fold == _fold]) 264 | folds2index[_fold] = ( 265 | first_index, 266 | first_index 267 | + int(n_in_fold * self.neg_pos_ratio * self.percent_edges), 268 | ) 269 | first_index += int(n_in_fold * self.neg_pos_ratio) 270 | if ( 271 | folds2index[n_folds - 1][1] > self.data.neg_edge.shape[1] + 1 272 | ): # second should be +1 first, not sure why the +1 273 | logging.warning( 274 | "Last index of last fold for negative sampling > total number of negatives sampled." 275 | ) 276 | 277 | self.data.split_edge[split]["neg"] = self.data.neg_edge[ 278 | :, 279 | folds2index["test"][0] : int( 280 | folds2index["test"][1] * self.percent_edges_test 281 | ), 282 | ] 283 | 284 | logging.debug( 285 | f"{split} dataset contains of %d positive and %d negative edges. \n Ratio = %f.", 286 | self.data.split_edge[f"{split}"]["pos"].shape[1], 287 | self.data.split_edge[f"{split}"]["neg"].shape[1], 288 | self.data.split_edge[f"{split}"]["neg"].shape[1] 289 | / self.data.split_edge[f"{split}"]["pos"].shape[1], 290 | ) 291 | 292 | # Add opposite direction of positive training edges to make it undirected 293 | self.data.split_edge["train"]["pos"] = to_undirected( 294 | self.data.split_edge["train"]["pos"] 295 | ) 296 | 297 | # Remove test and validation edges from edges list provided to SEAL model 298 | train_edge_boolean = [] 299 | for i in edge_split_label: 300 | if i in (1, 2): # If in valid or test split 301 | train_edge_boolean.append(False) 302 | else: # If in train or no split 303 | train_edge_boolean.append(True) 304 | 305 | all_edges_except_test_and_valid = ( 306 | self.data.pos_edge[:, train_edge_boolean].detach().clone() 307 | ) 308 | self.data.edge_index = all_edges_except_test_and_valid 309 | 310 | # Store data objects 311 | torch.save(self.collate([self.data]), self.processed_paths[0]) 312 | 313 | -------------------------------------------------------------------------------- /datasets/seal.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import numpy as np 4 | import scipy.sparse as ssp 5 | from scipy.sparse.csgraph import shortest_path 6 | 7 | import torch 8 | from torch_geometric.data import Data, Dataset 9 | 10 | 11 | class SEALDynamicDataset(Dataset): 12 | """Class for creating dataset used for link prediction with SEAL. 13 | This class constructs subgraphs for each target link. 14 | """ 15 | 16 | def __init__(self, root, dataset, settings, split="train", **kwargs): 17 | self.data = dataset.data 18 | self.num_nodes = dataset.num_nodes 19 | self.num_hops = settings["num_hops"] 20 | self.node_label = settings["node_label"] 21 | self.ratio_per_hop = settings["ratio_per_hop"] 22 | self.max_nodes_per_hop = settings["max_nodes_per_hop"] 23 | self.split = split 24 | super(SEALDynamicDataset, self).__init__(root) 25 | 26 | # Get positive and negative edges for given split 27 | self.pos_edge = self.data.split_edge[self.split]["pos"] 28 | self.neg_edge = self.data.split_edge[self.split]["neg"] 29 | 30 | # Create a torch with positive and negative edges and one with labels 31 | self.links = torch.cat([self.pos_edge, self.neg_edge], 1).t().tolist() 32 | self.labels = [1] * self.pos_edge.size(1) + [0] * self.neg_edge.size(1) 33 | 34 | edge_weight = torch.ones(self.data.edge_index.size(1), dtype=int) 35 | 36 | self.A = ssp.csr_matrix( 37 | ( 38 | edge_weight.numpy(), 39 | (self.data.edge_index[0].numpy(), self.data.edge_index[1].numpy()), 40 | ), 41 | shape=(self.num_nodes, self.num_nodes), 42 | ) 43 | 44 | def __len__(self): 45 | return len(self.links) 46 | 47 | def get(self, idx): 48 | """Retrieves a subgraph around the source and target nodes, given by idx.""" 49 | 50 | links, labels = self.links, self.labels 51 | 52 | tmp = k_hop_subgraph( 53 | links[idx], 54 | self.num_hops, 55 | self.A, 56 | self.ratio_per_hop, 57 | self.max_nodes_per_hop, 58 | node_features=self.data.x, 59 | y=labels[idx], 60 | ) 61 | 62 | data = construct_pyg_graph(*tmp, links[idx], self.node_label) 63 | 64 | return data 65 | 66 | 67 | ### UTILS ### 68 | def k_hop_subgraph( 69 | node_idx, 70 | num_hops, 71 | A, 72 | sample_ratio=1.0, 73 | max_nodes_per_hop=None, 74 | node_features=None, 75 | y=1, 76 | ): 77 | """Extract the k-hop enclosing subgraph around link node_idx=(src, dst) from A.""" 78 | 79 | dists = [0, 0] 80 | visited = set(node_idx) 81 | fringe = set(node_idx) 82 | for dist in range(1, num_hops + 1): 83 | # Get 1-hop neighbors not visited 84 | fringe = set(A[list(fringe)].indices) 85 | fringe = fringe - visited 86 | visited = visited.union(fringe) 87 | if sample_ratio < 1.0: 88 | fringe = random.sample(fringe, int(sample_ratio * len(fringe))) 89 | if max_nodes_per_hop is not None: 90 | if max_nodes_per_hop < len(fringe): 91 | fringe = random.sample(fringe, max_nodes_per_hop) 92 | if len(fringe) == 0: 93 | break 94 | node_idx = node_idx + list(fringe) 95 | dists = dists + [dist] * len(fringe) 96 | subgraph = A[node_idx, :][:, node_idx] 97 | 98 | # Remove target link between the subgraph. 99 | subgraph[0, 1] = 0 100 | subgraph[1, 0] = 0 101 | 102 | if node_features is not None: 103 | node_features = node_features[node_idx] 104 | 105 | return node_idx, subgraph, node_features, y 106 | 107 | 108 | def construct_pyg_graph(node_ids, adj, node_features, y, link, node_label="drnl"): 109 | """Construct a pytorch_geometric graph from a scipy csr adjacency matrix.""" 110 | 111 | u, v, r = ssp.find(adj) 112 | num_nodes = adj.shape[0] 113 | 114 | node_ids = torch.LongTensor(node_ids) 115 | u, v = torch.LongTensor(u), torch.LongTensor(v) 116 | r = torch.LongTensor(r) 117 | edge_index = torch.stack([u, v], 0) 118 | edge_weight = r.to(torch.float) 119 | y = torch.tensor([y]) 120 | 121 | if node_label == "drnl": # DRNL 122 | z = drnl_node_labeling(adj, 0, 1) 123 | else: 124 | raise NotImplementedError(f"{node_label} is not a valid 'node_label' setting.") 125 | 126 | sub_data = Data( 127 | node_features, 128 | edge_index, 129 | edge_weight=edge_weight, 130 | y=y, 131 | z=z, 132 | node_id=node_ids, 133 | num_nodes=num_nodes, 134 | link=link, 135 | ) 136 | return sub_data 137 | 138 | 139 | def drnl_node_labeling(adj, src, dst): 140 | """Double Radius Node Labeling (DRNL).""" 141 | src, dst = (dst, src) if src > dst else (src, dst) 142 | 143 | idx = list(range(src)) + list(range(src + 1, adj.shape[0])) 144 | adj_wo_src = adj[idx, :][:, idx] 145 | 146 | idx = list(range(dst)) + list(range(dst + 1, adj.shape[0])) 147 | adj_wo_dst = adj[idx, :][:, idx] 148 | 149 | dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src) 150 | dist2src = np.insert(dist2src, dst, 0, axis=0) 151 | dist2src = torch.from_numpy(dist2src) 152 | 153 | dist2dst = shortest_path( 154 | adj_wo_src, directed=False, unweighted=True, indices=dst - 1 155 | ) 156 | dist2dst = np.insert(dist2dst, src, 0, axis=0) 157 | dist2dst = torch.from_numpy(dist2dst) 158 | 159 | dist = dist2src + dist2dst 160 | dist_over_2, dist_mod_2 = dist // 2, dist % 2 161 | 162 | z = 1 + torch.min(dist2src, dist2dst) 163 | z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1) 164 | z[src] = 1.0 165 | z[dst] = 1.0 166 | z[torch.isnan(z)] = 0.0 167 | 168 | return z.to(torch.long) 169 | 170 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: link-prediction 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | dependencies: 6 | - cudatoolkit=10.2.89 7 | - graph-tool=2.37 8 | - matplotlib=3.3.4 9 | - mkl=2024.0.0 10 | - optuna=2.10.0 11 | - pip=21.0.1 12 | - python=3.8.8 13 | - pytorch=1.8.0 14 | - rdkit=2020.09.5 15 | - seaborn=0.11.1 16 | - scikit-learn=0.24.1 17 | - pip: 18 | - --find-links https://data.pyg.org/whl/torch-1.8.0+cu102.html 19 | - numpy==1.20.1 20 | - pandas==1.1.5 21 | - torch-geometric==1.7.0 22 | - torch-scatter==2.0.6 23 | - torch-sparse==0.6.9 24 | 25 | -------------------------------------------------------------------------------- /figures/method_overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/reaction-graph-link-prediction/118acb3b4f2d9afe5c34a1a132c91ad1b8c021d5/figures/method_overview.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import argparse 4 | import importlib 5 | 6 | from torch_trainer import GraphTrainer 7 | 8 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 9 | 10 | # Parse argments for updating 11 | parser = argparse.ArgumentParser(description="SEAL") 12 | parser.add_argument("-n", "--name", type=str, default='no_name') 13 | parser.add_argument("-g", "--graph_path", type=str, default=None) 14 | parser.add_argument("--pre_trained_model_path", type=str, default=None) 15 | # Subgraphs in SEAL 16 | parser.add_argument("--num_hops", type=int, default=1) 17 | parser.add_argument("--ratio_per_hop", type=float, default=1.0) 18 | parser.add_argument("--max_nodes_per_hop", type=int, default=800) 19 | parser.add_argument("--node_label", type=str, default='drnl') 20 | # Datasplits 21 | parser.add_argument("--neg_pos_ratio_test", type=float, default=1.0) 22 | parser.add_argument("--neg_pos_ratio", type=float, default=1.0) 23 | parser.add_argument("--train_fraction", type=float, default=1.0) 24 | parser.add_argument("--splitting", type=str, default='random') 25 | parser.add_argument("--valid_fold", type=int, default=1) 26 | parser.add_argument("--fraction_dist_neg", type=float, default=0.5) 27 | parser.add_argument("--seed", type=int, default=100) 28 | parser.add_argument("--include_in_train", type=str, default=None) 29 | parser.add_argument("--mode", type=str, default='normal') 30 | # Dataloaders of size and number of threads 31 | parser.add_argument("-bs", "--batch_size", type=int, default=256) 32 | parser.add_argument("--num_workers", type=int, default=6) 33 | # NN hyperparameters 34 | parser.add_argument("--model", type=str, default='DGCNN') 35 | parser.add_argument("--n_epochs", type=int, default=20) 36 | parser.add_argument("-lr", "--learning_rate", type=float, default=0.0005) 37 | parser.add_argument("--decay", type=float, default=0.855) 38 | parser.add_argument("--dropout", type=float, default=0.517) 39 | parser.add_argument("--n_runs", type=int, default=1) 40 | # SEAL training hyperparameters 41 | parser.add_argument("--hidden_channels", type=int, default=128) 42 | parser.add_argument("--num_layers", type=int, default=6) 43 | parser.add_argument("--max_z", type=int, default=1000) 44 | parser.add_argument("--sortpool_k", type=float, default=879) 45 | parser.add_argument("--graph_norm", action="store_true") 46 | parser.add_argument("--batch_norm", action="store_true") 47 | # GAE training hyperparameters 48 | parser.add_argument("--variational", action="store_true") 49 | parser.add_argument("--linear", action="store_true") 50 | parser.add_argument("--out_channels", type=int, default=None) 51 | # Graph options 52 | parser.add_argument("--use_attribute", type=str, default='fingerprint') 53 | parser.add_argument("--use_embedding", action="store_true") 54 | # Classification parameters 55 | parser.add_argument("--p_threshold", type=float, default=0.9) 56 | parser.add_argument("--pos_weight_loss", type=float, default=1.0) 57 | 58 | args = parser.parse_args() 59 | settings = vars(args) 60 | 61 | if settings["model"] == "DGCNN" and "max_nodes_per_hop" not in settings: 62 | settings["max_nodes_per_hop"] = None 63 | 64 | # Determine path of graph based on above settings 65 | assert settings["graph_path"] is not None, "-g --graph_path not provided as input or in settings file" 66 | 67 | trainer = GraphTrainer(settings) 68 | _ = trainer.run(running_test=False, final_test=True) 69 | 70 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/reaction-graph-link-prediction/118acb3b4f2d9afe5c34a1a132c91ad1b8c021d5/models/__init__.py -------------------------------------------------------------------------------- /models/autoencoder.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch_geometric.nn import GCNConv 4 | import torch.nn.functional as F 5 | 6 | 7 | class GCNEncoder(torch.nn.Module): 8 | def __init__(self, in_channels, out_channels, seed, dropout): 9 | super(GCNEncoder, self).__init__() 10 | self.dropout = dropout 11 | # Fix random seed 12 | torch.manual_seed(seed) 13 | self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True) 14 | self.conv2 = GCNConv(2 * out_channels, out_channels, cached=True) 15 | 16 | def forward(self, x, edge_index): 17 | x = F.dropout(x, p=self.dropout, training=self.training) 18 | x = self.conv1(x, edge_index).relu() 19 | return self.conv2(x, edge_index) 20 | 21 | 22 | class VariationalGCNEncoder(torch.nn.Module): 23 | def __init__(self, in_channels, out_channels, seed, dropout): 24 | super(VariationalGCNEncoder, self).__init__() 25 | self.dropout = dropout 26 | # Fix random seed 27 | torch.manual_seed(seed) 28 | self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True) 29 | self.conv_mu = GCNConv(2 * out_channels, out_channels, cached=True) 30 | self.conv_logstd = GCNConv(2 * out_channels, out_channels, cached=True) 31 | 32 | def forward(self, x, edge_index): 33 | x = F.dropout(x, p=self.dropout, training=self.training) 34 | x = self.conv1(x, edge_index).relu() 35 | return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index) 36 | 37 | 38 | class LinearEncoder(torch.nn.Module): 39 | def __init__(self, in_channels, out_channels, seed, dropout): 40 | super(LinearEncoder, self).__init__() 41 | self.dropout = dropout 42 | # Fix random seed 43 | torch.manual_seed(seed) 44 | self.conv = GCNConv(in_channels, out_channels, cached=True) 45 | 46 | def forward(self, x, edge_index): 47 | x = F.dropout(x, p=self.dropout, training=self.training) 48 | return self.conv(x, edge_index) 49 | 50 | 51 | class VariationalLinearEncoder(torch.nn.Module): 52 | def __init__(self, in_channels, out_channels, seed, dropout): 53 | super(VariationalLinearEncoder, self).__init__() 54 | self.dropout = dropout 55 | # Fix random seed 56 | torch.manual_seed(seed) 57 | self.conv_mu = GCNConv(in_channels, out_channels, cached=True) 58 | self.conv_logstd = GCNConv(in_channels, out_channels, cached=True) 59 | 60 | def forward(self, x, edge_index): 61 | x = F.dropout(x, p=self.dropout, training=self.training) 62 | return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index) 63 | 64 | -------------------------------------------------------------------------------- /models/dgcnn.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.nn import ModuleList, Linear, Conv1d, MaxPool1d, Embedding, BatchNorm1d 6 | from torch_geometric.nn import GCNConv, global_sort_pool, LayerNorm 7 | 8 | 9 | class DGCNN(torch.nn.Module): 10 | """An end-to-end deep learning architecture for graph classification, AAAI-18.""" 11 | 12 | def __init__( 13 | self, 14 | hidden_channels, 15 | num_layers, 16 | max_z, 17 | k=0.6, 18 | train_dataset=None, 19 | dynamic_train=False, 20 | GNN=GCNConv, 21 | use_feature=False, 22 | node_embedding=None, 23 | graph_norm=False, 24 | batch_norm=False, 25 | dropout=None, 26 | seed=42, 27 | ): 28 | super(DGCNN, self).__init__() 29 | 30 | self.graph_norm = graph_norm 31 | self.batch_norm = batch_norm 32 | self.dropout = dropout 33 | self.use_feature = use_feature 34 | self.node_embedding = node_embedding 35 | 36 | if k <= 1: # Transform percentile to number. 37 | if train_dataset is None: 38 | k = 30 39 | else: 40 | if dynamic_train: 41 | sampled_train = train_dataset[:1000] 42 | else: 43 | sampled_train = train_dataset 44 | num_nodes = sorted([g.num_nodes for g in sampled_train]) 45 | k = num_nodes[int(math.ceil(k * len(num_nodes))) - 1] 46 | k = max(10, k) 47 | self.k = int(k) 48 | self.max_z = max_z 49 | 50 | # Fix random seed 51 | torch.manual_seed(seed) 52 | self.z_embedding = Embedding(self.max_z, hidden_channels) 53 | 54 | self.convs = ModuleList() 55 | initial_channels = hidden_channels 56 | if self.use_feature: 57 | initial_channels += train_dataset.num_features 58 | if self.node_embedding is not None: 59 | initial_channels += node_embedding.embedding_dim 60 | 61 | self.norms = ModuleList() 62 | self.convs.append(GNN(initial_channels, hidden_channels)) 63 | self.norms.append(LayerNorm(hidden_channels)) 64 | for _ in range(0, num_layers - 1): 65 | self.convs.append(GNN(hidden_channels, hidden_channels)) 66 | self.norms.append(LayerNorm(hidden_channels)) 67 | self.convs.append(GNN(hidden_channels, 1)) 68 | self.norms.append(LayerNorm(1)) 69 | 70 | conv1d_channels = [16, 32] 71 | total_latent_dim = hidden_channels * num_layers + 1 72 | conv1d_kws = [total_latent_dim, 5] 73 | self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0]) 74 | self.bn1 = BatchNorm1d(conv1d_channels[0]) 75 | self.maxpool1d = MaxPool1d(2, 2) 76 | self.conv2 = Conv1d(conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1) 77 | self.bn2 = BatchNorm1d(conv1d_channels[1]) 78 | dense_dim = int((self.k - 2) / 2 + 1) 79 | dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1] 80 | if dense_dim < 0: 81 | raise ValueError( 82 | "Negative dimension provided to NN. Increase sortpool_k or decrease hidden_channels and/or num_layers" 83 | ) 84 | self.lin1 = Linear(dense_dim, 128) 85 | self.lin2 = Linear(128, 1) 86 | 87 | def forward(self, data, use_feature, embedding, edge_weight=None): 88 | z = data.z 89 | edge_index = data.edge_index 90 | batch = data.batch 91 | x = data.x if use_feature else None 92 | node_id = data.node_id if embedding else None 93 | 94 | z_emb = self.z_embedding(z) 95 | if z_emb.ndim == 3: # in case z has multiple integer labels 96 | z_emb = z_emb.sum(dim=1) 97 | if self.use_feature and x is not None: 98 | x = torch.cat([z_emb, x.to(torch.float)], 1) 99 | else: 100 | x = z_emb 101 | if self.node_embedding is not None and node_id is not None: 102 | n_emb = self.node_embedding(node_id) 103 | x = torch.cat([x, n_emb], 1) 104 | xs = [x] 105 | 106 | for conv, norm in zip(self.convs, self.norms): 107 | x_tmp = conv(xs[-1], edge_index, edge_weight) 108 | if self.graph_norm: 109 | x_tmp = norm(x_tmp) 110 | xs += [torch.tanh(x_tmp)] 111 | x = torch.cat(xs[1:], dim=-1) 112 | 113 | # Global pooling. 114 | x = global_sort_pool(x, batch, self.k) 115 | x = x.unsqueeze(1) # [num_graphs, 1, k * hidden] 116 | x = self.conv1(x) 117 | x = F.relu(x) 118 | x = self.maxpool1d(x) 119 | x = self.conv2(x) 120 | x = F.relu(x) 121 | x = x.view(x.size(0), -1) # [num_graphs, dense_dim] 122 | 123 | # MLP. 124 | if self.batch_norm: 125 | x = self.bn2(x) 126 | x = F.relu(self.lin1(x)) 127 | x = F.dropout(x, p=self.dropout, training=self.training) # p=0.5 in SEAL 128 | x = self.lin2(x) 129 | 130 | return x 131 | 132 | -------------------------------------------------------------------------------- /predict_links.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import torch 4 | import argparse 5 | import pandas as pd 6 | 7 | from utils.evaluate_model import predict_links 8 | 9 | # Parse arguments 10 | parser = argparse.ArgumentParser(description="Evaluating trained model") 11 | 12 | parser.add_argument("-p", "--model_dir_path", type=str) 13 | parser.add_argument("-s", "--save_path", type=str, default="predictions_current_best.csv") 14 | parser.add_argument("-e", "--edges_path", type=str, default=None) 15 | parser.add_argument("-g", "--graph_path", type=str, default="") 16 | parser.add_argument("-t", "--train_edges_path", type=str, default=None) # filename from data/TRAINING_MODEL_NAME/processed/processed_data.pt expected 17 | parser.add_argument("-l", "--label", type=int, default=None) 18 | parser.add_argument("-w", "--num_workers", type=int, default=None) 19 | 20 | args = parser.parse_args() 21 | 22 | model_dir_path = str(args.model_dir_path) 23 | save_path = str(args.save_path) 24 | edges_path = str(args.edges_path) 25 | graph_path = str(args.graph_path) 26 | 27 | if args.edges_path is not None: 28 | df_edges = pd.read_csv(edges_path) 29 | # edges = torch.tensor([df_edges['Reactant index 1'], df_edges['Reactant index 2']]) 30 | 31 | source_col, target_col = "Source", "Target" #'Reactant index 1', 'Reactant index 2' 32 | df_edges = df_edges.drop_duplicates(subset=[source_col, target_col]) 33 | edges = torch.tensor([df_edges[source_col].values, df_edges[target_col].values]) 34 | if args.label is not None: 35 | y_true = len(df_edges) * [args.label] 36 | df_edges["y true"] = y_true 37 | else: 38 | y_true = df_edges["y true"].values 39 | 40 | # check for overlap between edges and edges used for training the model and remove 41 | # NOTE: does not include the positive edges from the trained model in the check 42 | if args.train_edges_path is not None: 43 | data, _ = torch.load(args.train_edges_path) 44 | train_model_edges = ( 45 | torch.cat((data.neg_edge_rand, data.neg_edge_dist), dim=1) 46 | .int() 47 | .t() 48 | .tolist() 49 | ) 50 | train_model_edges_reversed = [edge[::-1] for edge in train_model_edges] 51 | bidirectional_train_model_edges = train_model_edges + train_model_edges_reversed 52 | edges_list = edges.t().tolist() 53 | 54 | filtered_edges_list, filtered_y_true = [], [] 55 | for edge, y in zip(edges_list, y_true): 56 | if edge not in bidirectional_train_model_edges: 57 | filtered_edges_list.append(edge) 58 | filtered_y_true.append(y) 59 | edges = torch.tensor(filtered_edges_list).t() 60 | overlap_size = len(edges_list) - len(filtered_edges_list) 61 | print( 62 | f"{overlap_size} overlaps between edge list and train edges found and removed." 63 | ) 64 | if len(filtered_edges_list) == 0: 65 | sys.exit("All edges overlap with train edges!") 66 | elif overlap_size > 0: 67 | df_edges = pd.DataFrame( 68 | data=edges.t().tolist(), columns=[source_col, target_col] 69 | ) 70 | df_edges["y true"] = filtered_y_true 71 | 72 | # Evaluate the model 73 | print("Prediction started!") 74 | y_prob, edges = predict_links(model_dir_path, edges=edges, graph_path=graph_path) 75 | print("Prediction done!") 76 | 77 | df_edges["y prob"] = y_prob 78 | df_edges["edge"] = edges 79 | 80 | df_edges.to_csv(save_path) 81 | 82 | -------------------------------------------------------------------------------- /reaction_data/get_negative_reactions_info.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import warnings 4 | import pandas as pd 5 | from datetime import datetime 6 | 7 | with warnings.catch_warnings(): 8 | warnings.filterwarnings("ignore", category=RuntimeWarning) 9 | import graph_tool.all as gt 10 | 11 | from datasets.reaction_graph import ReactionGraph 12 | from settings import settings_reaction_prediction as settings 13 | from utils.negative_sampling import ( 14 | one_against_all, 15 | one_against_most_reactive, 16 | ) 17 | from utils.reactions_info import get_index_to_smiles_dict 18 | 19 | # -------------------------------- Functions -------------------------------- 20 | 21 | 22 | def sample_negative_one_against_all( 23 | fixed_reactant_path, graph_path, neg_pos_ratio, most_reactive_only=False 24 | ): 25 | 26 | # Load Graph 27 | graph = gt.load_graph(graph_path) 28 | 29 | # Load settings 30 | s = settings.settings 31 | s["graph_path"] = graph_path 32 | s["train_fraction"] = 1 33 | s["neg_pos_ratio_test"] = 1 34 | s["neg_pos_ratio"] = neg_pos_ratio 35 | s["fraction_dist_neg"] = 0 36 | s["seed"] = 1 37 | s["name"] = f"get_info_{datetime.now()}" 38 | s["valid_fold"] = 4 39 | 40 | # Create ReactionGraph dataset 41 | eln = ReactionGraph( 42 | f"data/get_info_{datetime.now()}", settings.settings, seed_neg_sampling=0 43 | ) 44 | 45 | i_pos, j_pos = eln.data.pos_edge 46 | print("num links", len(i_pos)) 47 | print(max(i_pos), max(j_pos)) 48 | 49 | fixed_reactant_df = pd.read_csv(fixed_reactant_path) 50 | fixed_reactant_idx = fixed_reactant_df["node index in graph"].values 51 | 52 | if most_reactive_only: 53 | i_neg, j_neg = one_against_most_reactive( 54 | fixed_reactant_idx, (i_pos, j_pos), all_pos_edges=(i_pos, j_pos), cutoff=2 55 | ) 56 | else: 57 | i_neg, j_neg = one_against_all( 58 | fixed_reactant_idx, 59 | (i_pos, j_pos), 60 | all_pos_edges=(i_pos, j_pos), 61 | include_unconnected=False, 62 | ) 63 | 64 | # Create dictionaries, from gt index to neo4j index and from gt index to smiles 65 | index_smils_dict = get_index_to_smiles_dict(graph) 66 | 67 | for s, i in zip( 68 | fixed_reactant_df["SMILES"].values, 69 | fixed_reactant_df["node index in graph"].values, 70 | ): 71 | index_smils_dict[i] = s 72 | 73 | # Negative edges sampled at random 74 | neg_edge = eln.data.neg_edge_rand 75 | 76 | neg_edge_idx_1 = i_neg.to(int) 77 | neg_edge_idx_2 = j_neg.to(int) 78 | neg_edge_smiles_1 = [index_smils_dict[int(i)] for i in neg_edge_idx_1] 79 | neg_edge_smiles_2 = [index_smils_dict[int(i)] for i in neg_edge_idx_2] 80 | 81 | neg_reactions_df = pd.DataFrame( 82 | { 83 | "Source": neg_edge_idx_1, 84 | "Target": neg_edge_idx_2, 85 | "Reactant smiles 1": neg_edge_smiles_1, 86 | "Reactant smiles 2": neg_edge_smiles_2, 87 | "Type": ["all" for i in range(len(neg_edge_idx_1))], 88 | "y true": 0, 89 | } 90 | ) 91 | return neg_reactions_df, graph 92 | 93 | 94 | def sample_negative_reactions(graph_path, neg_pos_ratio, fraction_dist_neg, seed=1000): 95 | 96 | # Load Graph 97 | graph = gt.load_graph(graph_path) 98 | 99 | # Load settings 100 | s = settings.settings 101 | s["graph_path"] = graph_path 102 | s["train_fraction"] = 1 103 | s["neg_pos_ratio_test"] = 1 104 | s["neg_pos_ratio"] = neg_pos_ratio 105 | s["fraction_dist_neg"] = fraction_dist_neg 106 | s["valid_fold"] = 0 107 | 108 | # Create ReactionGraph dataset 109 | eln_dataset = ReactionGraph( 110 | f"data/get_info_{datetime.now()}", settings.settings, seed_neg_sampling=seed 111 | ) 112 | 113 | # Create dictionaries, from gt index to neo4j index and from gt index to smiles 114 | index_smils_dict = get_index_to_smiles_dict(graph) 115 | 116 | # Negative edges sampled at random 117 | neg_edge_rand = eln_dataset.data.neg_edge_rand 118 | 119 | neg_edge_rand_idx_1 = neg_edge_rand[0, :].to(int) 120 | neg_edge_rand_idx_2 = neg_edge_rand[1, :].to(int) 121 | neg_edge_rand_smiles_1 = [index_smils_dict[int(i)] for i in neg_edge_rand_idx_1] 122 | neg_edge_rand_smiles_2 = [index_smils_dict[int(i)] for i in neg_edge_rand_idx_2] 123 | 124 | neg_edge_rand_smiles_df = pd.DataFrame( 125 | { 126 | "Reactant index 1": neg_edge_rand_idx_1, 127 | "Reactant index 2": neg_edge_rand_idx_2, 128 | "Reactant smiles 1": neg_edge_rand_smiles_1, 129 | "Reactant smiles 2": neg_edge_rand_smiles_2, 130 | "Type": ["random" for i in range(len(neg_edge_rand_idx_1))], 131 | } 132 | ) 133 | 134 | # Negative edges sampled from positive edges distribution 135 | neg_edge_dist = eln_dataset.data.neg_edge_dist 136 | neg_edge_dist_idx_1 = neg_edge_dist[0, :].to(int) 137 | neg_edge_dist_idx_2 = neg_edge_dist[1, :].to(int) 138 | neg_edge_dist_smiles_1 = [index_smils_dict[int(i)] for i in neg_edge_dist_idx_1] 139 | neg_edge_dist_smiles_2 = [index_smils_dict[int(i)] for i in neg_edge_dist_idx_2] 140 | 141 | neg_edge_dist_smiles_df = pd.DataFrame( 142 | { 143 | "Reactant index 1": neg_edge_dist_idx_1, 144 | "Reactant index 2": neg_edge_dist_idx_2, 145 | "Reactant smiles 1": neg_edge_dist_smiles_1, 146 | "Reactant smiles 2": neg_edge_dist_smiles_2, 147 | "Type": ["distribution" for i in range(len(neg_edge_dist_idx_1))], 148 | } 149 | ) 150 | 151 | # Append the two dataframes 152 | neg_reactions_df = neg_edge_rand_smiles_df.append( 153 | neg_edge_dist_smiles_df, ignore_index=True 154 | ) 155 | return neg_reactions_df, graph 156 | 157 | 158 | # -------------------------------- Main -------------------------------- 159 | 160 | 161 | def main(args): 162 | 163 | if args.sample_how == "fixed_reactant": 164 | neg_reactions_df, _ = sample_negative_one_against_all( 165 | args.fixed_reactant_path, 166 | args.graph_path, 167 | args.neg_pos_ratio, 168 | most_reactive_only=False, 169 | ) 170 | elif args.sample_how == "most_reactive": 171 | neg_reactions_df, _ = sample_negative_one_against_all( 172 | args.fixed_reactant_path, 173 | args.graph_path, 174 | args.neg_pos_ratio, 175 | most_reactive_only=True, 176 | ) 177 | else: 178 | neg_reactions_df, _ = sample_negative_reactions( 179 | args.graph_path, 180 | args.neg_pos_ratio, 181 | args.fraction_dist_neg, 182 | args.seed, 183 | ) 184 | 185 | neg_reactions_df.to_csv(args.save_file_path) 186 | 187 | 188 | if __name__ == "__main__": 189 | parser = argparse.ArgumentParser() 190 | 191 | parser.add_argument("--graph_path", type=str) 192 | parser.add_argument("--graph_with_templates_path", type=str) 193 | parser.add_argument("--sample_how", type=str, default="normal") 194 | parser.add_argument("--fixed_reactant_path", type=str, default="") 195 | parser.add_argument("--neg_pos_ratio", type=int, default=0.5) 196 | parser.add_argument("--fraction_dist_neg", type=float) 197 | parser.add_argument("--seed", type=int, default=1000) 198 | parser.add_argument("--reactive_functions_count", type=str) 199 | parser.add_argument("--count_cut_off", type=int, default=1) 200 | parser.add_argument("--save_file_path", type=str) 201 | 202 | args = parser.parse_args() 203 | main(args) 204 | 205 | -------------------------------------------------------------------------------- /reaction_data/get_reactions_csv.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import warnings 4 | import numpy as np 5 | import pandas as pd 6 | 7 | with warnings.catch_warnings(): 8 | warnings.filterwarnings("ignore", category=RuntimeWarning) 9 | import graph_tool.all as gt 10 | 11 | from utils.reactions_info import get_reactants_and_product_index, get_index_to_smiles_dict 12 | 13 | 14 | def get_reactions_csv(args): 15 | graph = gt.load_graph(args.graph_bipartite) 16 | ( 17 | reactant_index, 18 | product_index, 19 | reaction_class, 20 | class_id, 21 | ) = get_reactants_and_product_index(graph) 22 | index_to_smiles = get_index_to_smiles_dict(graph) 23 | 24 | # Check for non-binary reactions 25 | for i, j in zip(reactant_index, product_index): 26 | assert len(i) == 2 27 | assert len(j) == 1 28 | 29 | # Edges 30 | edges = np.array(reactant_index)[:, 0], np.array(reactant_index)[:, 1] 31 | 32 | # Create dataframe with all reactions 33 | reactant_products_df = pd.DataFrame( 34 | { 35 | "Reactant index 1": edges[0], 36 | "Reactant index 2": edges[1], 37 | "Product index": np.array(product_index)[:, 0], 38 | "Reactant smiles 1": [index_to_smiles[i] for i in edges[0]], 39 | "Reactant smiles 2": [index_to_smiles[i] for i in edges[1]], 40 | "Product smiles": [ 41 | index_to_smiles[i] for i in np.array(product_index)[:, 0] 42 | ], 43 | } 44 | ) 45 | # Save dataframe 46 | reactant_products_df.to_csv(args.save_file_name, index=False) 47 | 48 | 49 | if __name__ == "__main__": 50 | # parse arguments 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("-g", "--graph_bipartite", type=str, default=None) 53 | parser.add_argument("-s", "--save_file_name", type=str, default=None) 54 | args = parser.parse_args() 55 | 56 | get_reactions_csv(args) 57 | 58 | -------------------------------------------------------------------------------- /reaction_data/get_smiles_node_degree.py: -------------------------------------------------------------------------------- 1 | 2 | import warnings 3 | import pandas as pd 4 | 5 | with warnings.catch_warnings(): 6 | warnings.filterwarnings("ignore", category=RuntimeWarning) 7 | import graph_tool.all as gt 8 | 9 | 10 | def get_molecule_degree(graph_path, save_file): 11 | 12 | graph = gt.load_graph(graph_path) 13 | 14 | molecules = gt.find_vertex(graph, graph.vertex_properties["labels"], ":Molecule") 15 | molecules_smiles = [graph.vertex_properties["smiles"][m] for m in molecules] 16 | molecules_out_degree = graph.get_total_degrees(graph.get_vertices()) 17 | 18 | df = pd.DataFrame( 19 | {"smiles": molecules_smiles, "total degree": molecules_out_degree} 20 | ) 21 | df = df.sort_values(by="total degree", ascending=False) 22 | 23 | df.to_csv(save_file) 24 | 25 | 26 | save_file = "smiles_total_degree.csv" 27 | graph_path = "graphs/monopartite/molecule_with_product_graph_5Fold.gt" 28 | 29 | get_molecule_degree(graph_path, save_file) 30 | 31 | -------------------------------------------------------------------------------- /settings/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /settings/optuna.py: -------------------------------------------------------------------------------- 1 | {"num_hops": [1, 2], "sortpool_k": [1, 1000], "hidden_channels": [4, 8, 16, 32, 64, 128, 256], "num_layers": [1, 12], "learning_rate": [0.0001, 0.001, 0.01, 0.1], "dropout": [0.0, 0.99], "decay": [0.0, 0.99], "startup_trials": 50} -------------------------------------------------------------------------------- /settings/settings_reaction_prediction.py: -------------------------------------------------------------------------------- 1 | settings = { 2 | # Subgraphs in SEAL 3 | 'num_hops': 1, 4 | 'ratio_per_hop': 1.0, 5 | 'max_nodes_per_hop': 800,#None, 6 | 'node_label': 'drnl', 7 | # Datasplits 8 | 'seed': 100, 9 | 'mode': 'normal',#'increase_negatives','fixed', 10 | 'train_fraction': 1, 11 | 'splitting': 'random', 12 | 'valid_fold': 1, 13 | 'neg_pos_ratio': 1, # How many percent to sample from distribution 14 | 'neg_pos_ratio_test': 1, 15 | 'fraction_dist_neg': 1, # How many percent to sample from distribution 16 | # Dataloaders of size and number of threads 17 | 'batch_size': 256, 18 | 'num_workers': 6, 19 | # GNN hyperparameters 20 | 'model': 'DGCNN', 21 | 'hidden_channels': 128, 22 | 'num_layers': 6, 23 | 'max_z': 1000, 24 | 'sortpool_k': 879, 25 | 'graph_norm': False, 26 | 'batch_norm': False, 27 | 'dropout': 0.517, 28 | 'pre_trained_model_path': None, 29 | # SEAL training hyperparameters 30 | 'n_epochs': 20, 31 | 'learning_rate': 0.0005, 32 | 'decay': 0.855, 33 | 'n_runs': 1, 34 | ##### Graph options ##### 35 | 'use_attribute': 'fingerprint', 36 | 'use_embedding': False, 37 | ##### Evaluation ##### 38 | 'p_threshold': 0.9, 39 | 'pos_weight_loss': 1, 40 | } 41 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/reaction-graph-link-prediction/118acb3b4f2d9afe5c34a1a132c91ad1b8c021d5/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/reaction-graph-link-prediction/118acb3b4f2d9afe5c34a1a132c91ad1b8c021d5/tests/test_dataset/__init__.py -------------------------------------------------------------------------------- /tests/test_dataset/test_overlaps_splits.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import warnings 5 | 6 | from torch_trainer import GraphTrainer 7 | 8 | with warnings.catch_warnings(): 9 | warnings.filterwarnings("ignore", category=RuntimeWarning) 10 | import graph_tool.all as gt 11 | 12 | 13 | def test_split_overlaps(): 14 | splittings = ["random"] 15 | fractions = [0] 16 | seeds = [1, 2] 17 | 18 | for sp in splittings: 19 | for f in fractions: 20 | for se in seeds: 21 | print("split", sp, "fraction", f, "seed", se) 22 | 23 | settings = { 24 | "name": f"test_split_overlap_{sp}_{f}_{f}", 25 | "graph_path": "", # TODO add graph path 26 | "train_fraction": 1, 27 | "splitting": sp, 28 | "valid_fold": 1, 29 | "percent_edges": 2, 30 | "fraction_dist_neg": f, 31 | "neg_pos_ratio": 1, 32 | "neg_pos_ratio_test": 1, 33 | "use_attribute": False, 34 | "p_threshold": 0.5, 35 | "mode": "normal", 36 | "n_runs": 1, 37 | "seed": se, 38 | "pos_weight_loss": 1, 39 | } 40 | 41 | os.system(f"rm -rf data/{settings['name']}") 42 | trainer = GraphTrainer(settings) 43 | eln = trainer.initialize_data(se) 44 | 45 | train_pos = eln.data.split_edge["train"]["pos"] 46 | train_neg = eln.data.split_edge["train"]["neg"] 47 | valid_pos = eln.data.split_edge["valid"]["pos"] 48 | valid_neg = eln.data.split_edge["valid"]["neg"] 49 | test_pos = eln.data.split_edge["test"]["pos"] 50 | test_neg = eln.data.split_edge["test"]["neg"] 51 | 52 | all_edges = torch.cat( 53 | (train_pos, train_neg, valid_pos, valid_neg, test_pos, test_neg), 54 | dim=1, 55 | ) 56 | all_train = torch.cat((train_pos, train_neg), dim=1) 57 | all_test = torch.cat((test_pos, test_neg), dim=1) 58 | all_valid = torch.cat((valid_pos, valid_neg), dim=1) 59 | all_neg = torch.cat((train_neg, valid_neg, test_neg), dim=1) 60 | all_pos = torch.cat((train_pos, valid_pos, test_pos), dim=1) 61 | 62 | assert all_edges.shape == all_edges.unique(dim=0).shape 63 | assert all_train.shape == all_train.unique(dim=0).shape 64 | assert all_valid.shape == all_valid.unique(dim=0).shape 65 | assert all_test.shape == all_test.unique(dim=0).shape 66 | assert all_neg.shape == all_neg.unique(dim=0).shape 67 | assert all_pos.shape == all_pos.unique(dim=0).shape 68 | 69 | -------------------------------------------------------------------------------- /tests/test_dataset/test_overlaps_valid_fold.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import warnings 5 | 6 | from torch_trainer import GraphTrainer 7 | 8 | with warnings.catch_warnings(): 9 | warnings.filterwarnings("ignore", category=RuntimeWarning) 10 | import graph_tool.all as gt 11 | 12 | def test_folds(): 13 | splittings = ["random"] # , 'time'] 14 | fractions = [0.1, 0.5, 1] # [0, 0.5, 1] 15 | valid_fold = [1, 2, 3, 4, 5] 16 | 17 | for sp in splittings: 18 | for f in fractions: 19 | train_pos_vfs = [] 20 | train_neg_vfs = [] 21 | valid_pos_vfs = [] 22 | valid_neg_vfs = [] 23 | test_pos_vfs = [] 24 | test_neg_vfs = [] 25 | 26 | for vf in valid_fold: 27 | print("split", sp, "fraction", f, "valid_fold", vf) 28 | 29 | settings = { 30 | "name": f"test_split_overlap_{sp}_{f}_{vf}", 31 | "graph_path": "", # TODO add graph path 32 | "train_fraction": f, 33 | "splitting": sp, 34 | "valid_fold": vf, 35 | "percent_edges": 2, 36 | "fraction_dist_neg": 0, 37 | "neg_pos_ratio": f, 38 | "neg_pos_ratio_test": 1, 39 | "use_attribute": False, 40 | "p_threshold": 0.5, 41 | "mode": "normal", 42 | "n_runs": 1, 43 | "seed": 1, 44 | "pos_weight_loss": 1, 45 | } 46 | 47 | os.system(f"rm -rf data/{settings['name']}") 48 | trainer = GraphTrainer(settings) 49 | eln = trainer.initialize_data(settings["seed"]) 50 | 51 | train_pos_vfs.append(eln.data.split_edge["train"]["pos"]) 52 | train_neg_vfs.append(eln.data.split_edge["train"]["neg"]) 53 | valid_pos_vfs.append(eln.data.split_edge["valid"]["pos"]) 54 | valid_neg_vfs.append(eln.data.split_edge["valid"]["neg"]) 55 | test_pos_vfs.append(eln.data.split_edge["test"]["pos"]) 56 | test_neg_vfs.append(eln.data.split_edge["test"]["neg"]) 57 | 58 | all_pos = torch.cat( 59 | ( 60 | eln.data.split_edge["train"]["pos"], 61 | eln.data.split_edge["valid"]["pos"], 62 | eln.data.split_edge["test"]["pos"], 63 | ), 64 | dim=1, 65 | ) 66 | all_neg = torch.cat( 67 | ( 68 | eln.data.split_edge["train"]["neg"], 69 | eln.data.split_edge["valid"]["neg"], 70 | eln.data.split_edge["test"]["neg"], 71 | ), 72 | dim=1, 73 | ) 74 | all_edges = torch.cat((all_pos, all_neg), dim=1) 75 | 76 | assert all_pos.unique(dim=1).shape == all_pos.shape, ( 77 | all_pos.unique(dim=1).shape, 78 | all_pos.shape, 79 | ) 80 | assert all_neg.unique(dim=1).shape == all_neg.shape, ( 81 | all_neg.unique(dim=1).shape, 82 | all_neg.shape, 83 | ) 84 | assert all_edges.unique(dim=1).shape == all_edges.shape, ( 85 | all_edges.unique(dim=1).shape, 86 | all_edges.shape, 87 | ) 88 | 89 | assert ( 90 | torch.cat(valid_pos_vfs, dim=1).unique(dim=1).shape 91 | == torch.cat(valid_pos_vfs, dim=1).shape 92 | ) 93 | assert torch.cat(test_pos_vfs, dim=1).unique(dim=1).shape[1] == int( 94 | torch.cat(test_pos_vfs, dim=1).shape[1] / len(valid_fold) 95 | ) 96 | assert ( 97 | torch.cat(valid_neg_vfs, dim=1).unique(dim=1).shape 98 | == torch.cat(valid_neg_vfs, dim=1).shape 99 | ) 100 | assert torch.cat(test_neg_vfs, dim=1).unique(dim=1).shape[1] == int( 101 | torch.cat(test_neg_vfs, dim=1).shape[1] / len(valid_fold) 102 | ) 103 | 104 | -------------------------------------------------------------------------------- /tests/test_dataset/test_reproducible_dataset.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torch_trainer import GraphTrainer 5 | 6 | 7 | @pytest.fixture 8 | def settings(): 9 | settings_dict = { 10 | "name": "test", 11 | "graph_path": "", # TODO add graph path 12 | "train_fraction": 1, 13 | "splitting": "random", 14 | "mode": "normal", 15 | "valid_fold": 1, 16 | "neg_pos_ratio": 1, 17 | "neg_pos_ratio_test": 1, 18 | "fraction_dist_neg": 0, 19 | "use_attribute": True, 20 | "p_threshold": 0.5, 21 | "n_runs": 1, 22 | "seed": 1, 23 | "pos_weight_loss": 1, 24 | } 25 | 26 | return settings_dict 27 | 28 | 29 | def test_variable_seed_sampling(settings): 30 | 31 | s = settings 32 | # s.sampling_factor = 1 33 | # s.fraction_dost_neg = 0.5 34 | print("settings", type(s), s) 35 | 36 | all_rand_neg_edges = [] 37 | all_dist_neg_edges = [] 38 | for i in range(4): 39 | trainer = GraphTrainer(settings) 40 | eln = trainer.initialize_data(i) 41 | 42 | all_rand_neg_edges.append(eln.data.neg_edge_rand) 43 | all_dist_neg_edges.append(eln.data.neg_edge_dist) 44 | 45 | all_rand_neg_edges = torch.cat(all_rand_neg_edges, dim=1) 46 | all_dist_neg_edges = torch.cat(all_dist_neg_edges, dim=1) 47 | 48 | if not eln.data.neg_edge_rand.shape[1] == 0: 49 | assert eln.data.neg_edge_rand.shape < all_rand_neg_edges.unique(dim=1).shape 50 | 51 | if not eln.data.neg_edge_dist.shape[1] == 0: 52 | assert eln.data.neg_edge_dist.shape < all_dist_neg_edges.unique(dim=1).shape 53 | 54 | 55 | def test_fixed_seed_sampling(settings): 56 | 57 | # s = settings 58 | # s.sampling_factor = 1 59 | # s.fraction_dost_neg = 0.5 60 | 61 | all_rand_neg_edges = [] 62 | all_dist_neg_edges = [] 63 | for i in range(4): 64 | trainer = GraphTrainer(settings) 65 | eln = trainer.initialize_data(1) 66 | 67 | all_rand_neg_edges.append(eln.data.neg_edge_rand) 68 | all_dist_neg_edges.append(eln.data.neg_edge_dist) 69 | 70 | all_rand_neg_edges = torch.cat(all_rand_neg_edges, dim=1) 71 | all_dist_neg_edges = torch.cat(all_dist_neg_edges, dim=1) 72 | 73 | if not eln.data.neg_edge_rand.shape[1] == 0: 74 | assert eln.data.neg_edge_rand.shape == all_rand_neg_edges.unique(dim=1).shape 75 | 76 | if not eln.data.neg_edge_dist.shape[1] == 0: 77 | assert eln.data.neg_edge_dist.shape == all_dist_neg_edges.unique(dim=1).shape 78 | 79 | 80 | def test_seed_init_model(): 81 | pass 82 | 83 | -------------------------------------------------------------------------------- /tests/test_dataset/test_split_distribution.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torch_trainer import GraphTrainer 5 | 6 | 7 | @pytest.fixture 8 | def settings(): 9 | settings_dict = { 10 | "name": "test", 11 | "graph_path": "", # TODO add graph path 12 | "train_fraction": 1, 13 | "splitting": "random", 14 | "mode": "normal", 15 | "valid_fold": 1, 16 | "neg_pos_ratio": 1, 17 | "neg_pos_ratio_test": 1, 18 | "fraction_dist_neg": 0.5, 19 | "use_attribute": True, 20 | "p_threshold": 0.5, 21 | "n_runs": 1, 22 | "seed": 1, 23 | "pos_weight_loss": 1, 24 | } 25 | 26 | return settings_dict 27 | 28 | 29 | def test_negative_sampling_distributions(settings): 30 | 31 | all_rand_neg_edges = [] 32 | all_dist_neg_edges = [] 33 | 34 | for fraction in (0, 0.5, 1): 35 | settings["fraction_dist_neg"] = fraction 36 | 37 | trainer = GraphTrainer(settings) 38 | eln = trainer.initialize_data(42) 39 | 40 | all_rand_neg_edges = eln.data.neg_edge_rand 41 | all_dist_neg_edges = eln.data.neg_edge_dist 42 | 43 | train_neg = eln.data.split_edge["train"]["neg"] 44 | valid_neg = eln.data.split_edge["valid"]["neg"] 45 | test_neg = eln.data.split_edge["test"]["neg"] 46 | 47 | train_and_rand = torch.cat((train_neg, all_rand_neg_edges), dim=1) 48 | unique_train_and_rand = train_and_rand.unique(dim=1) 49 | 50 | valid_and_rand = torch.cat((valid_neg, all_rand_neg_edges), dim=1) 51 | unique_valid_and_rand = valid_and_rand.unique(dim=1) 52 | 53 | test_and_rand = torch.cat((test_neg, all_rand_neg_edges), dim=1) 54 | unique_test_and_rand = test_and_rand.unique(dim=1) 55 | 56 | train_and_dist = torch.cat((train_neg, all_dist_neg_edges), dim=1) 57 | unique_train_and_dist = train_and_dist.unique(dim=1) 58 | 59 | valid_and_dist = torch.cat((valid_neg, all_dist_neg_edges), dim=1) 60 | unique_valid_and_dist = valid_and_dist.unique(dim=1) 61 | 62 | test_and_dist = torch.cat((test_neg, all_dist_neg_edges), dim=1) 63 | unique_test_and_dist = test_and_dist.unique(dim=1) 64 | 65 | if fraction == 0: 66 | assert len(unique_train_and_rand[0]) == len(all_rand_neg_edges[0]) 67 | assert len(unique_valid_and_rand[0]) == len(all_rand_neg_edges[0]) 68 | assert len(unique_test_and_rand[0]) == len(all_rand_neg_edges[0]) 69 | assert len(unique_train_and_dist[0]) == len(all_dist_neg_edges[0]) + len( 70 | train_neg[0] 71 | ) 72 | assert len(unique_valid_and_dist[0]) == len(all_dist_neg_edges[0]) + len( 73 | valid_neg[0] 74 | ) 75 | assert len(unique_test_and_dist[0]) == len(all_dist_neg_edges[0]) + len( 76 | test_neg[0] 77 | ) 78 | 79 | if fraction == 1: 80 | assert len(unique_train_and_dist[0]) == len(all_dist_neg_edges[0]) 81 | assert len(unique_valid_and_dist[0]) == len(all_dist_neg_edges[0]) 82 | assert len(unique_test_and_dist[0]) == len(all_dist_neg_edges[0]) 83 | assert len(unique_train_and_rand[0]) == len(all_rand_neg_edges[0]) + len( 84 | train_neg[0] 85 | ) 86 | assert len(unique_valid_and_rand[0]) == len(all_rand_neg_edges[0]) + len( 87 | valid_neg[0] 88 | ) 89 | assert len(unique_test_and_rand[0]) == len(all_rand_neg_edges[0]) + len( 90 | test_neg[0] 91 | ) 92 | 93 | if fraction == 0.5: 94 | assert ( 95 | len(unique_train_and_dist[0]) 96 | > len(all_dist_neg_edges[0]) + (len(train_neg[0]) / 2) * 0.9 97 | ) 98 | assert ( 99 | len(unique_valid_and_dist[0]) 100 | > len(all_dist_neg_edges[0]) + (len(valid_neg[0]) / 2) * 0.9 101 | ) 102 | assert ( 103 | len(unique_test_and_dist[0]) 104 | > len(all_dist_neg_edges[0]) + (len(test_neg[0]) / 2) * 0.9 105 | ) 106 | assert ( 107 | len(unique_train_and_rand[0]) 108 | < len(all_rand_neg_edges[0]) + (len(train_neg[0]) / 2) * 1.1 109 | ) 110 | assert ( 111 | len(unique_valid_and_rand[0]) 112 | < len(all_rand_neg_edges[0]) + (len(valid_neg[0]) / 2) * 1.1 113 | ) 114 | assert ( 115 | len(unique_test_and_rand[0]) 116 | < len(all_rand_neg_edges[0]) + (len(test_neg[0]) / 2) * 1.1 117 | ) 118 | -------------------------------------------------------------------------------- /torch_trainer.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import os.path as osp 4 | import time 5 | import logging 6 | import warnings 7 | from collections import namedtuple 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import seaborn as sns 12 | from tqdm import tqdm 13 | import matplotlib.pyplot as plt 14 | 15 | from sklearn.calibration import calibration_curve 16 | from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve 17 | from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score 18 | from scipy.sparse import SparseEfficiencyWarning 19 | 20 | import torch 21 | from torch.nn import BCEWithLogitsLoss, Embedding 22 | from torch_geometric.data import DataLoader 23 | from torch_geometric.nn import GAE, VGAE 24 | from torch_geometric.utils import to_undirected 25 | 26 | from datasets.reaction_graph import ReactionGraph 27 | from datasets.seal import SEALDynamicDataset 28 | from datasets.GAE import GeneralDataset 29 | from models.dgcnn import DGCNN 30 | from models.autoencoder import ( 31 | GCNEncoder, 32 | VariationalGCNEncoder, 33 | LinearEncoder, 34 | VariationalLinearEncoder, 35 | ) 36 | from utils.metrics import mean_average_precision, hitsK 37 | 38 | warnings.simplefilter("ignore", SparseEfficiencyWarning) 39 | 40 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 41 | 42 | 43 | class GraphTrainer: 44 | """Class for handling training SEAL or GAE model for link prediction in reaction graph.""" 45 | 46 | def __init__(self, settings): 47 | super(GraphTrainer, self).__init__() 48 | 49 | self.settings = settings 50 | self.n_runs = None 51 | self.run_it = 1 52 | self.running_best = { 53 | "Loss": {"Run": None, "Epoch": None, "Score": 100.0}, 54 | "AUC": {"Run": None, "Epoch": None, "Score": 0.0}, 55 | "last_epoch": {"Run": None, "Epoch": None}, 56 | } 57 | # Evaluation metrics 58 | self.metrics = { 59 | "score": {"AUC": roc_auc_score, "AP": average_precision_score}, 60 | # Metrics requiring a threshhold 61 | "prediction": { 62 | "Accuracy": accuracy_score, 63 | "Recall": recall_score, 64 | "Precision": precision_score, 65 | "F1": f1_score, 66 | }, 67 | } 68 | # Place to store running and test scores 69 | self.scores = { 70 | "running": pd.DataFrame({}), # run, epoch, score, metric, split 71 | "test": pd.DataFrame({}), 72 | } # run, epoch, best on valid loss/auc score, metric 73 | self.predictions = pd.DataFrame({}) 74 | # Store all related results 75 | trainer_id = time.strftime("%Y%m%d%H%M%S") 76 | self.res_dir = osp.join(f"results/{settings['name']}_{trainer_id}") 77 | 78 | if not osp.exists(self.res_dir): 79 | os.makedirs(self.res_dir) 80 | else: 81 | logging.warning("Warning: results will overwrite old files.") 82 | 83 | # Logging settings 84 | FORMAT = "%(asctime)s : %(levelname)s:%(module)s : %(funcName)s : %(message)s" 85 | DATEFMT = "%d/%b/%Y %H:%M:%S" 86 | logging.basicConfig( 87 | filename=osp.join(self.res_dir, "torch_trainer.log"), 88 | level=logging.DEBUG, 89 | format=FORMAT, 90 | datefmt=DATEFMT, 91 | ) 92 | logging.info("Log file for trainer: %s_%s", settings["name"], trainer_id) 93 | logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING) 94 | logging.getLogger("matplotlib.axes._base").setLevel(logging.WARNING) 95 | console = logging.StreamHandler() 96 | console.setLevel(logging.WARNING) 97 | formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s") 98 | console.setFormatter(formatter) 99 | logging.getLogger("").addHandler(console) 100 | 101 | # set random seed for reproducibility 102 | torch.manual_seed(self.settings["seed"]) 103 | np.random.seed(self.settings["seed"]) 104 | 105 | # Save settings 106 | df_settings = pd.DataFrame.from_dict( 107 | self.settings, orient="index", columns=["Settings"] 108 | ) 109 | df_settings.to_csv(osp.join(self.res_dir, "settings.csv")) 110 | 111 | def initialize_data(self, run_it): 112 | """Creates the an instance of ReactionGraph dataset class, considering given settings.""" 113 | 114 | logging.info( 115 | "mode is: %s, valid %s, seed %s", 116 | self.settings["mode"], 117 | self.settings["valid_fold"], 118 | self.settings["seed"], 119 | ) 120 | if ( 121 | self.settings["mode"] == "cross_validation" 122 | or self.settings["mode"] == "cross_validation_5" 123 | ): 124 | if self.settings["splitting"] == "random": 125 | self.settings["valid_fold"] = run_it - 1 126 | else: 127 | self.settings["valid_fold"] = run_it 128 | # for cross-validation only positives will change within the split while negatives will not 129 | seed_neg_sampling = self.settings["seed"] 130 | 131 | elif self.settings["mode"] == "normal": 132 | # for run_it = 1 sampling will match the one when cross_valiation is done 133 | seed_neg_sampling = self.settings["seed"] + run_it - 1 134 | 135 | logging.info( 136 | "Initializing run %d. # %d fold is used as validation set. \ 137 | Neg sampling seed = %d.", 138 | run_it, 139 | self.settings["valid_fold"], 140 | seed_neg_sampling, 141 | ) 142 | 143 | is_data = osp.isdir(f'data/{self.settings["name"]}') 144 | if is_data: 145 | os.system(f"rm -rf data/{self.settings['name']}") 146 | logging.info( 147 | "ReactionGraph is initialized. Saved in 'data/%s'.", 148 | self.settings["name"], 149 | ) 150 | 151 | reaction_graph = ReactionGraph( 152 | f'data/{self.settings["name"]}', self.settings, seed_neg_sampling 153 | ) 154 | self.settings["num_nodes"] = reaction_graph.num_nodes 155 | self.settings["features"] = reaction_graph.data.x 156 | reaction_graph.process_splits() 157 | 158 | return reaction_graph 159 | 160 | def make_data_splits(self, reaction_graph): 161 | """Creating the datasets and dataloaders for train, validation and test splits.""" 162 | 163 | # Check for malicious settings 164 | if self.settings["valid_fold"] > 9 and self.settings["splitting"] == "random": 165 | logging.error( 166 | 'Validation fold should be in range (0,9) when \ 167 | splitting == "random".' 168 | ) 169 | elif ( 170 | self.settings["valid_fold"] > 9 171 | or self.settings["valid_fold"] < 1 172 | and self.settings["splitting"] == "time" 173 | ): 174 | logging.error( 175 | 'Validation fold should be in range (1,9) when splitting == "time".' 176 | ) 177 | 178 | if ( 179 | self.settings["fraction_dist_neg"] > 1 180 | or self.settings["fraction_dist_neg"] < 0 181 | ): 182 | logging.error("'fraction_dist_neg' (%f) must be between 0 and 1.") 183 | 184 | # Train on train and validation data 185 | if "include_in_train" in self.settings.keys(): 186 | include_in_train = self.settings["include_in_train"] 187 | 188 | train_pos_edge = ( 189 | reaction_graph.data.split_edge["train"]["pos"].detach().clone() 190 | ) 191 | train_neg_edge = ( 192 | reaction_graph.data.split_edge["train"]["neg"].detach().clone() 193 | ) 194 | valid_pos_edge = ( 195 | reaction_graph.data.split_edge["valid"]["pos"].detach().clone() 196 | ) 197 | valid_neg_edge = ( 198 | reaction_graph.data.split_edge["valid"]["neg"].detach().clone() 199 | ) 200 | test_pos_edge = ( 201 | reaction_graph.data.split_edge["test"]["pos"].detach().clone() 202 | ) 203 | test_neg_edge = ( 204 | reaction_graph.data.split_edge["test"]["neg"].detach().clone() 205 | ) 206 | else: 207 | self.settings["include_in_train"] = "train" 208 | include_in_train = self.settings["include_in_train"] 209 | 210 | if include_in_train == "valid": 211 | logging.info( 212 | "Training is done on train and valid set. \ 213 | \n Test set is used for validation and testing." 214 | ) 215 | # Concat train and valid and assign to train 216 | valid_pos_edge = to_undirected(valid_pos_edge) 217 | reaction_graph.data.split_edge["train"]["pos"] = torch.cat( 218 | (train_pos_edge, valid_pos_edge), dim=1 219 | ) 220 | reaction_graph.data.split_edge["train"]["neg"] = torch.cat( 221 | (train_neg_edge, valid_neg_edge), dim=1 222 | ) 223 | # Assign test to valid, test and valid is then the same 224 | reaction_graph.data.split_edge["valid"]["pos"] = test_pos_edge 225 | reaction_graph.data.split_edge["valid"]["neg"] = test_neg_edge 226 | logging.info("Adding validation edges to train, setting validation to test") 227 | 228 | # Train on train, validation and test data 229 | elif include_in_train == "test": 230 | logging.info( 231 | 'Training is done on all edges. No validation set or test set is used. \ 232 | \n Model is saved at epoch "n_epochs".' 233 | ) 234 | # Concat train, valid and test and assign to train 235 | valid_pos_edge = to_undirected(valid_pos_edge) 236 | test_pos_edge = to_undirected(test_pos_edge) 237 | reaction_graph.data.split_edge["train"]["pos"] = torch.cat( 238 | (train_pos_edge, valid_pos_edge, test_pos_edge), dim=1 239 | ) 240 | reaction_graph.data.split_edge["train"]["neg"] = torch.cat( 241 | (train_neg_edge, valid_neg_edge, test_neg_edge), dim=1 242 | ) 243 | # Assign dummy valid set 244 | reaction_graph.data.split_edge["valid"]["pos"] = torch.tensor([[], []]) 245 | reaction_graph.data.split_edge["valid"]["neg"] = torch.tensor([[], []]) 246 | # Assign dummy test set 247 | reaction_graph.data.split_edge["test"]["pos"] = torch.tensor([[], []]) 248 | reaction_graph.data.split_edge["test"]["neg"] = torch.tensor([[], []]) 249 | logging.info( 250 | "Adding validation and test edges to train, removing validation and test" 251 | ) 252 | 253 | # Split edges and create the dataloaders 254 | splits = ["train", "valid", "test"] 255 | datasets = {} 256 | dataloaders = {} 257 | 258 | for split in splits: 259 | if self.settings["model"] == "DGCNN": 260 | datasets[split] = SEALDynamicDataset( 261 | root="data/SEAL", 262 | dataset=reaction_graph, 263 | settings=self.settings, 264 | split=split, 265 | ) 266 | 267 | dataloaders[split] = DataLoader( 268 | datasets[split], 269 | batch_size=self.settings["batch_size"], 270 | shuffle=(split == "train"), 271 | num_workers=self.settings["num_workers"], 272 | ) 273 | else: 274 | datasets[split] = GeneralDataset( 275 | root="data/general", 276 | dataset=reaction_graph, 277 | settings=self.settings, 278 | split=split, 279 | ) 280 | 281 | if split == "train": 282 | logging.info( 283 | "%s dataset contains %d (%d x2 positive + %d negative) edges.", 284 | split.upper(), 285 | len(datasets[split]), 286 | datasets[split].pos_edge.shape[1] // 2, 287 | datasets[split].neg_edge.shape[1], 288 | ) 289 | else: 290 | logging.info( 291 | "%s dataset contains %d (%d + %d negative) edges.", 292 | split.upper(), 293 | len(datasets[split]), 294 | datasets[split].pos_edge.shape[1], 295 | datasets[split].neg_edge.shape[1], 296 | ) 297 | 298 | return datasets, dataloaders 299 | 300 | def initialize_model(self, datasets): 301 | """Initializeing the DGCNN model, optimizer, learning rate scheduler 302 | and any embeddings. 303 | """ 304 | # Node embeddings 305 | if self.settings["use_embedding"]: 306 | emb = Embedding( 307 | self.settings["num_nodes"], self.settings["hidden_channels"] 308 | ).to(DEVICE) 309 | else: 310 | emb = None 311 | 312 | # Initialize classifier model 313 | torch.manual_seed(self.settings["seed"]) 314 | if self.settings["model"] == "DGCNN": 315 | model = DGCNN( 316 | hidden_channels=self.settings["hidden_channels"], 317 | num_layers=self.settings["num_layers"], 318 | max_z=self.settings["max_z"], 319 | k=self.settings["sortpool_k"], 320 | train_dataset=datasets["train"], 321 | dynamic_train=False, 322 | use_feature=self.settings["use_attribute"], 323 | node_embedding=emb, 324 | graph_norm=self.settings["graph_norm"], 325 | batch_norm=self.settings["batch_norm"], 326 | dropout=self.settings["dropout"], 327 | seed=self.settings["seed"], 328 | ).to(DEVICE) 329 | elif self.settings["model"] == "GAE": 330 | if self.settings["use_attribute"] == True: 331 | num_features = datasets["train"].num_features 332 | else: 333 | num_features = self.settings["num_nodes"] 334 | # num_features = 1 335 | if self.settings["variational"] == False: 336 | if self.settings["linear"] == False: 337 | model = GAE( 338 | GCNEncoder( 339 | in_channels=num_features, 340 | out_channels=self.settings["out_channels"], 341 | seed=self.settings["seed"], 342 | dropout=self.settings["dropout"], 343 | ) 344 | ).to(DEVICE) 345 | else: 346 | model = GAE( 347 | LinearEncoder( 348 | in_channels=num_features, 349 | out_channels=self.settings["out_channels"], 350 | seed=self.settings["seed"], 351 | dropout=self.settings["dropout"], 352 | ) 353 | ).to(DEVICE) 354 | else: 355 | if self.settings["linear"] == False: 356 | model = VGAE( 357 | VariationalGCNEncoder( 358 | in_channels=num_features, 359 | out_channels=self.settings["out_channels"], 360 | seed=self.settings["seed"], 361 | dropout=self.settings["dropout"], 362 | ) 363 | ).to(DEVICE) 364 | else: 365 | model = VGAE( 366 | VariationalLinearEncoder( 367 | in_channels=num_features, 368 | out_channels=self.settings["out_channels"], 369 | seed=self.settings["seed"], 370 | dropout=self.settings["dropout"], 371 | ) 372 | ).to(DEVICE) 373 | 374 | if self.settings["learning_rate"] < 0.01: 375 | logging.info( 376 | "Learning rate %f smaller than suggested 0.01 for GAE", 377 | self.settings["learning_rate"], 378 | ) 379 | else: 380 | logging.error("Model %s not valid option.", self.settings["model"]) 381 | 382 | parameters = list(model.parameters()) 383 | if self.settings["use_embedding"]: 384 | torch.nn.init.xavier_uniform_(emb.weight) 385 | parameters += list(emb.parameters()) 386 | optimizer = torch.optim.Adam( 387 | params=parameters, lr=self.settings["learning_rate"] 388 | ) 389 | 390 | # If provided, load pretrained model 391 | if self.settings["pre_trained_model_path"] is not None: 392 | # Update in how models / optimizers are saved and loaded 393 | if osp.isfile(self.settings["pre_trained_model_path"]): 394 | # try: # after update model saved as dicts 395 | checkpoint = torch.load(self.settings["pre_trained_model_path"]) 396 | model.load_state_dict(checkpoint["model_state_dict"]) 397 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 398 | # except: # before update saved separately 399 | elif osp.isfile( 400 | osp.join( 401 | self.settings["pre_trained_model_path"], 402 | "best_AUC_model_checkpoint.pth", 403 | ) 404 | ): 405 | model.load_state_dict( 406 | torch.load( 407 | osp.join( 408 | self.settings["pre_trained_model_path"], 409 | "best_AUC_model_checkpoint.pth", 410 | ) 411 | ), 412 | strict=False 413 | ) 414 | else: 415 | logging.error("Cannot load pre-trained_model. Check path in settings.") 416 | 417 | lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( 418 | optimizer=optimizer, gamma=self.settings["decay"] 419 | ) 420 | 421 | if self.run_it == 1: 422 | total_params = sum(p.numel() for param in parameters for p in param) 423 | logging.info("Total number of parameters is %d", total_params) 424 | 425 | return model, optimizer, lr_scheduler, emb 426 | 427 | def update_scores(self, evaluation, split, epoch, run_it, best_on): 428 | """Gets scores: loss and evaluation metrics given an evaluation object 429 | and appends to dataframe where this informaion is stored. 430 | """ 431 | 432 | scores_df = pd.DataFrame({}) 433 | Score = namedtuple("Score", "metric score split run_it epoch") 434 | 435 | y_prob = torch.sigmoid(evaluation.y_prob) 436 | y_true = evaluation.y_true 437 | links = evaluation.links 438 | 439 | # Add loss 440 | score = Score("Loss", evaluation.loss, split, run_it, epoch) 441 | scores_df = add_score(scores_df, score) 442 | 443 | # Add metrics 444 | for metric, func in self.metrics["score"].items(): 445 | score = Score(metric, func(y_true, y_prob), split, run_it, epoch) 446 | scores_df = add_score(scores_df, score) 447 | 448 | if self.settings["p_threshold"] == "roc": 449 | fpr, tpr, thresholds = roc_curve(y_true, y_prob) 450 | max_roc_cutoff = thresholds[(tpr - fpr).argmax()] 451 | y_pred = (y_prob > max_roc_cutoff).int() 452 | else: 453 | y_pred = (y_prob > self.settings["p_threshold"]).int() 454 | for metric, func in self.metrics["prediction"].items(): 455 | if metric == "Precision": 456 | score = Score( 457 | metric, func(y_true, y_pred, zero_division=0), split, run_it, epoch 458 | ) 459 | else: 460 | score = Score(metric, func(y_true, y_pred), split, run_it, epoch) 461 | 462 | scores_df = add_score(scores_df, score) 463 | 464 | TP = float(sum(y_pred[y_true == 1]) / sum(y_true == 1)) 465 | FP = float(sum(y_pred[y_true == 0]) / sum(y_true == 0)) 466 | TN = 1 - FP 467 | FN = 1 - TP 468 | rates = {"TPR": TP, "FPR": FP, "TNR": TN, "FNR": FN} 469 | for metric, rate in rates.items(): 470 | score = Score(metric, rate, split, run_it, epoch) 471 | scores_df = add_score(scores_df, score) 472 | 473 | if split == "train" or split == "valid": 474 | self.scores["running"] = self.scores["running"].append( 475 | scores_df, ignore_index=True 476 | ) 477 | else: 478 | # HITS @ k 479 | if len(y_pred) > 100: 480 | for k in [20, 50, 100]: 481 | score = Score( 482 | f"HITS@{k}", hitsK(y_true, y_pred, k), split, run_it, epoch 483 | ) 484 | scores_df = add_score(scores_df, score) 485 | else: 486 | logging.warning("Too few datapoints for calculating HITS @ k") 487 | 488 | # MAP 489 | mean_ap, _, _ = mean_average_precision(y_true, y_pred, links) 490 | score = Score("MAP", mean_ap, split, run_it, epoch) 491 | scores_df = add_score(scores_df, score) 492 | 493 | scores_df = scores_df[["Run", "Epoch", "Metric", "Score", "Split"]] 494 | scores_df["Best on"] = best_on 495 | 496 | self.scores["test"] = self.scores["test"].append( 497 | scores_df, ignore_index=True 498 | ) 499 | 500 | def train(self, model, dataloaders, datasets, loss_func, optimizer, emb): 501 | """Training function.""" 502 | 503 | Evaluation = namedtuple("Evaluation", "loss auc y_true y_prob links") 504 | model.train() 505 | total_loss = 0 506 | y_score, y_true = [], [] 507 | 508 | if self.settings["model"] == "DGCNN": 509 | for data in dataloaders["train"]: 510 | data = data.to(DEVICE) 511 | optimizer.zero_grad() 512 | logits = model( 513 | data, use_feature=self.settings["use_attribute"], embedding=emb 514 | ) 515 | y = data.y.to(torch.float) 516 | loss = loss_func(logits.view(-1).to(torch.float), y) 517 | 518 | loss.backward() 519 | optimizer.step() 520 | 521 | total_loss += loss.item() * data.num_graphs 522 | y_score.append(logits.view(-1).detach().cpu()) 523 | y_true.append(y.detach().cpu()) 524 | 525 | elif self.settings["model"] == "GAE": 526 | optimizer.zero_grad() 527 | if self.settings["use_attribute"]: 528 | x = self.settings["features"].to(torch.float).to(DEVICE) 529 | else: 530 | x = torch.eye(self.settings["num_nodes"]).to(DEVICE) 531 | # x = torch.ones(self.settings['num_nodes'], 1).to(DEVICE) # should not be used because result depends on scalar value set 532 | pos_neg_edges = torch.cat( 533 | [datasets["train"].pos_edge, datasets["train"].neg_edge], 1 534 | ).to(DEVICE) 535 | pos_edges = datasets["train"].pos_edge.to(DEVICE) 536 | neg_edges = datasets["train"].neg_edge.to(torch.long).to(DEVICE) 537 | z = model.encode(x, datasets["train"].pos_edge.to(DEVICE)) 538 | logits = model.decode(z, pos_neg_edges) 539 | loss = model.recon_loss( 540 | z, pos_edge_index=pos_edges, neg_edge_index=neg_edges 541 | ) 542 | if self.settings["variational"] == True: 543 | loss = loss + (1 / self.settings["num_nodes"]) * model.kl_loss() 544 | 545 | y_true.append(torch.FloatTensor(datasets["train"].labels)) 546 | 547 | loss.backward() 548 | optimizer.step() 549 | 550 | total_loss = loss.item() 551 | y_score.append(logits.view(-1).detach().cpu()) 552 | 553 | y_true, y_score = torch.cat(y_true), torch.cat(y_score) 554 | 555 | total_loss /= len(y_true) 556 | auc = self.metrics["score"]["AUC"](y_true, y_score) 557 | 558 | return Evaluation( 559 | total_loss, auc, y_true.detach().clone(), y_score.detach().clone(), None 560 | ) 561 | 562 | @torch.no_grad() 563 | def evaluate(self, model, split, dataloaders, datasets, loss_func, emb): 564 | """Evaluation function""" 565 | 566 | Evaluation = namedtuple("Evaluation", "loss auc y_true y_prob links") 567 | model.eval() 568 | total_loss = 0 569 | y_score, y_true = [], [] 570 | links = [] 571 | if self.settings["model"] == "DGCNN": 572 | for data in dataloaders[split]: 573 | data = data.to(DEVICE) 574 | logits = model( 575 | data, use_feature=self.settings["use_attribute"], embedding=emb 576 | ) 577 | loss = loss_func(logits.view(-1), data.y.to(torch.float)) 578 | total_loss += loss.item() * data.num_graphs 579 | 580 | y_score.append(logits.view(-1).cpu()) 581 | y_true.append(data.y.view(-1).cpu().to(torch.float)) 582 | links.extend(data.link) 583 | elif self.settings["model"] == "GAE": 584 | if self.settings["use_attribute"]: 585 | x = self.settings["features"].to(torch.float).to(DEVICE) 586 | else: 587 | x = torch.eye(self.settings["num_nodes"]).to(DEVICE) 588 | # x = torch.ones(self.settings['num_nodes'], 1).to(DEVICE) # should not be used because result depends on scalar value set 589 | pos_neg_edges = torch.cat( 590 | [datasets[split].pos_edge, datasets[split].neg_edge], 1 591 | ).to(DEVICE) 592 | pos_edges = datasets[split].pos_edge.to(DEVICE) 593 | neg_edges = datasets[split].neg_edge.to(torch.long).to(DEVICE) 594 | z = model.encode(x, datasets["train"].pos_edge.to(DEVICE)) 595 | logits = model.decode(z, pos_neg_edges) 596 | loss = model.recon_loss( 597 | z, pos_edge_index=pos_edges, neg_edge_index=neg_edges 598 | ) 599 | if self.settings["variational"] == True: 600 | loss = loss + (1 / self.settings["num_nodes"]) * model.kl_loss() 601 | 602 | y_true.append(torch.FloatTensor(datasets[split].labels)) 603 | 604 | total_loss = loss.item() 605 | y_score.append(logits.view(-1).cpu()) 606 | links.extend(datasets[split].links) 607 | 608 | y_prob = torch.sigmoid(torch.cat(y_score)) 609 | y_true = torch.cat(y_true) 610 | 611 | total_loss /= len(y_true) 612 | auc = self.metrics["score"]["AUC"](y_true, y_prob) 613 | 614 | return Evaluation( 615 | total_loss, auc, y_true.detach().clone(), y_prob.detach().clone(), links 616 | ) 617 | 618 | @torch.no_grad() 619 | def predict(self, model, datasets, dataloader, split, emb): 620 | model.eval() 621 | y_score = [] 622 | links = [] 623 | if self.settings["model"] == "DGCNN": 624 | for data in dataloader[split]: 625 | data = data.to(DEVICE) 626 | logits = model( 627 | data, use_feature=self.settings["use_attribute"], embedding=emb 628 | ) 629 | y_score.append(logits.view(-1).cpu()) 630 | links.extend(data.link) 631 | elif self.settings["model"] == "GAE": 632 | if self.settings["use_attribute"]: 633 | x = self.settings["features"].to(torch.float).to(DEVICE) 634 | else: 635 | x = torch.eye(self.settings["num_nodes"]).to(DEVICE) 636 | pos_neg_edges = torch.cat( 637 | [datasets[split].pos_edge, datasets[split].neg_edge], 1 638 | ).to(DEVICE) 639 | pos_edges = datasets[split].pos_edge.to(DEVICE) 640 | neg_edges = datasets[split].neg_edge.to(torch.long).to(DEVICE) 641 | z = model.encode(x, datasets["train"].pos_edge.to(DEVICE)) 642 | logits = model.decode(z, pos_neg_edges) 643 | links.extend(datasets[split].links) 644 | y_score.append(logits.view(-1).cpu()) 645 | 646 | y_prob = torch.sigmoid(torch.cat(y_score)) 647 | 648 | return y_prob, links 649 | 650 | def run(self, running_test=True, final_test=False): 651 | """Full training process.""" 652 | start = time.time() 653 | 654 | logging.info( 655 | "Starting training process on %s, results will be saved in %s.", 656 | DEVICE, 657 | self.res_dir, 658 | ) 659 | 660 | # Settings depending on training mode 661 | assert self.settings["mode"] in [ 662 | "normal", 663 | "cross_validation", 664 | "cross_validation_5", 665 | ], "'mode' setting invalid. Use: 'normal', 'cross_validation', 'cross_validation_5'" 666 | 667 | assert self.settings["splitting"] in [ 668 | "time", 669 | "random", 670 | ], "'splitting' setting invalid. Chose from: 'time', 'random'" 671 | 672 | if self.settings["mode"] == "cross_validation": 673 | if self.settings["splitting"] == "random": 674 | self.n_runs = 10 675 | logging.debug("Random split and cross validation.") 676 | elif self.settings["splitting"] == "time": 677 | self.n_runs = 9 678 | logging.debug("Time split and cross validation.") 679 | elif self.settings["mode"] == "cross_validation_5": 680 | if self.settings["splitting"] == "random": 681 | self.n_runs = 5 682 | logging.debug("Random split and cross validation.") 683 | elif self.settings["splitting"] == "time": 684 | self.n_runs = 4 685 | logging.debug("Time split and cross validation.") 686 | elif self.settings["mode"] == "normal": 687 | self.n_runs = self.settings["n_runs"] 688 | else: 689 | logging.error("Training mode %s not a valid option.", self.settings["mode"]) 690 | 691 | average_valid_auc = 0 692 | best_test_auc = 0 693 | runs_range = range(1, self.n_runs + 1) 694 | # Main training loop 695 | for run_it in tqdm(runs_range): 696 | self.run_it = run_it 697 | reaction_graph = self.initialize_data(self.run_it) 698 | datasets, dataloaders = self.make_data_splits(reaction_graph) 699 | model, optimizer, lr_scheduler, emb = self.initialize_model(datasets) 700 | loss_func = BCEWithLogitsLoss( 701 | pos_weight=torch.tensor(self.settings["pos_weight_loss"]) 702 | ) 703 | 704 | if self.settings["model"] == "DGCNN": 705 | logging.debug( 706 | "DGCNN use k=%f for a sortpool_k of %f.", 707 | model.k, 708 | self.settings["sortpool_k"], 709 | ) 710 | 711 | best_in_run = { 712 | "Loss": {"Epoch": None, "Score": 100.0}, 713 | "AUC": {"Epoch": None, "Score": 0.0}, 714 | } 715 | 716 | for epoch in tqdm(range(1, self.settings["n_epochs"] + 1)): 717 | 718 | evaluation_train = self.train( 719 | model, dataloaders, datasets, loss_func, optimizer, emb 720 | ) 721 | self.update_scores(evaluation_train, "train", epoch, run_it, None) 722 | 723 | if ( 724 | len(datasets["valid"]) != 0 725 | ): # True when training is done on all edges 726 | evaluation_valid = self.evaluate( 727 | model, "valid", dataloaders, datasets, loss_func, emb 728 | ) 729 | self.update_scores(evaluation_valid, "valid", epoch, run_it, None) 730 | 731 | logging.info( 732 | "Epoch: %02d, Train Loss: %.4f, Valid Loss: %.4f, Valid AUC: %.4f", 733 | epoch, 734 | evaluation_train.loss, 735 | evaluation_valid.loss, 736 | evaluation_valid.auc, 737 | ) 738 | 739 | if epoch == 1: 740 | df_valid_set = pd.DataFrame( 741 | { 742 | "Source": [e[0] for e in evaluation_valid.links], 743 | "Target": [e[1] for e in evaluation_valid.links], 744 | "y true": evaluation_valid.y_true, 745 | } 746 | ) 747 | # Save validation set for each fold 748 | if self.settings['mode'] == 'cross_validation_5': 749 | df_valid_set.to_csv( 750 | osp.join(self.res_dir, "validation_set_valid-fold=" + str(self.settings['valid_fold']) + ".csv") 751 | ) 752 | 753 | elif run_it == 1: 754 | df_valid_set.to_csv( 755 | osp.join(self.res_dir, "validation_set.csv") 756 | ) 757 | 758 | # Save checkpoints if validation performance has improved 759 | if ( 760 | self.settings["include_in_train"] != "test" 761 | and self.settings["include_in_train"] != "valid" 762 | ): 763 | for metric, score in ( 764 | ("Loss", evaluation_valid.loss), 765 | ("AUC", evaluation_valid.auc), 766 | ): 767 | 768 | higher_score = score > best_in_run[metric]["Score"] 769 | better = ( 770 | not higher_score if metric == "Loss" else higher_score 771 | ) 772 | 773 | if ( 774 | better 775 | ): # and (self.settings['include_in_train'] != 'test' and self.settings['include_in_train'] != 'valid'): 776 | 777 | best_in_run[metric] = {"Epoch": epoch, "Score": score} 778 | 779 | save_torch_model( 780 | model, 781 | optimizer, 782 | emb, 783 | epoch, 784 | self.res_dir 785 | + f"/tmp_best_{metric}_model_checkpoint.pth", 786 | ) 787 | 788 | if self.settings["mode"] == "cross_validation_5": 789 | save_torch_model( 790 | model, 791 | optimizer, 792 | emb, 793 | epoch, 794 | self.res_dir 795 | + f"/best_{metric}_model_checkpoint_iteration={str(run_it)}.pth", 796 | ) 797 | 798 | higher_score = ( 799 | score > self.running_best[metric]["Score"] 800 | ) 801 | global_better = ( 802 | not higher_score 803 | if metric == "Loss" 804 | else higher_score 805 | ) 806 | 807 | if global_better: 808 | self.running_best[metric] = { 809 | "Run": run_it, 810 | "Epoch": epoch, 811 | "Score": score, 812 | } 813 | save_torch_model( 814 | model, 815 | optimizer, 816 | emb, 817 | epoch, 818 | self.res_dir 819 | + f"/best_{metric}_model_checkpoint.pth", 820 | ) 821 | 822 | # save model after last epoch for first run 823 | if run_it == 1: 824 | save_torch_model( 825 | model, 826 | optimizer, 827 | emb, 828 | epoch, 829 | self.res_dir + "/best_last_epoch_model_checkpoint.pth", 830 | ) 831 | lr_scheduler.step() 832 | average_valid_auc += best_in_run["AUC"]["Score"] 833 | 834 | if final_test and self.settings["include_in_train"] != "test": 835 | metric = "final_model" 836 | evaluation_final_test = self.evaluate( 837 | model, "test", dataloaders, datasets, loss_func, emb 838 | ) 839 | self.update_scores( 840 | evaluation_final_test, "final epoch test", epoch, run_it, metric 841 | ) 842 | df_preds = pd.DataFrame( 843 | { 844 | "Source": [e[0] for e in evaluation_final_test.links], 845 | "Target": [e[1] for e in evaluation_final_test.links], 846 | "y true": evaluation_final_test.y_true, 847 | "y pred": evaluation_final_test.y_prob, 848 | "run": run_it, 849 | "model": "final epoch", 850 | "based on": None, 851 | } 852 | ) 853 | self.predictions = self.predictions.append(df_preds) 854 | 855 | if running_test and ( 856 | self.settings["include_in_train"] != "test" 857 | and self.settings["include_in_train"] != "valid" 858 | ): 859 | # Test current runs best models 860 | for metric in self.running_best: 861 | metric_model_path = ( 862 | self.res_dir + f"/tmp_best_{metric}_model_checkpoint.pth" 863 | ) 864 | if osp.isfile(metric_model_path): 865 | model_checkpoint_path = metric_model_path 866 | checkpoint = torch.load(model_checkpoint_path) 867 | model.load_state_dict(checkpoint["model_state_dict"]) 868 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 869 | model_epoch = checkpoint["epoch"] 870 | if self.settings["use_embedding"]: 871 | emb.load_state_dict(checkpoint["embedding_state_dict"]) 872 | 873 | evaluation_test = self.evaluate( 874 | model, "test", dataloaders, datasets, loss_func, emb 875 | ) 876 | self.update_scores( 877 | evaluation_test, 878 | "best model test", 879 | model_epoch, 880 | run_it, 881 | metric, 882 | ) 883 | 884 | df_preds = pd.DataFrame( 885 | { 886 | "Source": [e[0] for e in evaluation_test.links], 887 | "Target": [e[1] for e in evaluation_test.links], 888 | "y true": evaluation_test.y_true, 889 | "y pred": evaluation_test.y_prob, 890 | "run": run_it, 891 | "model": "highest metric", 892 | "based on": [metric for y in evaluation_test.y_prob], 893 | } 894 | ) 895 | self.predictions = self.predictions.append(df_preds) 896 | 897 | for metric in best_in_run: 898 | logging.info( 899 | f"Best validation %s: %f, at epoch %d.", 900 | metric, 901 | best_in_run[metric]["Score"], 902 | best_in_run[metric]["Epoch"], 903 | ) 904 | logging.info( 905 | "Finished run %d / %d of %d epochs.", run_it, self.n_runs, epoch 906 | ) 907 | 908 | average_valid_auc /= len(runs_range) 909 | 910 | if (running_test or final_test) and len(datasets["test"]) != 0: 911 | self.scores["test"].to_csv(osp.join(self.res_dir, "test_scores.csv")) 912 | 913 | self.scores["running"].to_csv(osp.join(self.res_dir, "running_scores.csv")) 914 | self.predictions.to_csv(osp.join(self.res_dir, "test_predictions.csv")) 915 | plot_results( 916 | final_test=final_test, running_test=running_test, path=self.res_dir, settings=self.settings 917 | ) 918 | 919 | logging.info("Finished all runs.") 920 | logging.info( 921 | "Overall best validation loss: %f", self.running_best["Loss"]["Score"] 922 | ) 923 | logging.info( 924 | "Overall best validation AUC: %f", self.running_best["AUC"]["Score"] 925 | ) 926 | 927 | for metric in self.running_best: 928 | tmp_checkpoint = osp.join( 929 | self.res_dir, f"tmp_best_{metric}_model_checkpoint.pth" 930 | ) 931 | print(tmp_checkpoint) 932 | if osp.isfile(tmp_checkpoint): 933 | os.remove(tmp_checkpoint) 934 | 935 | end = time.time() 936 | m, s = divmod(end - start, 60) 937 | h, m = divmod(m, 60) 938 | logging.info(f"Took h:m:s %d:%d:%d", h, m, s) 939 | 940 | return average_valid_auc 941 | 942 | 943 | # ------------------------- Functions ------------------------- 944 | 945 | 946 | def add_score(scores_df, score): 947 | """Helper function for updating scores.""" 948 | 949 | tmp = pd.DataFrame({}) 950 | tmp["Run"] = [score.run_it] 951 | tmp["Epoch"] = [score.epoch] 952 | tmp["Score"] = [score.score] 953 | tmp["Metric"] = [score.metric] 954 | tmp["Split"] = [score.split] 955 | 956 | if scores_df is None: 957 | scores_df = pd.DataFrame({}) 958 | 959 | return scores_df.append(tmp, ignore_index=True) 960 | 961 | 962 | def save_torch_model(model, optimizer, emb, epoch, save_as): 963 | """Helper function for saving model, optimizer and if provided embeddings.""" 964 | 965 | if emb: 966 | torch.save( 967 | { 968 | "model_state_dict": model.state_dict(), 969 | "optimizer_state_dict": optimizer.state_dict(), 970 | "embedding_state_dict": emb.stat_dict(), 971 | "epoch": epoch, 972 | }, 973 | save_as, 974 | ) 975 | else: 976 | torch.save( 977 | { 978 | "model_state_dict": model.state_dict(), 979 | "optimizer_state_dict": optimizer.state_dict(), 980 | "epoch": epoch, 981 | }, 982 | save_as, 983 | ) 984 | 985 | 986 | def plot_calibration_curve(y_true, y_pred, label="Model"): 987 | """Plot the calibration curve for the model.""" 988 | 989 | fig = plt.figure(figsize=(10, 10)) 990 | ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2) 991 | ax2 = plt.subplot2grid((3, 1), (2, 0)) 992 | 993 | ax1.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated") 994 | 995 | fraction_positives, mean_predicted_value = calibration_curve( 996 | y_true=y_true, y_prob=y_pred, n_bins=20 997 | ) 998 | ax1.plot(mean_predicted_value, fraction_positives, "s-", label=label) 999 | 1000 | ax2.hist(y_pred, range=(0, 1), bins=100, label=label, histtype="step", lw=2) 1001 | 1002 | ax1.set_ylabel("Fraction of positives") 1003 | ax1.set_ylim([-0.05, 1.05]) 1004 | ax1.legend(loc="lower right") 1005 | ax1.set_title("Calibration plots (reliability curve)") 1006 | 1007 | ax2.set_xlabel("Mean predicted value") 1008 | ax2.set_ylabel("Count") 1009 | ax2.legend(loc="upper center", ncol=2) 1010 | 1011 | plt.tight_layout() 1012 | plt.show() 1013 | 1014 | return fig, fraction_positives, mean_predicted_value 1015 | 1016 | 1017 | def plot_results(final_test, running_test, path=None, settings=None): 1018 | """Plot the results in various ways. 1019 | Args: 1020 | final_test (boolean): If True plot result on test set. 1021 | running_test (boolean): Was running_test used during runs. 1022 | path (string): Pathway to result directory. 1023 | """ 1024 | 1025 | scores = pd.read_csv(osp.join(path, "running_scores.csv")) 1026 | # Plot Loss 1027 | loss_scores = scores[scores["Metric"] == "Loss"] 1028 | fig = sns.relplot(data=loss_scores, x="Epoch", y="Score", hue="Split", kind="line") 1029 | fig.set(ylabel="Loss") 1030 | fig.savefig(osp.join(path, "loss.png")) 1031 | plt.close() 1032 | 1033 | # Plot Metrics 1034 | columns = ["AUC", "AP", "Accuracy", "F1", "Recall", "Precision"] 1035 | fig = sns.relplot( 1036 | data=scores, 1037 | x="Epoch", 1038 | y="Score", 1039 | hue="Split", 1040 | kind="line", 1041 | col="Metric", 1042 | col_wrap=3, 1043 | col_order=columns, 1044 | ) 1045 | fig.savefig(osp.join(path, "metrics.png")) 1046 | plt.close() 1047 | 1048 | if final_test: 1049 | # ROC curve 1050 | test_df = pd.read_csv(osp.join(path, "test_predictions.csv"), index_col=0) 1051 | 1052 | plt.figure() 1053 | 1054 | if running_test and ( 1055 | settings["include_in_train"] != "test" 1056 | and settings["include_in_train"] != "valid" 1057 | ): 1058 | test_df = test_df[test_df["model"] == "highest metric"] 1059 | test_df = test_df[test_df["based on"] == "AUC"] 1060 | y_true = test_df["y true"] 1061 | y_pred = test_df["y pred"] 1062 | fpr, tpr, _ = roc_curve(y_true, y_pred) 1063 | auc = roc_auc_score(y_true, y_pred) 1064 | 1065 | plt.plot(fpr, tpr, label=f"AUC={auc:.4f}") 1066 | plt.plot([0, 1], [0, 1], color="black", linestyle="--") 1067 | plt.xlabel("False Positive Rate") 1068 | plt.ylabel("True Positive Rate") 1069 | plt.legend() 1070 | plt.title("Receiver Operating Characteristic") 1071 | plt.savefig(osp.join(path, "roc.png")) 1072 | plt.close() 1073 | 1074 | # Distribution of predictions 1075 | predictions_true = test_df[test_df["y true"] == 1]["y pred"] 1076 | predictions_false = test_df[test_df["y true"] == 0]["y pred"] 1077 | 1078 | plt.figure() 1079 | plt.hist( 1080 | predictions_true, 1081 | histtype="step", 1082 | bins=50, 1083 | alpha=0.6, 1084 | color="green", 1085 | linewidth=2, 1086 | label="Positive Class", 1087 | ) 1088 | plt.hist( 1089 | predictions_false, 1090 | histtype="step", 1091 | alpha=0.6, 1092 | bins=50, 1093 | color="red", 1094 | linewidth=2, 1095 | label="Negative Class", 1096 | ) 1097 | 1098 | plt.xlabel("Prediction") 1099 | plt.ylabel("Count") 1100 | plt.title("Predicted probability for each class.") 1101 | plt.legend() 1102 | plt.savefig(osp.join(path, "dist_prediction.png")) 1103 | plt.close() 1104 | 1105 | # Make calibration curve plot 1106 | fig, _, _ = plot_calibration_curve(y_true, y_pred, label="SEAL Model") 1107 | fig.savefig(osp.join(path, "calibration_curve.png")) 1108 | 1109 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/reaction-graph-link-prediction/118acb3b4f2d9afe5c34a1a132c91ad1b8c021d5/utils/__init__.py -------------------------------------------------------------------------------- /utils/evaluate_model.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import pandas as pd 4 | import os.path as osp 5 | 6 | from torch_trainer import GraphTrainer 7 | 8 | 9 | def predict_links(model_dir, edges, graph_path="", num_workers=None): 10 | """Reload pre-trained model and test with specified links. 11 | Args: 12 | model_dir (str): pathway to directory where the trained model is saved 13 | edges (tensor): tensor with dim (2, n) of edges that will be tested 14 | """ 15 | 16 | # model dir path provided 17 | if osp.isfile(f"{model_dir}/settings.csv"): 18 | settings = pd.read_csv(f"{model_dir}/settings.csv") 19 | # model path provided 20 | elif osp.isfile(f'{"/".join(model_dir.split("/")[:-1])}/settings.csv'): 21 | settings = pd.read_csv(f'{"/".join(model_dir.split("/")[:-1])}/settings.csv') 22 | else: 23 | print("Wrong model_dir path provided:", model_dir) 24 | sys.exit() 25 | converted = [] 26 | for i in settings["Settings"]: 27 | try: 28 | i_ = eval(i) 29 | except: 30 | i_ = i 31 | converted.append(i_) 32 | settings["value"] = converted 33 | settings = dict(zip(settings["Unnamed: 0"], settings["value"])) 34 | 35 | # Set path to pre-trained model and name 36 | settings["pre_trained_model_path"] = model_dir 37 | settings["name"] = f"evaluate_model/{model_dir}" 38 | # Reset parameter so that test is not set to empty when doing make_data_split 39 | settings["include_in_train"] = None 40 | 41 | if num_workers != None: 42 | settings["num_workers"] = num_workers 43 | if graph_path: 44 | settings["graph_path"] = graph_path 45 | 46 | if not "seed" in settings.keys(): 47 | settings["seed"] = 1 48 | 49 | trainer = GraphTrainer(settings) 50 | eln_dataset = trainer.initialize_data(1) # settings['seed']) 51 | 52 | if edges is not None: 53 | # assign evenly as positive and negative edges to test set. 54 | # The labels has no importance in themselves here, but the order does. 55 | n_edges = edges.size(1) 56 | eln_dataset.data.split_edge["test"]["pos"] = edges[:, : int(n_edges / 2)] 57 | eln_dataset.data.split_edge["test"]["neg"] = edges[:, int(n_edges / 2) :] 58 | 59 | datasets, dataloaders = trainer.make_data_splits(eln_dataset) 60 | model, _, _, _ = trainer.initialize_model(datasets) 61 | 62 | y_prob, links = trainer.predict( 63 | model, datasets, dataloaders, "test", None 64 | ) 65 | 66 | return y_prob, links 67 | 68 | -------------------------------------------------------------------------------- /utils/evaluate_predictions.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | import pandas as pd 5 | import seaborn as sns 6 | import matplotlib.pyplot as plt 7 | from sklearn.metrics import roc_auc_score, roc_curve 8 | 9 | import warnings 10 | 11 | with warnings.catch_warnings(): 12 | warnings.filterwarnings("ignore", category=RuntimeWarning) 13 | import graph_tool.all as gt 14 | 15 | 16 | # get edges as edge list 17 | def get_edges(edges_series): 18 | """From a panda series convert containing: '[id1, id2]' convert from edge to edge list""" 19 | 20 | edges = [] 21 | for edge in edges_series: 22 | tmp = edge.strip("][") 23 | tmp = list(map(int, tmp.split(","))) 24 | edges.append(tmp) 25 | 26 | return edges 27 | 28 | 29 | # get edges as edge list 30 | def get_edges(edges_series): 31 | 32 | edges = [] 33 | for edge in edges_series: 34 | tmp = edge.strip("][") 35 | tmp = list(map(int, tmp.split(","))) 36 | edges.append(tmp) 37 | return edges 38 | 39 | 40 | def create_id_csv(graph, file_path): 41 | """Creates a csv file where each gt id is mapped to its neo4j id""" 42 | 43 | neo4j_id = [] 44 | gt_id = [] 45 | 46 | for v in graph.vertices(): 47 | gt_id.append(int(v)) 48 | neo4j_id.append(int(graph.vertex_properties["_graphml_vertex_id"][v][1:])) 49 | 50 | nodes_labels_dict = {"renamed_ID": gt_id, "neo4j_ID": neo4j_id} 51 | nodes_labels_dict = pd.DataFrame(nodes_labels_dict) 52 | 53 | if file_path: 54 | nodes_labels_dict.to_csv(file_path, index=False) 55 | print("Saved!") 56 | 57 | return nodes_labels_dict 58 | 59 | 60 | # --------- Find Optimal Threshold --------- 61 | 62 | 63 | def find_optimal_threshold(roc_curve, plot=True, title=None, save_name=None): 64 | """Find optimal threshold based on the roc curve.""" 65 | 66 | fpr, tpr, threshold = roc_curve 67 | if plot: 68 | fig, ax = plt.subplots(1, figsize=(6, 6)) 69 | ax.spines["right"].set_visible(False) 70 | ax.spines["top"].set_visible(False) 71 | ax.plot(threshold, 1 - fpr, label="1 - FPR", zorder=1) 72 | ax.plot(threshold, tpr, label="TPR", zorder=2) 73 | 74 | where = np.where(np.round(1 - fpr, 2) == np.round(tpr, 2)) 75 | intersect = where[0][int(len(where[0]) / 2)] 76 | ax.scatter( 77 | threshold[intersect], 78 | tpr[intersect], 79 | c="black", 80 | marker="*", 81 | s=100, 82 | label="Optimal Threshold", 83 | zorder=3, 84 | ) 85 | 86 | plt.xlabel("Threshold") 87 | plt.title(f"Optimal Threshold: {title}") 88 | plt.xlim(0, 1) 89 | plt.legend() 90 | 91 | if save_name: 92 | plt.savefig( 93 | f"figures/cumulative_prediction/{save_name}_1.png", 94 | dpi=300, 95 | ) 96 | plt.show() 97 | else: 98 | plt.show() 99 | 100 | print("Optimal thereshold is around:", threshold[intersect]) 101 | 102 | if plot: 103 | fig, ax = plt.subplots(1, figsize=(7.5, 6)) 104 | ax.spines["right"].set_visible(False) 105 | ax.spines["top"].set_visible(False) 106 | plt.scatter( 107 | fpr, 108 | tpr, 109 | c=np.round(threshold, 3), 110 | s=6, 111 | cmap=sns.color_palette("Spectral", as_cmap=True), 112 | ) # )'cool') 113 | plt.clim(0, 1) 114 | cbar = plt.colorbar() 115 | cbar.ax.set_ylabel("Threshold", rotation=-90, va="bottom") 116 | 117 | plt.scatter( 118 | fpr[intersect], 119 | tpr[intersect], 120 | c="black", 121 | marker="*", 122 | s=100, 123 | label="Optimal Threshold", 124 | ) 125 | 126 | plt.title(f"ROC Curve: {title}") 127 | plt.ylabel("TPR") 128 | plt.xlabel("FPR") 129 | plt.legend() 130 | 131 | if save_name: 132 | plt.savefig( 133 | f"figures/cumulative_prediction/{save_name}_2.png", 134 | dpi=300, 135 | ) 136 | else: 137 | plt.show() 138 | 139 | return threshold[intersect] 140 | 141 | 142 | # --------- Percentage of Predictions below Prediction against Predictions --------- 143 | 144 | 145 | def plot_cumulative_predictions( 146 | predictions_dict, threshold, title=None, colors=None, save_name=None 147 | ): 148 | 149 | fig, ax = plt.subplots(1, figsize=(8, 8)) 150 | ax.spines["right"].set_visible(False) 151 | ax.spines["top"].set_visible(False) 152 | 153 | if colors == None: 154 | colors = [f"C{i}" for i in range(len(predictions_dict))] 155 | print(colors) 156 | 157 | i = 0 158 | for name, predictions in predictions_dict.items(): 159 | preds = np.sort(predictions) 160 | sum_preds = [np.sum(preds < p) / len(preds) for p in preds] 161 | plt.plot(preds, sum_preds, linewidth=3, label=name, c=colors[i]) 162 | i += 1 163 | 164 | if threshold: 165 | plt.plot( 166 | [threshold, threshold], 167 | [0, 1], 168 | linewidth=2, 169 | linestyle="dashed", 170 | c="grey", 171 | label=f"Optimal Threshold: {np.round(threshold,2)}", 172 | ) 173 | 174 | plt.xlabel("Prediction, p") 175 | plt.ylabel(" % predictions < p") 176 | plt.title(f"{title}") 177 | plt.legend() 178 | 179 | if save_name: 180 | plt.savefig( 181 | f"figures/cumulative_prediction/{save_name}.png", 182 | dpi=300, 183 | ) 184 | else: 185 | plt.show() 186 | 187 | 188 | def plot_metrics(prediction_df): 189 | results = {"Random": {}, "Distributed": {}, "Structured": {}} 190 | # ROC curve 191 | q = int(len(prediction_df) / 4) 192 | ranges = {} 193 | ranges["Random"] = list(range(2 * q)) 194 | r = list(range(q)) 195 | r.extend(list(range(2 * q, 3 * q))) 196 | ranges["Distributed"] = r 197 | r = list(range(q)) 198 | r.extend(list(range(3 * q, 4 * q))) 199 | ranges["Structured"] = r 200 | 201 | colors = {"Random": "C1", "Distributed": "C2", "Structured": "C3"} 202 | plt.figure() 203 | for sampling in ["Random", "Distributed", "Structured"]: 204 | y_true = prediction_df["True"][ranges[sampling]] 205 | y_pred = prediction_df["Best AUC Preds"][ranges[sampling]] 206 | fpr, tpr, roc_thresholds = roc_curve(y_true, y_pred) 207 | results[sampling]["AUC"] = roc_auc_score(y_true, y_pred) 208 | # if sampling != 'Structured': 209 | plt.plot( 210 | fpr, 211 | tpr, 212 | color=colors[sampling], 213 | label=f"{sampling}, AUC={results[sampling]['AUC']:.4f}", 214 | ) 215 | y_pred_neg = torch.tensor(y_pred[y_true == 0].tolist()) 216 | y_pred_pos = torch.tensor(y_pred[y_true == 1].tolist()) 217 | # HITS @ k 218 | for k in [20, 50, 100]: 219 | kth_score_in_negative_edges = torch.topk(y_pred_neg, k)[0][-1] 220 | results[sampling][f"HITS@{k}"] = float( 221 | torch.sum(y_pred_pos > kth_score_in_negative_edges).cpu() 222 | ) / len(y_pred_pos) 223 | 224 | display(pd.DataFrame(results)) 225 | plt.plot([0, 1], [0, 1], color="black", linestyle="--") 226 | plt.xlabel("False Positive Rate") 227 | plt.ylabel("True Positive Rate") 228 | plt.legend() 229 | plt.title("Receiver Operating Characteristic") 230 | # plt.savefig(osp.join(path, f'roc.png')) 231 | plt.show() 232 | 233 | 234 | def plot_prediction_distribution(prediction_df): 235 | results = {"Random": {}, "Distributed": {}, "Structured": {}} 236 | # ROC curve 237 | q = int(len(prediction_df) / 4) 238 | ranges = {} 239 | ranges["Random"] = list(range(2 * q)) 240 | r = list(range(q)) 241 | r.extend(list(range(2 * q, 3 * q))) 242 | ranges["Distributed"] = r 243 | r = list(range(q)) 244 | r.extend(list(range(3 * q, 4 * q))) 245 | ranges["Structured"] = r 246 | 247 | colors = {"Random": "C1", "Distributed": "C2", "Structured": "C3"} 248 | 249 | # Distribution of predictions 250 | predictions_true = prediction_df[prediction_df["True"] == 1]["Best AUC Preds"] 251 | predictions_false = prediction_df[prediction_df["True"] == 0]["Best AUC Preds"] 252 | r = int(len(predictions_false) / 3) 253 | plt.figure() 254 | plt.hist( 255 | predictions_true, 256 | histtype="step", 257 | bins=50, 258 | alpha=0.6, 259 | color="C0", 260 | linewidth=2, 261 | label="Positive Class", 262 | ) 263 | plt.hist( 264 | predictions_false[:r], 265 | histtype="step", 266 | alpha=0.6, 267 | bins=50, 268 | color=colors["Random"], 269 | linewidth=2, 270 | label="Random Negative Class", 271 | ) 272 | plt.hist( 273 | predictions_false[r : 2 * r], 274 | histtype="step", 275 | alpha=0.6, 276 | bins=50, 277 | color=colors["Distributed"], 278 | linewidth=2, 279 | label="Distributed Negative Class", 280 | ) 281 | plt.hist( 282 | predictions_false[2 * r :], 283 | histtype="step", 284 | alpha=0.6, 285 | bins=50, 286 | color=colors["Structured"], 287 | linewidth=2, 288 | label="Structured Negative Class", 289 | ) 290 | plt.xlabel("Prediction") 291 | plt.ylabel("Count") 292 | plt.title("Predicted probability for each class.") 293 | plt.legend() 294 | # plt.savefig(osp.join(path, f'dist_prediction.png')) 295 | plt.show() 296 | 297 | results_df = pd.DataFrame.from_dict(results, orient="index") 298 | # results_df.to_csv(osp.join(path, f'test_metrics.csv')) 299 | 300 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import pandas as pd 4 | 5 | 6 | def hitsK(y_true, y_pred, k): 7 | """Calculate the Hits@K. Based on definition in: 8 | M. Ali et al., "Bringing Light Into the Dark: A Large-scale Evaluation of Knowledge Graph 9 | Embedding Models under a Unified Framework" 10 | doi: 10.1109/TPAMI.2021.3124805. 11 | 12 | Args: 13 | y_true (1D tensor): Tensor with true labels. 14 | y_pred (1D tensor): Tensor wit predictions. 15 | 'create_all_corrupted_df'. 16 | k (int): Calculate Hits Based on the top k prediction. 17 | 18 | Return: 19 | hits_k_score (float): Hits@K score 20 | """ 21 | 22 | y_pred_neg = torch.tensor(y_pred[y_true == 0].tolist()) 23 | y_pred_pos = torch.tensor(y_pred[y_true == 1].tolist()) 24 | 25 | top_k_prob, _ = torch.topk(y_pred_neg, k) 26 | k_prob = float(top_k_prob[-1]) 27 | 28 | hits_k_score = float(torch.sum(y_pred_pos > k_prob) / len(y_pred_pos)) 29 | 30 | return hits_k_score 31 | 32 | 33 | def mean_average_precision(y_true, y_pred, edges): 34 | """Calculate the mean of the average precision score. Based on definition in: 35 | M. Ali et al., "Bringing Light Into the Dark: A Large-scale Evaluation of Knowledge Graph 36 | Embedding Models under a Unified Framework" 37 | doi: 10.1109/TPAMI.2021.3124805. 38 | 39 | Args: 40 | y_true (iterable): Ground truth for each edge. 41 | y_pred (iterable): Prediction for each edge. 42 | edges (2d tensor or tuple wit 2 !D tensor): Test edges. 43 | 44 | Return: 45 | mean_ap (float): Mean average precision score. 46 | n_valid_nodes (int): Number of nodes included in the mean. 47 | len(set_nodes): Total number of unique nodes in the edges. 48 | """ 49 | if len(edges) == 2: 50 | node_1, node_2 = edges[0], edges[1] 51 | else: 52 | node_1 = [e[0] for e in edges] 53 | node_2 = [e[1] for e in edges] 54 | 55 | df = pd.DataFrame( 56 | {"y true": y_true, "y pred": y_pred, "node 1": node_1, "node 2": node_2} 57 | ) 58 | 59 | df = df.sort_values(by="y pred", ascending=False, ignore_index=True) 60 | 61 | all_nodes = [int(n) for n in node_1] 62 | all_nodes.extend([int(n) for n in node_2]) 63 | 64 | set_nodes = [int(n) for n in set(all_nodes)] 65 | mean_ap = 0 66 | ap = 0 67 | count_included_nodes = 0 68 | for n in set_nodes: 69 | df_tmp = df[df["node 1"] == n].append(df[df["node 2"] == n]).drop_duplicates() 70 | df_tmp = df_tmp.sort_index() 71 | 72 | y_true = df_tmp["y true"].values 73 | if 0 in y_true and 1 in y_true: 74 | index_true = df_tmp[df_tmp["y true"] == 1].index 75 | for i in index_true: 76 | ap = 0 77 | df_tmp_k = df_tmp.loc[:i] 78 | ap += len(df_tmp_k[df_tmp_k["y true"] == 1].values) / len( 79 | df_tmp_k["y true"].values 80 | ) 81 | mean_ap += ap 82 | else: 83 | count_included_nodes += 1 84 | 85 | n_valid_nodes = len(set_nodes) - count_included_nodes 86 | if n_valid_nodes > 0: 87 | mean_ap /= n_valid_nodes 88 | else: 89 | mean_ap = None 90 | 91 | return mean_ap, n_valid_nodes, len(set_nodes) 92 | 93 | -------------------------------------------------------------------------------- /utils/negative_sampling.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import torch 4 | import logging 5 | import warnings 6 | import numpy as np 7 | from tqdm import tqdm 8 | from sys import maxsize 9 | from os.path import isfile 10 | from collections import Counter 11 | 12 | with warnings.catch_warnings(): 13 | warnings.filterwarnings("ignore", category=RuntimeWarning) 14 | import graph_tool.all as gt 15 | 16 | 17 | def correct_overlaps(pos_edge_index, neg_edge_index, num_nodes, seed): 18 | """Checks for overlap between pos_edge_index and neg_edge_index and replaces these, 19 | while considering the edges as undirected. 20 | 21 | Args: 22 | pos_edge_index (2D tensor or tuple with 2 1D tesor): Used to compare neg_edge_index to 23 | neg_edge_index (2D tensor or tuple with 2 1D tesor): Used to check if any overlap with 24 | pos_edge_index 25 | num_nodes (int): Number of nodes in the graph. 26 | seed (int): Randoms seed used by torch. 27 | 28 | Returns: 29 | neg_edge_index_1 (tensor): neg_edge_index_1, same as input 30 | neg_edge_index_2 (tensor): Updated version of negative_edge_index so no overlaps 31 | exists with pos_edge_index. 32 | """ 33 | 34 | torch.manual_seed(seed) 35 | pos_edge_index_1, pos_edge_index_2 = pos_edge_index 36 | neg_edge_index_1, neg_edge_index_2 = neg_edge_index 37 | 38 | all_nodes = torch.arange(num_nodes) 39 | 40 | # Creating idx_1 (unique identifier) for both edge directions 41 | idx_1 = torch.cat( 42 | ( 43 | pos_edge_index_1 * num_nodes + pos_edge_index_2, 44 | pos_edge_index_2 * num_nodes + pos_edge_index_1, 45 | all_nodes * num_nodes + all_nodes, 46 | ) 47 | ) 48 | 49 | idx_2 = neg_edge_index_1 * num_nodes + neg_edge_index_2 50 | 51 | mask = torch.from_numpy(np.isin(idx_2, idx_1)).to(torch.bool) 52 | rest = mask.nonzero(as_tuple=False).view(-1) 53 | 54 | logging.warning( 55 | "%d overlaps have been found. \n Comparing: %d edges with %d negative edges.", 56 | rest.numel(), 57 | pos_edge_index_1.size(0), 58 | neg_edge_index_1.size(0), 59 | ) 60 | 61 | while rest.numel() > 0: # pragma: no cover 62 | tmp = torch.randint(num_nodes, (rest.numel(),), dtype=torch.long) 63 | idx_2 = neg_edge_index_1[rest] * num_nodes + tmp 64 | mask = torch.from_numpy(np.isin(idx_2, idx_1)).to(torch.bool) 65 | neg_edge_index_2[rest] = tmp 66 | rest = rest[mask.nonzero(as_tuple=False).view(-1)] 67 | 68 | return neg_edge_index_1, neg_edge_index_2 69 | 70 | 71 | def remove_overlaps(corrupted_edge_index, pos_edge_index): 72 | """Checks for overlap between corrupted_edge_index and pos_edge_index and removes these, 73 | while considering the edges as undirected. 74 | 75 | Args: 76 | pos_edge_index (2D tensor or tuple with 2 1D tesor): Used to compare neg_edge_index to 77 | corrupted_edge_index (2D tensor or tuple with 2 1D tesor): Used to check if any overlap with 78 | pos_edge_index 79 | 80 | Returns: 81 | neg_edge_index_1 (tensor): Updated version of corrupted_edge_index with any overlaps 82 | with pos_edge_index removed. 83 | neg_edge_index_2 (tensor): Updated version of corrupted_edge_index with any overlaps 84 | with pos_edge_index removed. 85 | """ 86 | 87 | pos_edge_index_1, pos_edge_index_2 = pos_edge_index 88 | neg_edge_index_1, neg_edge_index_2 = corrupted_edge_index 89 | 90 | all_nodes = [int(i) for i in pos_edge_index_1] 91 | all_nodes.extend([int(i) for i in pos_edge_index_2]) 92 | all_nodes = torch.tensor(list(set(all_nodes))) 93 | 94 | num_nodes = max(all_nodes) 95 | 96 | # Creating idx_1 (unique identifier) for both edge dirctions 97 | idx_1 = torch.cat( 98 | ( 99 | pos_edge_index_1 * num_nodes + pos_edge_index_2, 100 | pos_edge_index_2 * num_nodes + pos_edge_index_1, 101 | all_nodes * num_nodes + all_nodes, 102 | ) 103 | ) 104 | 105 | idx_2 = neg_edge_index_1 * num_nodes + neg_edge_index_2 106 | 107 | mask = torch.from_numpy(np.isin(idx_2, idx_1)).to(torch.bool) 108 | 109 | neg_edge_index_1 = neg_edge_index_1[~mask] 110 | neg_edge_index_2 = neg_edge_index_2[~mask] 111 | 112 | return neg_edge_index_1, neg_edge_index_2 113 | 114 | 115 | # ---------------------------- Negative Sampling Methods ---------------------------- 116 | def sample_analogs(edge_index, n, num_nodes, all_pos_edges, seed): 117 | """Sample negative edges from the distribution of nodes in the positive edges 118 | where positive set node degrees are preserved 119 | 120 | Args: 121 | edge_index (2D tensor or tuple with 2 1D tesor): The edges to base the sampling on. 122 | n (int): How many edges to sample: n * number of edges in edge_index negative edges. 123 | num_nodes (int): Number of nodes in the graph. 124 | all_pos_edges: All edges in the graph. 125 | seed (int): Random seed for torch. 126 | 127 | Return: 128 | i_neg (1D tensors): Sampled negative edge sources. 129 | j_neg (1D tensors): Sampled negative edge targets. 130 | """ 131 | 132 | torch.manual_seed(seed) 133 | 134 | # check if degree distribution already exists for a given seed 135 | negative_analogs_file = ( 136 | "data/negatives=analogs_USPTO_1_fingerprints_single_NameRxn3.2.pt" 137 | ) 138 | 139 | if isfile(negative_analogs_file): 140 | neg_edges = torch.load(negative_analogs_file) 141 | logging.info( 142 | "Loading pre-computed negative links of 2-nodes away nodes file %s", 143 | negative_analogs_file, 144 | ) 145 | else: 146 | sys.exit() 147 | 148 | neg_edges = neg_edges[:, torch.randperm(neg_edges.shape[1])[:n]] 149 | 150 | if neg_edges.shape[1] < n: 151 | logging.warning( 152 | "To few unique sampled negative edges by sample_distribution. \ 153 | Fill with random edges." 154 | ) 155 | # Fill with randomly sampled edges 156 | n_rand = n - neg_edges.shape[1] 157 | i_neg_rand, j_neg_rand = sample_random( 158 | edge_index, n_rand, int(1.5 * num_nodes), all_pos_edges, seed=seed 159 | ) 160 | ij_rand = torch.stack((i_neg_rand, j_neg_rand), dim=0) 161 | # cat and remove repetetive edges - does not replace the removed ones! 162 | neg_edges = torch.cat((neg_edges, ij_rand), dim=1).unique(dim=1) 163 | 164 | i_neg, j_neg = neg_edges[:, :n] 165 | 166 | return i_neg, j_neg 167 | 168 | 169 | def sample_degree_preserving_distribution( 170 | negative_degree_preserving_distribution_file, 171 | edge_index, 172 | n, 173 | num_nodes, 174 | all_pos_edges, 175 | seed, 176 | ): 177 | """Sample negative edges from the distribution of nodes in the positive edges 178 | where positive set node degrees are preserved 179 | 180 | Args: 181 | edge_index (2D tensor or tuple with 2 1D tesor): The edges to base the sampling on. 182 | n (int): How many edges to sample: n * number of edges in edge_index negative edges. 183 | num_nodes (int): Number of nodes in the graph. 184 | all_pos_edges: All edges in the graph. 185 | seed (int): Random seed for torch. 186 | 187 | Return: 188 | i_neg (1D tensors): Sampled negative edge sources. 189 | j_neg (1D tensors): Sampled negative edge targets. 190 | """ 191 | 192 | torch.manual_seed(seed) 193 | 194 | # check if degree distribution already exists 195 | if isfile(negative_degree_preserving_distribution_file): 196 | neg_edges = torch.load(negative_degree_preserving_distribution_file) 197 | logging.info( 198 | "Loading pre-computed degree-preserving distribution negative links with a fixed seed %d from %s'.", 199 | seed, 200 | negative_degree_preserving_distribution_file, 201 | ) 202 | else: 203 | print("Generating negative links preserving degree distribution. Take a seat.") 204 | 205 | i_pos, j_pos = edge_index 206 | ij = torch.cat((i_pos, j_pos)) 207 | all_nodes_in_pos_edges = ij.tolist() 208 | 209 | # loop descending by node popularity to decrease chance of insufficient unique link pair nodes available for popular nodes 210 | # NOTE: this leads overesimated predicted duration of sampling by tqdm bar 211 | counter = Counter(all_nodes_in_pos_edges).most_common() 212 | 213 | source_nodes, target_nodes = [], [] 214 | i = 0 215 | pbar = tqdm(total=len(counter)) 216 | while i < len(counter): 217 | source = counter[i][0] 218 | degree = counter[i][1] 219 | target = maxsize 220 | 221 | # get positive head or tail partners of node in question 222 | j_partners = j_pos[i_pos == source] 223 | i_partners = i_pos[j_pos == source] 224 | 225 | # make list of forbidden partner nodes which includes partner nodes of node in question as well as itself 226 | nodes_for_exclusion = set( 227 | torch.cat((i_partners, j_partners, torch.tensor([source]))).tolist() 228 | ) 229 | 230 | while degree > 0: 231 | all_nodes_in_pos_edges.remove(source) 232 | 233 | # add target nodes already used for source node in question 234 | nodes_for_exclusion.add(target) 235 | 236 | # make list for sampling which does not include any forbidden nodes 237 | nodes_for_sampling = [ 238 | item 239 | for item in all_nodes_in_pos_edges 240 | if item not in nodes_for_exclusion 241 | ] 242 | 243 | if len(nodes_for_sampling) > 0: 244 | target = nodes_for_sampling[ 245 | torch.randint(0, len(nodes_for_sampling), size=(1,)) 246 | ] 247 | all_nodes_in_pos_edges.remove(target) 248 | source_nodes.append(source) 249 | target_nodes.append(target) 250 | 251 | # reduce counter for target node found 252 | k = 0 253 | while k < len(counter): 254 | if counter[k][0] == target: 255 | counter[k] = tuple([counter[k][0], counter[k][1] - 1]) 256 | break 257 | k += 1 258 | degree -= 1 259 | i += 1 260 | pbar.update(1) 261 | pbar.close() 262 | 263 | i_neg = torch.tensor(source_nodes) 264 | j_neg = torch.tensor(target_nodes) 265 | 266 | i_neg, j_neg = correct_overlaps(all_pos_edges, (i_neg, j_neg), num_nodes, seed) 267 | neg_edges = torch.stack((i_neg, j_neg), dim=0) 268 | 269 | old_size = len(neg_edges[0]) 270 | neg_edges = neg_edges.unique(dim=1) 271 | logging.info( 272 | "%d (%f) duplicate edges from distribution were removed.", 273 | old_size - len(neg_edges[0]), 274 | (old_size - len(neg_edges[0])) / old_size, 275 | ) 276 | torch.save(neg_edges, negative_degree_preserving_distribution_file) 277 | 278 | neg_edges = neg_edges[:, torch.randperm(neg_edges.shape[1])[:n]] 279 | 280 | if neg_edges.shape[1] < n: 281 | logging.warning( 282 | "To few unique sampled negative edges by sample_distribution. \ 283 | Fill with random edges." 284 | ) 285 | # Fill with randomly sampled edges 286 | n_rand = n - neg_edges.shape[1] 287 | i_neg_rand, j_neg_rand = sample_random( 288 | edge_index, n_rand, int(1.5 * num_nodes), all_pos_edges, seed=seed 289 | ) 290 | ij_rand = torch.stack((i_neg_rand, j_neg_rand), dim=0) 291 | # cat and remove repetetive edges - does not replace the removed ones! 292 | neg_edges = torch.cat((neg_edges, ij_rand), dim=1).unique(dim=1) 293 | 294 | i_neg, j_neg = neg_edges[:, :n] 295 | 296 | return i_neg, j_neg 297 | 298 | 299 | def sample_distribution(edge_index, n, num_nodes, all_pos_edges, seed): 300 | """Sample negative edges from the distribution of nodes in the positive edges. 301 | 302 | Args: 303 | edge_index (2D tensor or tuple with 2 1D tesor): The edges to base the sampling on. 304 | n (int): How many edges to sample: n * number of edges in edge_index negative edges. 305 | num_nodes (int): Number of nodes in the graph. 306 | all_pos_edges: All edges in the graph. 307 | seed (int): Random seed for torch. 308 | 309 | Return: 310 | i_neg (1D tensors): Sampled negative edge sources. 311 | j_neg (1D tensors): Sampled negative edge targets. 312 | """ 313 | 314 | torch.manual_seed(seed) 315 | 316 | i_pos, j_pos = edge_index 317 | ij = torch.cat((i_pos, j_pos)) 318 | 319 | source, target = [], [] 320 | x = int(np.ceil(n / len(i_pos))) 321 | for _ in range(x + 1): 322 | ij = ij[torch.randperm(len(ij))] 323 | source.append(ij[0 : len(i_pos)]) 324 | target.append(ij[len(i_pos) :]) 325 | 326 | i_neg = torch.cat(source) 327 | j_neg = torch.cat(target) 328 | 329 | i_neg, j_neg = correct_overlaps(all_pos_edges, (i_neg, j_neg), num_nodes, seed) 330 | neg_edges = torch.stack((i_neg, j_neg), dim=0) 331 | old_size = len(neg_edges[0]) 332 | neg_edges = neg_edges.unique(dim=1) 333 | logging.info( 334 | "%d (%f) duplicate edges from distribution were removed.", 335 | old_size - len(neg_edges[0]), 336 | (old_size - len(neg_edges[0])) / old_size, 337 | ) 338 | neg_edges = neg_edges[:, torch.randperm(neg_edges.shape[1])[:n]] 339 | 340 | if neg_edges.shape[1] < n: 341 | logging.warning( 342 | "To few unique sampled negative edges by sample_distribution. \ 343 | Fill with random edges." 344 | ) 345 | # Fill with randomly sampled edges 346 | n_rand = n - neg_edges.shape[1] 347 | i_neg_rand, j_neg_rand = sample_random( 348 | edge_index, n_rand, int(1.5 * num_nodes), all_pos_edges, seed=seed 349 | ) 350 | ij_rand = torch.stack((i_neg_rand, j_neg_rand), dim=0) 351 | # cat and remove repetetive edges - does not replace the removed ones! 352 | neg_edges = torch.cat((neg_edges, ij_rand), dim=1).unique(dim=1) 353 | 354 | i_neg, j_neg = neg_edges[:, :n] 355 | 356 | return i_neg, j_neg 357 | 358 | 359 | def sample_random(edge_index, n, num_nodes, all_pos_edges, seed): 360 | """Sample negative edges, both the source and target, at random. 361 | 362 | Args: 363 | edge_index (2D tensor or tuple with 2 1D tesor): The edges to base the sampling on. 364 | n (int): How many edges to sample: n * number of edges in edge_index negative edges. 365 | num_nodes (int): Number of nodes in the graph. 366 | all_pos_edges: All edges in the graph. 367 | seed (int): Random seed for torch. 368 | 369 | Return: 370 | i_neg (1D tensors): Sampled negative edge sources. 371 | j_neg (1D tensors): Sampled negative edge targets. 372 | """ 373 | 374 | torch.manual_seed(seed) 375 | 376 | all_nodes = torch.cat((edge_index[0], edge_index[1])).flatten().unique() 377 | 378 | i_index = torch.randint(len(all_nodes), (int(1.5 * n),), dtype=torch.long) 379 | j_index = torch.randint(len(all_nodes), (int(1.5 * n),), dtype=torch.long) 380 | 381 | i = all_nodes[i_index] 382 | j = all_nodes[j_index] 383 | 384 | i, j = correct_overlaps(all_pos_edges, (i, j), num_nodes, seed) 385 | 386 | neg_edges = torch.stack((i, j), dim=0) 387 | 388 | neg_edges = neg_edges.unique(dim=1) 389 | 390 | neg_edges = neg_edges[:, torch.randperm(neg_edges.shape[1])[:n]] 391 | 392 | if neg_edges.shape[1] < n: 393 | logging.warning("To few unique sampled negative edges by sample_random.") 394 | i, j = neg_edges 395 | else: 396 | i, j = neg_edges[:, :n] 397 | 398 | return i, j 399 | 400 | 401 | def one_against_all(nodes, edge_index, all_pos_edges, include_unconnected=False): 402 | """Sample negative edges, by keeping a fixed target node and sampling all possible source nodes. 403 | 404 | Args: 405 | nodes (iterable): Fixed reactant 406 | edge_index (2D tensor or tuple with 2 1D tesor): The edges to base the sampling on. 407 | all_pos_edges (2D tensor or tuple with 2 1D tesor): All edges in the graph. 408 | 409 | Return: 410 | neg_edges (2D tensors): Negative edges. 411 | """ 412 | all_nodes_in_edges = torch.cat((edge_index[0], edge_index[1])).flatten().unique() 413 | if include_unconnected: 414 | all_nodes = torch.arange(1, max(all_nodes_in_edges)) 415 | else: 416 | all_nodes = all_nodes_in_edges 417 | print("num unique nodes", len(all_nodes.unique())) 418 | i_neg = torch.tensor([]) 419 | j_neg = torch.tensor([]) 420 | for node in nodes: 421 | node = int(node) 422 | i_neg_node = all_nodes[all_nodes != node] 423 | j_neg_node = torch.tensor([node for _ in i_neg_node.flatten()]) 424 | 425 | if node in all_nodes: 426 | i_neg_node, j_neg_node = remove_overlaps( 427 | (i_neg_node, j_neg_node), all_pos_edges 428 | ) 429 | else: 430 | print(f"Node {node} not in reactants.") 431 | 432 | print("num unique nodes", len(all_nodes.unique())) 433 | i_neg = torch.cat((i_neg, i_neg_node)) 434 | j_neg = torch.cat((j_neg, j_neg_node)) 435 | 436 | neg_edges = torch.stack((i_neg, j_neg), dim=0) 437 | 438 | return neg_edges 439 | 440 | 441 | def one_against_most_reactive(nodes, edge_index, all_pos_edges, cutoff=2): 442 | """Sample negative edges, by keeping a fixed target node and sampling all possible source nodes. 443 | 444 | Args: 445 | nodes (iterable): Fixed reactant 446 | edge_index (2D tensor or tuple with 2 1D tesor): The edges to base the sampling on. 447 | all_pos_edges (2D tensor or tuple with 2 1D tesor): All edges in the graph. 448 | 449 | Return: 450 | neg_edges (2D tensors): Negative edges. 451 | """ 452 | all_nodes = torch.cat((edge_index[0], edge_index[1])).flatten().tolist() 453 | print("len(all_nodes)", len(all_nodes), "max ", max(all_nodes)) 454 | 455 | count_reactants = Counter(all_nodes) 456 | reactive_reactants = [item for item in set(all_nodes) if count_reactants[item] > 5] 457 | reactive_reactants = torch.tensor(list(set(reactive_reactants))) 458 | print("num unique nodes", len(reactive_reactants)) 459 | 460 | i_neg = torch.tensor([]) 461 | j_neg = torch.tensor([]) 462 | for node in nodes: 463 | node = int(node) 464 | i_neg_node = reactive_reactants[reactive_reactants != node] 465 | j_neg_node = torch.tensor([node for _ in i_neg_node.flatten()]) 466 | 467 | if node in all_nodes: 468 | i_neg_node, j_neg_node = remove_overlaps( 469 | (i_neg_node, j_neg_node), all_pos_edges 470 | ) 471 | else: 472 | print(f"Node {node} not in reactants.") 473 | 474 | # print('num unique nodes', len(reactive_reactants)) 475 | i_neg = torch.cat((i_neg, i_neg_node)) 476 | j_neg = torch.cat((j_neg, j_neg_node)) 477 | 478 | neg_edges = torch.stack((i_neg, j_neg), dim=0) 479 | 480 | return neg_edges 481 | 482 | 483 | def all_against_all(nodes, edge_index, all_pos_edges, include_unconnected=False): 484 | """Sample negative edges, by keeping a fixed target node and sampling all possible source nodes. 485 | 486 | Args: 487 | nodes (iterable): Fixed reactant 488 | edge_index (2D tensor or tuple with 2 1D tesor): The edges to base the sampling on. 489 | all_pos_edges (2D tensor or tuple with 2 1D tesor): All edges in the graph. 490 | 491 | Return: 492 | neg_edges (2D tensors): Negative edges. 493 | """ 494 | all_nodes_in_edges = torch.cat((edge_index[0], edge_index[1])).flatten().unique() 495 | if include_unconnected: 496 | all_nodes = torch.arange(1, max(all_nodes_in_edges)) 497 | else: 498 | all_nodes = all_nodes_in_edges 499 | print("num unique nodes", len(all_nodes.unique())) 500 | 501 | i_neg = torch.tensor([]) 502 | j_neg = torch.tensor([]) 503 | for node in nodes: 504 | node = int(node) 505 | i_neg_node = all_nodes[all_nodes != node] 506 | j_neg_node = torch.tensor([node for _ in i_neg_node.flatten()]) 507 | 508 | if node in all_nodes: 509 | i_neg_node, j_neg_node = remove_overlaps( 510 | (i_neg_node, j_neg_node), all_pos_edges 511 | ) 512 | else: 513 | print(f"Node {node} not in reactants.") 514 | 515 | print("num unique nodes", len(all_nodes.unique())) 516 | i_neg = torch.cat((i_neg, i_neg_node)) 517 | j_neg = torch.cat((j_neg, j_neg_node)) 518 | 519 | neg_edges = torch.stack((i_neg, j_neg), dim=0) 520 | 521 | return neg_edges 522 | 523 | -------------------------------------------------------------------------------- /utils/reactions_info.py: -------------------------------------------------------------------------------- 1 | 2 | import warnings 3 | 4 | with warnings.catch_warnings(): 5 | warnings.filterwarnings("ignore", category=RuntimeWarning) 6 | import graph_tool.all as gt 7 | 8 | 9 | def get_reactants_and_product_index(graph): 10 | """Returns one list with the reactant molecule nodes and one list with the corresponding 11 | products molecule nodes. 12 | """ 13 | reactant_index = [] 14 | product_index = [] 15 | 16 | reaction_class = [] 17 | class_id = [] 18 | 19 | reactions = gt.find_vertex(graph, graph.vertex_properties["labels"], ":Reaction") 20 | 21 | for r in reactions: 22 | in_neighbors = r.in_neighbors() 23 | in_neighbors = [int(n) for n in in_neighbors] 24 | if len(in_neighbors) != 0: 25 | reactant_index.append(in_neighbors) 26 | 27 | out_neighbors = r.out_neighbors() 28 | out_neighbors = [int(n) for n in out_neighbors] 29 | product_index.append(out_neighbors) 30 | 31 | try: 32 | reaction_class.append(graph.vertex_properties["reaction_class"][r]) 33 | class_id.append(graph.vertex_properties["class_id"][r]) 34 | except: 35 | reaction_class.append(None) 36 | class_id.append(None) 37 | 38 | return reactant_index, product_index, reaction_class, class_id 39 | 40 | 41 | def get_smiles_to_index_dict(graph): 42 | 43 | smiles_to_index = {} 44 | for v in graph.get_vertices(): 45 | smiles = graph.vertex_properties["smiles"][v] 46 | smiles_to_index[smiles] = v 47 | 48 | return smiles_to_index 49 | 50 | 51 | def get_index_to_smiles_dict(graph): 52 | index_to_smiles = {} 53 | 54 | for v in graph.get_vertices(): 55 | smiles = graph.vertex_properties["smiles"][v] 56 | index_to_smiles[v] = smiles 57 | 58 | return index_to_smiles 59 | 60 | 61 | def shortest_distance(graph, edges): 62 | distances = [] 63 | count = 0 64 | for node_1, node_2 in edges: 65 | count += 1 66 | dist = shortest_distance( 67 | graph, source=int(node_1), target=int(node_2), max_dist=50, directed=False 68 | ) 69 | if dist == 2147483647: 70 | distances.append(50) 71 | else: 72 | distances.append(dist) 73 | return distances 74 | 75 | 76 | def shortest_distance_positive(graph, edges): 77 | distances = [] 78 | count = 0 79 | 80 | for node1, node2 in edges: 81 | count += 1 82 | graph.remove_edge(graph.edge(node1, node2)) 83 | dist = shortest_distance( 84 | graph, source=int(node1), target=int(node2), max_dist=50, directed=False 85 | ) 86 | graph.add_edge(node1, node2) 87 | if dist == 2147483647: 88 | distances.append(50) 89 | else: 90 | distances.append(dist) 91 | return distances 92 | 93 | --------------------------------------------------------------------------------