├── .gitignore ├── LICENSE ├── README.md ├── envs ├── environment.yml ├── poetry.lock ├── pyproject.toml └── test_pyg.py └── lsfml ├── clustering └── cluster.py ├── config ├── config_140.ini ├── config_141.ini ├── config_142.ini ├── config_143.ini ├── config_161.ini ├── config_171.ini ├── config_210.ini ├── config_211.ini ├── config_323.ini └── config_420.ini ├── data ├── drugs_data.tsv ├── experimental_rxndata.csv ├── literature_regio.csv ├── literature_rxndata.csv ├── roche_hte_regio_data.csv └── surf_template.tsv ├── experimental ├── net.py ├── net_utils.py ├── preprocessh5.py ├── property_analysis.py └── train.py ├── fganalysis ├── ertl.py └── fg_analysis.py ├── img ├── NCHEM-22102062A_figure1.png └── regio_example.jpg ├── literature ├── regioselectivity │ ├── graph_mapping.py │ ├── net.py │ ├── net_utils.py │ ├── production.py │ └── train.py └── rxnyield │ ├── net.py │ ├── net_utils.py │ ├── preprocessh5.py │ └── train.py ├── modules ├── gmt.py ├── gnn_blocks.py └── pygdataset.py ├── qml ├── config_14000.ini ├── model1.pkl ├── prod.py ├── qml_net.py └── test.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | __pycache__/ 3 | 4 | *.DS_Store 5 | 6 | lsf.* 7 | .nfs* 8 | 9 | *.pkl 10 | *.pickle 11 | *.pyc 12 | *.json 13 | *.h5 14 | *.txt 15 | *.pt 16 | *.sdf 17 | *.html 18 | *.pdb 19 | *.cdxml 20 | *.xlsx 21 | *.out 22 | *.txtx 23 | *.png 24 | *.jpg 25 | *.csv 26 | *.ini 27 | *.xyz 28 | *.pdf 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (©) 2023 Kenneth Atz, & Gisbert Schneider (ETH Zurich) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Enabling late-stage drug diversification by high-throughput experimentation with geometric deep learning 2 | 3 | [![python](https://img.shields.io/badge/Python-3.8-3776AB.svg?style=flat&logo=python&logoColor=white)](https://www.python.org) 4 | [![pytorch](https://img.shields.io/badge/PyTorch-1.13.1-EE4C2C.svg?style=flat&logo=pytorch)](https://pytorch.org) 5 | [![RDKit badge](https://img.shields.io/badge/Powered%20by-RDKit-3838ff.svg?logo=)](https://www.rdkit.org/) 6 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 7 | [![MIT license](https://img.shields.io/badge/License-MIT-blue.svg)](https://lbesson.mit-license.org/) 8 | [![DOI:10.1000/XXX-X-XX-XXX-X_X](https://zenodo.org/badge/DOI/10.1000/XXX-X-XX-XXX-X_X.svg)](https://doi.org/10.1000/XXX-X-XX-XXX-X_X) 9 | [![Citations](https://api.juleskreuer.eu/citation-badge.php?doi=10.1000/XXX-X-XX-XXX-X_X)](https://juleskreuer.eu/projekte/citation-badge/) 10 | 11 | ![](lsfml/img/NCHEM-22102062A_figure1.png) 12 | 13 | This repository contains a reference implementation to preprocess the data, as well as to train and apply the graph machine learning models introduced in David F. Nippa, Kenneth Atz, Remo Hohler, Alex T. Müller, Andreas Marx, Christian Bartelmus, Georg Wuitschik, Irene Marzuoli, Vera Jost, Jens Wolfard, Martin Binder, Antonia F. Stepan, David B. Konrad, Uwe Grether, Rainer E. Martin & Gisbert Schneider, Journal, X, XX (2023). 14 | 15 | 16 | ## 1. Environment 17 | Create and activate the lsfml environment. 18 | 19 | ``` 20 | cd envs/ 21 | conda env create -f environment.yml 22 | conda activate lsfml 23 | poetry install 24 | ``` 25 | 26 | Add the "lsfml path" as PYTHONPATH to your `~/.bashrc` file. 27 | 28 | ``` 29 | export PYTHONPATH="${PYTHONPATH}:/lsfml/" 30 | ``` 31 | 32 | Source your `~/.bashrc`. 33 | 34 | ``` 35 | source `~/.bashrc` 36 | conda activate lsfml 37 | ``` 38 | 39 | Test your installation by running `test_pyg.py`. 40 | 41 | ``` 42 | python test_pyg.py 43 | >>> torch_geometric.__version__: 2.3.0 44 | >>> torch.__version__: 1.13.1 45 | >>> rdkit.__version__: 2022.09.5 46 | ``` 47 | 48 | ## 2. Data 49 | The `data/` directory contains five files: (i) `drugs_data.tsv` containing the 1344 approved drug molecules, (ii) `experimental_rxndata.csv` containing the 956 reactions generated by automated high-throughput experimentation (HTE) formatted in the simple user-friendly reaction format (SURF), (iii) `literature_rxndata.csv` containing the 1301 reactions extracted from literature formatted in SURF, (iv) `roche_hte_regio_data.csv` containing the borylation reactions prospectively conducted within the scope of this paper formatted in SURF, and (v) `surf_template.tsv` containing a SURF template. 50 | ``` 51 | cd lsfml/ 52 | ls data/ 53 | drugs_data.tsv experimental_rxndata.csv literature_rxndata.csv roche_hte_regio_data.csv surf_template.tsv 54 | ``` 55 | 56 | ## 3. Clustering 57 | `cluster.py` filters the 1344 approved drug molecules in `data/drugs_data.tsv` by molecular weight (default: 200 - 800 g/mol) and clusters them into a defined number of clusters (default=8). Molecules are represented by a similarity matrix calculated from extended-connectivity fingerprints (ECFPs). This representation is chosen to capture structural features within the clustering. The resulting clusters are saved as individual `.csv` files in `cluster_analysis/`. 58 | ``` 59 | cd clustering/ 60 | python cluster.py 61 | ``` 62 | 63 | ## 4. Functional group analysis 64 | `fg_analysis.py` analyzes the functional groups present in `data/experimental_rxndata.csv` and their influence to reaction outcome. The functional groups are calculated based on Ertl 2017 (https://doi.org/10.1186/s13321-017-0225-z) and the resulting plots are stored in `fg_analysis/`. 65 | ``` 66 | cd fganalysis/ 67 | python fg_analysis.py 68 | ``` 69 | 70 | ## 5. Quantum machine learning 71 | `qml/` contains the quantum machine learning (qml) scripts that enable the prediction of DFT-level partial charges. The models are based on Atz, Isert et al. 2022 (https://doi.org/10.1039/D2CP00834C) and were trained on quantum chemical properties using the ωB97X-D functional and the def2-SVP basis set. `prod.py` contains the initialization of the qml model that is imported by scripts in `literature/` and `experimental/`. 72 | ``` 73 | ls qml/ 74 | config_14000.ini model1.pkl prod.py qml_net.py test_mols test.py 75 | python test.py 76 | >>> QML model has been sent to cpu 77 | >>> Partial charge prediction is conducted with the following mean absolute errors: 78 | >>> CHEMBL1 0.0039913069590710925 79 | >>> CHEMBL2 0.002572119902126642 80 | >>> CHEMBL3 0.0026494266223815776 81 | >>> CHEMBL4 0.0029478159343548434 82 | >>> CHEMBL5 0.0030516667701088154 83 | >>> CHEMBL6 0.002998803614110483 84 | >>> CHEMBL8 0.003159851694464683 85 | ``` 86 | 87 | ## 6. Experimental data (reaction yield + binary reaction outcome) 88 | `experimental/` contains the scripts to process the HTE reaction data and train the different graph neural networks (GNNs) to predict reaction yield and binary reaction outcome. First, `preprocessh5.py` loads the reaction data from `data/experimental_rxndata.csv`, loops over the substrates and reactions and calculates molecular graphs, different 3D conformations and DFT-level partial charges. The resulting information is stored in `data/experimental_rxndata.h5`. 89 | ``` 90 | cd experimental/ 91 | python preprocessh5.py 92 | ``` 93 | 94 | Once `data/experimental_rxndata.h5` is generated, the models can be trained via `train.py`. `train.py` imports the data loader from `net_utils.py`, the network architectures from `net.py` and the hyperparameters from one of the config files in `config/`. The config files contain the model and training hyperparameters, as well as information about data set split (SPLIT, random or eln (i.e. substrate type)), target (TARGET, mono (i.e. yield) or binary), graph dimension (GEOMETRY, 1 (i.e. 3D) or 0 (i.e. 2D)) and electronic properties (QML, 1 (i.e. including partial charges) or 0 (i.e. no partial charges)). An exemplary config file (`config/config_323.ini`) using a random split, reaction yield as target, 3D graphs and quantum chemical information looks as following: 95 | 96 | ``` 97 | cat config/config_323.ini 98 | 99 | [PARAMS] 100 | LR_FACTOR = 1e-4 101 | LR_STEP_SIZE = 100 102 | N_KERNELS = 3 103 | POOLING_HEADS = 4 104 | D_MLP = 512 105 | D_KERNEL = 128 106 | D_EMBEDDING = 128 107 | BATCH_SIZE = 16 108 | SPLIT = random 109 | ELN = ELN036496-146 110 | TARGET = mono 111 | GEOMETRY = 1 112 | QML = 1 113 | FINGERPRINT = ecfp4_2 114 | ``` 115 | Further, arguments allow to set additional training parameters, such as `-config` (config file in `config/`), `-mode` (neural network type: a: GTNN, b: GNN, c: FNN), `-cv` (cross validation), `-testset` (test set, 1 - 4), and `-early_stop` (early stopping using a validation set: 1 = Yes, 0 = No; If set to 0, then the final model is stored at epoch=1000). Once the training and model parameters are chosen, the training script can be run. 116 | 117 | ``` 118 | python train.py -config 323 -mode a -cv 1 -testset 1 -early_stop 0 119 | ``` 120 | The training script generates two directories where the models (`models/`) and results (`results/`) are stored. 121 | 122 | The HTE dataset can be analyzed via `property_analysis.py`, which saves plots about molecular properties of the substrates and the number of successful reactions for the different reaction conditions into `analysis/`. 123 | 124 | ``` 125 | python property_analysis.py 126 | ``` 127 | 128 | ## 7. Literature data (reaction yield) 129 | The literature data directory follows an identical structure as the experimental data described above. First, `preprocessh5.py ` reads the reactions and sustrates from `data/literature_rxndata.csv` and stores the resulting data in `data/literature_rxndata.h5`. 130 | 131 | ``` 132 | cd literature/rxnyield/ 133 | python preprocessh5.py 134 | ``` 135 | 136 | Subsequently, the different neural networks (GTNN, GNN, FNN) using different molecular graphs (2D, 3D, 2DQM and 3DQM) can be trained in an identical fashion as described above for the experimental data. The main two difference between the two training procedures are that (i) the number of different catalysts, solvents, ligands and other conditions (temp., conc. etc.) is much more diverse in the literature data set, and (ii) models can only be trained for reaction yield prediction (and not for binary reaction outcome) since the literature data set does not contain negative examples. Once the training and model parameters are chosen, the training script can be run. 137 | 138 | ``` 139 | python train.py -config 420 -mode a -cv 1 -early_stop 0 140 | ``` 141 | The training script generates two directories where the models (`models/`) and results (`results/`) are stored. 142 | 143 | ## 8. Literature data (regioselectivity) 144 | `graph_mapping.py` filters the literature data from `data/literature_rxndata.csv` (e.g. by yield >= 30%, only major products, only mono borylation etc.), extracts the regioselectivity information, and stores the processed molecules in `data/literature_regio.h5`. Additionally, the file `data/literature_regio.csv` is generated containing an overview of all structures used to train the regioselectivity models. 145 | ``` 146 | cd literature/regioselectivity/ 147 | python graph_mapping.py 148 | ``` 149 | Deuterium (2H) was chosen to be placed at the C-H bond of borylation in `data/literature_regio.csv`. This labelling-procedure enabled to conserve regioselectivity information within the substrates while neither influencing steric (conformer generation) nor electronic properties (quantum machine learning). 150 | 151 | Once `data/literature_regio.h5` is generated, the different regioselectivity models (aGNN2D, aGNN2DQM, aGNN3D and aGNN3DQM) can be trained in an identical fashion using config files (`config/`) and arguments as described above for the literature and experimental data sets. 152 | ``` 153 | python train.py -config 141 -mode a -cv 1 -early_stop 0 154 | ``` 155 | 156 | Trained models can then be applied via `literature/regioselectivity/production.py`, where in the `__main__` the model id is specified and a list of substrate SMILES strings is added. The predicted probabilities of borylation per carbon atom are printed out and a figure of each substrate highlighting the predicted site(s) of borylation is stored in `regiosel_imgs/`. 157 | 158 | ``` 159 | python production.py > predicted_161.txt 160 | ``` 161 | The individual predictions are then stored in figures like the following: 162 | ![](lsfml/img/regio_example.jpg) 163 | 164 | ## 9. Additional neural network modules 165 | `modules/`contains additional essential neural network modules used by the different graph neural networks. `gnn_blocks.py` contains the message passing functions for 2D and 3D graphs that are used by all GNNs and GTNNs. `gmt.py` contains the graph multiset transformer-based pooling that is used by all GTNNs. `pygdataset.py` contains a specific dataloader to generate the graph data objects for the processing of PyG. 166 | ``` 167 | ls modules/ 168 | gmt.py gnn_blocks.py pygdataset.py 169 | ``` 170 | 171 | ## 10. License 172 | The software was developed at ETH Zurich and is licensed by the MIT license, i.e. discribed in `LICENSE`. 173 | 174 | ## 11. Citation 175 | ``` 176 | @article{nippa_atz2022enabling, 177 | title={Enabling late-stage drug diversification by high-throughput experimentation with geometric deep learning}, 178 | author={Nippa, David F. and Atz, Kenneth and Hohler, Remo and M{\"u}ller, Alex T. and Marx, Andreas and Bartelmus, Christian and Wuitschik, Georg and Marzuoli, Irene and Jost, Vera and Wolfard, Jens and Binder, Martin and Stepan, Antonia F. and Konrad, David B. and Grether, Uwe and Martin, Rainer E. and Schneider, Gisbert}, 179 | journal={}, 180 | year={} 181 | } 182 | To Do: To be updated with DOI and journal. 183 | ``` 184 | 185 | -------------------------------------------------------------------------------- /envs/environment.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: lsfml 3 | channels: 4 | - pytorch 5 | - conda-forge 6 | - pyg 7 | - nodefaults 8 | - nvidia 9 | dependencies: 10 | - python=3.8 11 | - numpy=<1.20 12 | - libcurl 13 | - pytorch=1.13.1 14 | - cudatoolkit=11.7 15 | - torchvision=0.14.1 16 | - torchaudio=0.13.1 17 | - pyg=2.3.0 18 | # utilities 19 | - pip 20 | - pip: 21 | - poetry 22 | -------------------------------------------------------------------------------- /envs/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "lsfml" 3 | version = "0.1.0" 4 | description = "Reaction predcition with PyTorch Geometric" 5 | authors = ["Kenneth Atz ",] 6 | license = "See License file" 7 | readme = "README.md" 8 | 9 | homepage = "https://github.com/atzkenneth/lsfml" 10 | repository = "https://github.com/atzkenneth/lsfml" 11 | documentation = "https://github.com/atzkenneth/lsfml" 12 | 13 | [tool.poetry.dependencies] 14 | # pip packages 15 | python = "~3.8" 16 | click = "~8" 17 | "ruamel.yaml" = "~0.16" 18 | pyarrow = "~3.0" 19 | tqdm = "~4" 20 | h5py = "~3.7" 21 | lightgbm = "~3.1" 22 | pyopenssl = "20.0.0" 23 | pandas = "^1.3" 24 | rdkit = "^2022.9" 25 | XlsxWriter = "^3.1.2" 26 | openpyxl = "^3.1.2" 27 | networkx = "^3.1" 28 | einops = "^0.6.1" 29 | 30 | [tool.poetry.dev-dependencies] 31 | # notebooks 32 | ipykernel = "^5.5" 33 | ipython_genutils = "^0.2" 34 | matplotlib = "~3.4" 35 | # testing 36 | pytest = "^7.1.2" 37 | pytest-cov = "^3.0.0" 38 | # formatting, linting, etc 39 | black = "^22.6.0" 40 | flake8 = "^4.0.1" 41 | pydocstyle = "^6.1.1" 42 | isort = "^5.10.1" 43 | yamllint = "^1.27.1" 44 | pre-commit = "*" 45 | 46 | [build-system] 47 | requires = ["poetry-core>=1.0.0"] 48 | build-backend = "poetry.core.masonry.api" 49 | 50 | ############### 51 | # Other tools # 52 | ############### 53 | 54 | [tool.black] 55 | line-length = 99 56 | target-version = ['py38'] 57 | 58 | [tool.isort] 59 | profile = "black" 60 | -------------------------------------------------------------------------------- /envs/test_pyg.py: -------------------------------------------------------------------------------- 1 | import rdkit 2 | from rdkit import Chem 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import torch_geometric 9 | from torch_geometric.nn.conv import MessagePassing 10 | from torch_geometric.nn.aggr import GraphMultisetTransformer 11 | from torch_geometric.typing import Adj, Size, Tensor 12 | from torch_geometric.utils.scatter import scatter 13 | 14 | if __name__ == "__main__": 15 | print(f"torch_geometric.__version__: {torch_geometric.__version__}") 16 | print(f"torch.__version__: {torch.__version__}") 17 | print(f"rdkit.__version__: {rdkit.__version__}") -------------------------------------------------------------------------------- /lsfml/config/config_140.ini: -------------------------------------------------------------------------------- 1 | [PARAMS] 2 | LR_FACTOR = 1e-4 3 | LR_STEP_SIZE = 100 4 | N_KERNELS = 3 5 | D_MLP = 512 6 | D_KERNEL = 128 7 | D_EMBEDDING = 128 8 | BATCH_SIZE = 16 9 | GEOMETRY = 1 10 | QML = 0 -------------------------------------------------------------------------------- /lsfml/config/config_141.ini: -------------------------------------------------------------------------------- 1 | [PARAMS] 2 | LR_FACTOR = 1e-4 3 | LR_STEP_SIZE = 100 4 | N_KERNELS = 3 5 | D_MLP = 512 6 | D_KERNEL = 128 7 | D_EMBEDDING = 128 8 | BATCH_SIZE = 16 9 | GEOMETRY = 1 10 | QML = 1 -------------------------------------------------------------------------------- /lsfml/config/config_142.ini: -------------------------------------------------------------------------------- 1 | [PARAMS] 2 | LR_FACTOR = 1e-4 3 | LR_STEP_SIZE = 100 4 | N_KERNELS = 3 5 | D_MLP = 512 6 | D_KERNEL = 128 7 | D_EMBEDDING = 128 8 | BATCH_SIZE = 16 9 | GEOMETRY = 0 10 | QML = 0 -------------------------------------------------------------------------------- /lsfml/config/config_143.ini: -------------------------------------------------------------------------------- 1 | [PARAMS] 2 | LR_FACTOR = 1e-4 3 | LR_STEP_SIZE = 100 4 | N_KERNELS = 3 5 | D_MLP = 256 6 | D_KERNEL = 128 7 | D_EMBEDDING = 128 8 | BATCH_SIZE = 16 9 | GEOMETRY = 0 10 | QML = 1 -------------------------------------------------------------------------------- /lsfml/config/config_161.ini: -------------------------------------------------------------------------------- 1 | [PARAMS] 2 | LR_FACTOR = 1e-4 3 | LR_STEP_SIZE = 100 4 | N_KERNELS = 3 5 | D_MLP = 512 6 | D_KERNEL = 128 7 | D_EMBEDDING = 128 8 | BATCH_SIZE = 16 9 | GEOMETRY = 1 10 | QML = 1 -------------------------------------------------------------------------------- /lsfml/config/config_171.ini: -------------------------------------------------------------------------------- 1 | [PARAMS] 2 | LR_FACTOR = 1e-4 3 | LR_STEP_SIZE = 100 4 | N_KERNELS = 3 5 | D_MLP = 512 6 | D_KERNEL = 128 7 | D_EMBEDDING = 128 8 | BATCH_SIZE = 16 9 | GEOMETRY = 1 10 | QML = 1 -------------------------------------------------------------------------------- /lsfml/config/config_210.ini: -------------------------------------------------------------------------------- 1 | [PARAMS] 2 | LR_FACTOR = 1e-4 3 | LR_STEP_SIZE = 100 4 | N_KERNELS = 3 5 | D_MLP = 512 6 | D_KERNEL = 128 7 | D_EMBEDDING = 128 8 | BATCH_SIZE = 16 9 | GEOMETRY = 1 10 | QML = 0 -------------------------------------------------------------------------------- /lsfml/config/config_211.ini: -------------------------------------------------------------------------------- 1 | [PARAMS] 2 | LR_FACTOR = 1e-4 3 | LR_STEP_SIZE = 100 4 | N_KERNELS = 3 5 | D_MLP = 512 6 | D_KERNEL = 128 7 | D_EMBEDDING = 128 8 | BATCH_SIZE = 16 9 | GEOMETRY = 0 10 | QML = 0 -------------------------------------------------------------------------------- /lsfml/config/config_323.ini: -------------------------------------------------------------------------------- 1 | [PARAMS] 2 | LR_FACTOR = 1e-4 3 | LR_STEP_SIZE = 100 4 | N_KERNELS = 3 5 | POOLING_HEADS = 4 6 | D_MLP = 512 7 | D_KERNEL = 128 8 | D_EMBEDDING = 128 9 | BATCH_SIZE = 16 10 | SPLIT = random 11 | ELN = ELN036496-146 12 | TARGET = mono 13 | GEOMETRY = 1 14 | QML = 1 15 | FINGERPRINT = ecfp4_2 -------------------------------------------------------------------------------- /lsfml/config/config_420.ini: -------------------------------------------------------------------------------- 1 | [PARAMS] 2 | LR_FACTOR = 1e-4 3 | LR_STEP_SIZE = 100 4 | N_KERNELS = 3 5 | POOLING_HEADS = 4 6 | D_MLP = 512 7 | D_KERNEL = 128 8 | D_EMBEDDING = 128 9 | BATCH_SIZE = 16 10 | SPLIT = random 11 | ELN = ELN036496-146 12 | TARGET = binary 13 | GEOMETRY = 0 14 | QML = 0 15 | FINGERPRINT = ecfp4_2 -------------------------------------------------------------------------------- /lsfml/data/literature_rxndata.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ETHmodlab/lsfml/d0178f1ebfedd73639cee2452fadacc500ca23e1/lsfml/data/literature_rxndata.csv -------------------------------------------------------------------------------- /lsfml/data/roche_hte_regio_data.csv: -------------------------------------------------------------------------------- 1 | rxn_id,startingmat_smiles,product_smiles 2 | 1b,O=C(N1CC/C(CC1)=C2C3=C(CCC4=C\2C=CC(Cl)=C4)C=CC=N3)OCC,O=C(N1CC/C(CC1)=C2C3=C(CCC4=C\2C=C(B5OC(C)(C)C(C)(C)O5)C(Cl)=C4)C=C(B6OC(C)(C)C(C)(C)O6)C=N3)OCC 3 | 25a,O=C1OC2=C(C(O)=C1C(CC(C)=O)C3=CC=CC=C3)C=CC=C2,O=C1OC2=C(C(O)=C1C(CC(C)=O)C3=CC=CC=C3)C=CC(B4OC(C)(C)C(C)(C)O4)=C2 4 | 29a,O=C1NC2=C(C)C=CN=C2N(C3CC3)C4=NC=CC=C14,O=C1NC2=C(C)C=CN=C2N(C3CC3)C4=NC=C(B5OC(C)(C)C(C)(C)O5)C=C14 5 | 37a,CC1=CC=C(N1C2=NN(C)C=C2)C,CC1=CC=C(N1C2=NN(C)C(B3OC(C)(C)C(C)(C)O3)=C2)C 6 | 38a,BrC1=CN=C(C=C1)N2CCOCC2,BrC1=CN=C(C=C1B2OC(C)(C)C(C)(C)O2)N3CCOCC3 7 | 39a,FC(F)(C1=NN(C)C2=C1C=CC(Br)=C2)F,FC(F)(C1=NN(CB2OC(C)(C)C(C)(C)O2)C3=C1C=CC(Br)=C3)F 8 | 45a,O=C(N1CCOCC1)C2=CC=C3C=C(C=CC3=C2)O,O=C(N1CCOCC1)C2=CC(B3OC(C)(C)C(C)(C)O3)=C4C=C(C=C(B5OC(C)(C)C(C)(C)O5)C4=C2)O -------------------------------------------------------------------------------- /lsfml/data/surf_template.tsv: -------------------------------------------------------------------------------- 1 | rxn_id source_id source_type rxn_date rxn_type rxn_name rxn_tech temperature_degC time_h atmosphere stirring_shaking scale_mol concentration_mol startingmat_1_cas startingmat_1_smiles startingmat_1_eq startingmat_2_cas startingmat_2_smiles startingmat_2_eq reagent_1_cas reagent_1_smiles reagent_1_eq reagent_2_cas reagent_2_smiles reagent_2_eq catalyst_1_cas catalyst_1_smiles catalyst_1_eq ligand_1_cas ligand_1_smiles ligand_1_eq additive_1_cas additive_1_smiles additive_1_eq additive_2_cas additive_2_smiles additive_2_eq solvent_1_cas solvent_1_smiles solvent_1_fraction solvent_2_cas solvent_2_smiles solvent_2_fraction product_1_cas product_1_smiles product_1_yield product_1_yieldtype product_1_ms product_1_nmr product_2_cas product_2_smiles product_2_yield product_2_yieldtype product_2_ms product_2_nmr procedure comment 2 | 3 | -------------------------------------------------------------------------------- /lsfml/experimental/net.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from lsfml.modules.gmt import GraphMultisetTransformer 10 | from lsfml.modules.gnn_blocks import ( 11 | EGNN_sparse, 12 | EGNN_sparse3D, 13 | weights_init, 14 | scatter_sum, 15 | ) 16 | 17 | 18 | class GraphTransformer(nn.Module): 19 | """Graph Transformer neural network (GTNN) for yield and binary predictions.""" 20 | 21 | def __init__( 22 | self, 23 | n_kernels=3, 24 | pooling_heads=8, 25 | mlp_dim=512, 26 | kernel_dim=64, 27 | embeddings_dim=64, 28 | qml=True, 29 | geometry=True, 30 | ): 31 | """Initialization of GTNN 32 | 33 | :param n_kernels: Number of message passing functions, defaults to 3 34 | :type n_kernels: int, optional 35 | :param pooling_heads: Number of Transformers, defaults to 8 36 | :type pooling_heads: int, optional 37 | :param mlp_dim: Feature dimension within the multi layer perceptrons, defaults to 512 38 | :type mlp_dim: int, optional 39 | :param kernel_dim: Feature dimension within the message passing fucntions, defaults to 64 40 | :type kernel_dim: int, optional 41 | :param embeddings_dim: Embedding dimension of the input features (e.g. reaction conditions and atomic features), defaults to 64 42 | :type embeddings_dim: int, optional 43 | :param qml: Option to include DFT-level partial charges, defaults to True 44 | :type qml: bool, optional 45 | :param geometry: Option to include steric information in the input graph, defaults to True 46 | :type geometry: bool, optional 47 | """ 48 | super(GraphTransformer, self).__init__() 49 | 50 | self.embeddings_dim = embeddings_dim 51 | self.m_dim = 16 52 | self.kernel_dim = kernel_dim 53 | self.n_kernels = n_kernels 54 | self.aggr = "add" 55 | self.pos_dim = 3 56 | self.pooling_heads = pooling_heads 57 | self.mlp_dim = mlp_dim 58 | self.qml = qml 59 | self.geometry = geometry 60 | 61 | dropout = 0.1 62 | self.dropout = nn.Dropout(dropout) 63 | 64 | self.atom_em = nn.Embedding(num_embeddings=10, embedding_dim=self.embeddings_dim) 65 | self.ring_em = nn.Embedding(num_embeddings=2, embedding_dim=self.embeddings_dim) 66 | self.hybr_em = nn.Embedding(num_embeddings=4, embedding_dim=self.embeddings_dim) 67 | self.arom_em = nn.Embedding(num_embeddings=2, embedding_dim=self.embeddings_dim) 68 | 69 | if self.qml: 70 | self.chrg_em = nn.Linear(1, self.embeddings_dim) 71 | self.pre_egnn_mlp_input_dim = self.embeddings_dim * 5 72 | self.chrg_em.apply(weights_init) 73 | else: 74 | self.pre_egnn_mlp_input_dim = self.embeddings_dim * 4 75 | 76 | self.pre_egnn_mlp = nn.Sequential( 77 | nn.Linear(self.pre_egnn_mlp_input_dim, self.kernel_dim * 2), 78 | self.dropout, 79 | nn.SiLU(), 80 | nn.Linear(self.kernel_dim * 2, self.kernel_dim), 81 | nn.SiLU(), 82 | nn.Linear(self.kernel_dim, self.kernel_dim), 83 | nn.SiLU(), 84 | ) 85 | 86 | self.kernels = nn.ModuleList() 87 | for _ in range(self.n_kernels): 88 | if self.geometry: 89 | self.kernels.append( 90 | EGNN_sparse3D( 91 | feats_dim=self.kernel_dim, 92 | m_dim=self.m_dim, 93 | aggr=self.aggr, 94 | ) 95 | ) 96 | else: 97 | self.kernels.append( 98 | EGNN_sparse( 99 | feats_dim=self.kernel_dim, 100 | m_dim=self.m_dim, 101 | aggr=self.aggr, 102 | ) 103 | ) 104 | 105 | self.post_egnn_mlp = nn.Sequential( 106 | nn.Linear(self.kernel_dim * self.n_kernels, self.kernel_dim), 107 | self.dropout, 108 | nn.SiLU(), 109 | nn.Linear(self.kernel_dim, self.kernel_dim), 110 | nn.SiLU(), 111 | nn.Linear(self.kernel_dim, self.kernel_dim), 112 | nn.SiLU(), 113 | ) 114 | 115 | self.transformers = nn.ModuleList() 116 | for _ in range(self.pooling_heads): 117 | self.transformers.append( 118 | GraphMultisetTransformer( 119 | in_channels=self.kernel_dim, 120 | hidden_channels=self.kernel_dim, 121 | out_channels=self.kernel_dim, 122 | pool_sequences=["GMPool_G", "SelfAtt", "GMPool_I"], 123 | num_heads=1, 124 | layer_norm=True, 125 | ) 126 | ) 127 | 128 | self.lig_emb = nn.Embedding(num_embeddings=6, embedding_dim=self.embeddings_dim) 129 | self.sol_emb = nn.Embedding(num_embeddings=4, embedding_dim=self.embeddings_dim) 130 | 131 | self.post_pooling_mlp_input_dim = self.kernel_dim * (self.pooling_heads) + self.embeddings_dim * 2 132 | 133 | self.post_pooling_mlp = nn.Sequential( 134 | nn.Linear(self.post_pooling_mlp_input_dim, self.mlp_dim), 135 | self.dropout, 136 | nn.SiLU(), 137 | nn.Linear(self.mlp_dim, self.mlp_dim), 138 | nn.SiLU(), 139 | nn.Linear(self.mlp_dim, self.mlp_dim), 140 | nn.SiLU(), 141 | nn.Linear(self.mlp_dim, 1), 142 | ) 143 | 144 | self.transformers.apply(weights_init) 145 | self.kernels.apply(weights_init) 146 | self.pre_egnn_mlp.apply(weights_init) 147 | self.post_egnn_mlp.apply(weights_init) 148 | self.post_pooling_mlp.apply(weights_init) 149 | nn.init.xavier_uniform_(self.atom_em.weight) 150 | nn.init.xavier_uniform_(self.ring_em.weight) 151 | nn.init.xavier_uniform_(self.hybr_em.weight) 152 | nn.init.xavier_uniform_(self.arom_em.weight) 153 | nn.init.xavier_uniform_(self.lig_emb.weight) 154 | nn.init.xavier_uniform_(self.sol_emb.weight) 155 | 156 | def forward(self, g_batch): 157 | """Forward pass of the GTNN. 158 | 159 | :param g_batch: Input graph. 160 | :type g_batch: class 161 | :return: Prediction. 162 | :rtype: Tensor 163 | """ 164 | if self.qml: 165 | features = self.pre_egnn_mlp( 166 | torch.cat( 167 | [ 168 | self.atom_em(g_batch.atom_id), 169 | self.ring_em(g_batch.ring_id), 170 | self.hybr_em(g_batch.hybr_id), 171 | self.arom_em(g_batch.arom_id), 172 | self.chrg_em(g_batch.charges), 173 | ], 174 | dim=1, 175 | ) 176 | ) 177 | else: 178 | features = self.pre_egnn_mlp( 179 | torch.cat( 180 | [ 181 | self.atom_em(g_batch.atom_id), 182 | self.ring_em(g_batch.ring_id), 183 | self.hybr_em(g_batch.hybr_id), 184 | self.arom_em(g_batch.arom_id), 185 | ], 186 | dim=1, 187 | ) 188 | ) 189 | 190 | feature_list = [] 191 | if self.geometry: 192 | features = torch.cat([g_batch.crds_3d, features], dim=1) 193 | for kernel in self.kernels: 194 | features = kernel( 195 | x=features, 196 | edge_index=g_batch.edge_index, 197 | ) 198 | feature_list.append(features[:, self.pos_dim :]) 199 | else: 200 | for kernel in self.kernels: 201 | features = kernel(x=features, edge_index=g_batch.edge_index) 202 | feature_list.append(features) 203 | 204 | features = torch.cat(feature_list, dim=1) 205 | features = self.post_egnn_mlp(features) 206 | 207 | feature_list = [] 208 | for transformer in self.transformers: 209 | feature_list.append(transformer(x=features, batch=g_batch.batch, edge_index=g_batch.edge_index)) 210 | 211 | features = torch.cat(feature_list, dim=1) 212 | del feature_list 213 | 214 | conditions = torch.cat( 215 | [ 216 | self.lig_emb(g_batch.lgnd_id), 217 | self.sol_emb(g_batch.slvn_id), 218 | ], 219 | dim=1, 220 | ) 221 | 222 | features = torch.cat([features, conditions], dim=1) 223 | features = self.post_pooling_mlp(features).squeeze(1) 224 | 225 | return features 226 | 227 | 228 | class EGNN(nn.Module): 229 | """Graph neural network (GNN) using sum pooling for yield and binary predictions.""" 230 | 231 | def __init__(self, n_kernels=3, mlp_dim=512, kernel_dim=64, embeddings_dim=64, qml=True, geometry=True): 232 | """Initialization of GNN 233 | 234 | :param n_kernels: Number of message passing functions, defaults to 3 235 | :type n_kernels: int, optional 236 | :param mlp_dim: Feature dimension within the multi layer perceptrons, defaults to 512 237 | :type mlp_dim: int, optional 238 | :param kernel_dim: Feature dimension within the message passing fucntions, defaults to 64 239 | :type kernel_dim: int, optional 240 | :param embeddings_dim: Embedding dimension of the input features (e.g. reaction conditions and atomic features), defaults to 64 241 | :type embeddings_dim: int, optional 242 | :param qml: Option to include DFT-level partial charges, defaults to True 243 | :type qml: bool, optional 244 | :param geometry: Option to include steric information in the input graph, defaults to True 245 | :type geometry: bool, optional 246 | """ 247 | super(EGNN, self).__init__() 248 | 249 | self.embeddings_dim = embeddings_dim 250 | self.m_dim = 16 251 | self.kernel_dim = kernel_dim 252 | self.n_kernels = n_kernels 253 | self.aggr = "add" 254 | self.pos_dim = 3 255 | self.mlp_dim = mlp_dim 256 | self.qml = qml 257 | self.geometry = geometry 258 | 259 | dropout = 0.1 260 | self.dropout = nn.Dropout(dropout) 261 | 262 | self.atom_em = nn.Embedding(num_embeddings=10, embedding_dim=self.embeddings_dim) 263 | self.ring_em = nn.Embedding(num_embeddings=2, embedding_dim=self.embeddings_dim) 264 | self.hybr_em = nn.Embedding(num_embeddings=4, embedding_dim=self.embeddings_dim) 265 | self.arom_em = nn.Embedding(num_embeddings=2, embedding_dim=self.embeddings_dim) 266 | 267 | if self.qml: 268 | self.chrg_em = nn.Linear(1, self.embeddings_dim) 269 | self.pre_egnn_mlp_input_dim = self.embeddings_dim * 5 270 | self.chrg_em.apply(weights_init) 271 | else: 272 | self.pre_egnn_mlp_input_dim = self.embeddings_dim * 4 273 | 274 | self.pre_egnn_mlp = nn.Sequential( 275 | nn.Linear(self.pre_egnn_mlp_input_dim, self.kernel_dim * 2), 276 | self.dropout, 277 | nn.SiLU(), 278 | nn.Linear(self.kernel_dim * 2, self.kernel_dim), 279 | nn.SiLU(), 280 | nn.Linear(self.kernel_dim, self.kernel_dim), 281 | nn.SiLU(), 282 | ) 283 | 284 | self.kernels = nn.ModuleList() 285 | for _ in range(self.n_kernels): 286 | if self.geometry: 287 | self.kernels.append( 288 | EGNN_sparse3D( 289 | feats_dim=self.kernel_dim, 290 | m_dim=self.m_dim, 291 | aggr=self.aggr, 292 | ) 293 | ) 294 | else: 295 | self.kernels.append( 296 | EGNN_sparse( 297 | feats_dim=self.kernel_dim, 298 | m_dim=self.m_dim, 299 | aggr=self.aggr, 300 | ) 301 | ) 302 | 303 | self.post_egnn_mlp = nn.Sequential( 304 | nn.Linear(self.kernel_dim * self.n_kernels, self.kernel_dim * 2), 305 | self.dropout, 306 | nn.SiLU(), 307 | nn.Linear(self.kernel_dim * 2, self.kernel_dim * 2), 308 | nn.SiLU(), 309 | nn.Linear(self.kernel_dim * 2, self.kernel_dim * 2), 310 | nn.SiLU(), 311 | ) 312 | 313 | self.lig_emb = nn.Embedding(num_embeddings=6, embedding_dim=self.embeddings_dim) 314 | self.sol_emb = nn.Embedding(num_embeddings=4, embedding_dim=self.embeddings_dim) 315 | 316 | self.post_pooling_mlp_input_dim = self.kernel_dim * 2 + self.embeddings_dim * 2 317 | 318 | self.post_pooling_mlp = nn.Sequential( 319 | nn.Linear(self.post_pooling_mlp_input_dim, self.mlp_dim), 320 | self.dropout, 321 | nn.SiLU(), 322 | nn.Linear(self.mlp_dim, self.mlp_dim), 323 | nn.SiLU(), 324 | nn.Linear(self.mlp_dim, self.mlp_dim), 325 | nn.SiLU(), 326 | nn.Linear(self.mlp_dim, 1), 327 | ) 328 | 329 | self.kernels.apply(weights_init) 330 | self.pre_egnn_mlp.apply(weights_init) 331 | self.post_egnn_mlp.apply(weights_init) 332 | self.post_pooling_mlp.apply(weights_init) 333 | nn.init.xavier_uniform_(self.atom_em.weight) 334 | nn.init.xavier_uniform_(self.ring_em.weight) 335 | nn.init.xavier_uniform_(self.hybr_em.weight) 336 | nn.init.xavier_uniform_(self.arom_em.weight) 337 | nn.init.xavier_uniform_(self.lig_emb.weight) 338 | nn.init.xavier_uniform_(self.sol_emb.weight) 339 | 340 | def forward(self, g_batch): 341 | """Forward pass of the GNN. 342 | 343 | :param g_batch: Input graph. 344 | :type g_batch: class 345 | :return: Prediction. 346 | :rtype: Tensor 347 | """ 348 | if self.qml: 349 | features = self.pre_egnn_mlp( 350 | torch.cat( 351 | [ 352 | self.atom_em(g_batch.atom_id), 353 | self.ring_em(g_batch.ring_id), 354 | self.hybr_em(g_batch.hybr_id), 355 | self.arom_em(g_batch.arom_id), 356 | self.chrg_em(g_batch.charges), 357 | ], 358 | dim=1, 359 | ) 360 | ) 361 | else: 362 | features = self.pre_egnn_mlp( 363 | torch.cat( 364 | [ 365 | self.atom_em(g_batch.atom_id), 366 | self.ring_em(g_batch.ring_id), 367 | self.hybr_em(g_batch.hybr_id), 368 | self.arom_em(g_batch.arom_id), 369 | ], 370 | dim=1, 371 | ) 372 | ) 373 | 374 | feature_list = [] 375 | if self.geometry: 376 | features = torch.cat([g_batch.crds_3d, features], dim=1) 377 | for kernel in self.kernels: 378 | features = kernel( 379 | x=features, 380 | edge_index=g_batch.edge_index, 381 | ) 382 | feature_list.append(features[:, self.pos_dim :]) 383 | else: 384 | for kernel in self.kernels: 385 | features = kernel(x=features, edge_index=g_batch.edge_index) 386 | feature_list.append(features) 387 | 388 | features = torch.cat(feature_list, dim=1) 389 | features = self.post_egnn_mlp(features) 390 | 391 | del feature_list 392 | 393 | features = scatter_sum(features, g_batch.batch, dim=0) 394 | 395 | conditions = torch.cat( 396 | [ 397 | self.lig_emb(g_batch.lgnd_id), 398 | self.sol_emb(g_batch.slvn_id), 399 | ], 400 | dim=1, 401 | ) 402 | 403 | features = torch.cat([features, conditions], dim=1) 404 | features = self.post_pooling_mlp(features).squeeze(1) 405 | 406 | return features 407 | 408 | 409 | class FNN(nn.Module): 410 | """Feed forward neural network (FNN) for yield and binary predictions.""" 411 | 412 | def __init__(self, fp_dim=256, mlp_dim=512, kernel_dim=64, embeddings_dim=64): 413 | """Initialization of FNN 414 | 415 | :param fp_dim: Input dimension of the ECFP descriptor, defaults to 256 416 | :type fp_dim: int, optional 417 | :param mlp_dim: Feature dimension within the multi layer perceptrons, defaults to 512 418 | :type mlp_dim: int, optional 419 | :param kernel_dim: Feature dimension within the message passing fucntions, defaults to 64 420 | :type kernel_dim: int, optional 421 | :param embeddings_dim: Embedding dimension of the input features (e.g. reaction conditions), defaults to 64 422 | :type embeddings_dim: int, optional 423 | """ 424 | super(FNN, self).__init__() 425 | 426 | self.embeddings_dim = embeddings_dim 427 | self.kernel_dim = kernel_dim 428 | self.mlp_dim = mlp_dim 429 | self.fp_dim = fp_dim 430 | 431 | dropout = 0.1 432 | self.dropout = nn.Dropout(dropout) 433 | 434 | self.pre_egnn_mlp = nn.Sequential( 435 | nn.Linear(self.fp_dim, self.kernel_dim * 2), 436 | self.dropout, 437 | nn.SiLU(), 438 | nn.Linear(self.kernel_dim * 2, self.kernel_dim * 2), 439 | nn.SiLU(), 440 | nn.Linear(self.kernel_dim * 2, self.kernel_dim), 441 | ) 442 | 443 | self.lig_emb = nn.Embedding(num_embeddings=6, embedding_dim=self.embeddings_dim) 444 | self.sol_emb = nn.Embedding(num_embeddings=4, embedding_dim=self.embeddings_dim) 445 | 446 | self.post_pooling_mlp_input_dim = self.kernel_dim + self.embeddings_dim * 2 447 | 448 | self.post_pooling_mlp = nn.Sequential( 449 | nn.Linear(self.post_pooling_mlp_input_dim, self.mlp_dim), 450 | self.dropout, 451 | nn.SiLU(), 452 | nn.Linear(self.mlp_dim, self.mlp_dim), 453 | nn.SiLU(), 454 | nn.Linear(self.mlp_dim, self.mlp_dim), 455 | nn.SiLU(), 456 | nn.Linear(self.mlp_dim, 1), 457 | ) 458 | 459 | self.post_pooling_mlp.apply(weights_init) 460 | nn.init.xavier_uniform_(self.lig_emb.weight) 461 | nn.init.xavier_uniform_(self.sol_emb.weight) 462 | 463 | def forward(self, g_batch): 464 | """Forward pass of the FNN. 465 | 466 | :param g_batch: Input graph. 467 | :type g_batch: class 468 | :return: Prediction. 469 | :rtype: Tensor 470 | """ 471 | features = self.pre_egnn_mlp(g_batch.ecfp_fp) 472 | 473 | conditions = torch.cat( 474 | [ 475 | self.lig_emb(g_batch.lgnd_id), 476 | self.sol_emb(g_batch.slvn_id), 477 | ], 478 | dim=1, 479 | ) 480 | 481 | features = torch.cat([features, conditions], dim=1) 482 | features = self.post_pooling_mlp(features).squeeze(1) 483 | 484 | return features 485 | -------------------------------------------------------------------------------- /lsfml/experimental/net_utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | import random 7 | 8 | import h5py 9 | import numpy as np 10 | import torch 11 | from torch_geometric.data import Data 12 | from lsfml.modules.pygdataset import Dataset 13 | 14 | 15 | random.seed(2) 16 | 17 | 18 | def get_rxn_ids( 19 | data, 20 | split, 21 | eln, 22 | testset, 23 | ): 24 | """Generates the data set split into training, validation and test sets. 25 | 26 | :param data: Path to h5 file, including preprocessed data, defaults to "../data/experimental_rxndata.h5" 27 | :type data: str, optional 28 | :param split: Type of split (eln or random), defaults to "random" 29 | :type split: str, optional 30 | :param eln: Substrate number for substrate-based split, defaults to "ELN036496-147" 31 | :type eln: str, optional 32 | :param testset: Type of testset, defaults to "1" 33 | :type testset: str, optional 34 | :return: Reaction IDs for training, validation and test split 35 | :rtype: list[str] 36 | """ 37 | 38 | # Load data from h5 file 39 | h5f = h5py.File(data) 40 | 41 | # Load all rxn keys 42 | rxn_ids = list(h5f.keys()) 43 | random.shuffle(rxn_ids) 44 | 45 | # Define subset of rxn keys 46 | if split == "random": 47 | if testset == "1": 48 | tran_ids = rxn_ids[: int(len(rxn_ids) / 2)] 49 | eval_ids = rxn_ids[int(len(rxn_ids) / 4) * 3 :] 50 | test_ids = rxn_ids[int(len(rxn_ids) / 2) : int(len(rxn_ids) / 4) * 3] 51 | if testset == "2": 52 | tran_ids = rxn_ids[: int(len(rxn_ids) / 2)] 53 | eval_ids = rxn_ids[int(len(rxn_ids) / 2) : int(len(rxn_ids) / 4) * 3] 54 | test_ids = rxn_ids[int(len(rxn_ids) / 4) * 3 :] 55 | if testset == "3": 56 | tran_ids = rxn_ids[int(len(rxn_ids) / 2) :] 57 | eval_ids = rxn_ids[: int(len(rxn_ids) / 4)] 58 | test_ids = rxn_ids[int(len(rxn_ids) / 4) : int(len(rxn_ids) / 2)] 59 | if testset == "4": 60 | tran_ids = rxn_ids[int(len(rxn_ids) / 2) :] 61 | eval_ids = rxn_ids[int(len(rxn_ids) / 4) : int(len(rxn_ids) / 2)] 62 | test_ids = rxn_ids[: int(len(rxn_ids) / 4)] 63 | elif split == "eln": 64 | rxn_ids_train = [x for x in rxn_ids if eln not in x] 65 | tran_ids = rxn_ids_train[int(len(rxn_ids_train) / 3) :] 66 | eval_ids = rxn_ids_train[: int(len(rxn_ids_train) / 3)] 67 | test_ids = [x for x in rxn_ids if eln in x] 68 | 69 | return tran_ids, eval_ids, test_ids 70 | 71 | 72 | class DataLSF(Dataset): 73 | """Generates the desired graph objects (2D, 3D, QM) from reading the h5 files.""" 74 | 75 | def __init__( 76 | self, 77 | rxn_ids, 78 | data, 79 | data_substrates, 80 | target, 81 | graph_dim, 82 | fingerprint, 83 | conformers, 84 | ): 85 | """Initialization. 86 | 87 | :param rxn_ids: Reactions IDs from the given split (train, eval, test) 88 | :type rxn_ids: list[str] 89 | :param data: Path to h5 file, including preprocessed data, defaults to "../data/experimental_rxndata.h5" 90 | :type data: str, optional 91 | :param data: Path to h5 file, including preprocessed data, defaults to "../data/experimental_substrates.h5" 92 | :type data_substrates: str, optional 93 | :param target: Target type (binary or mono), defaults to "binary" 94 | :type target: str, optional 95 | :param graph_dim: Indicating 2D or 3D graph structure ("edge_2d" or "edge_3d"), defaults to "edge_2d" 96 | :type target: str, optional 97 | :param fingerprint: Indicating fingerprint type (ecfp4_2 or None), defaults to "ecfp4_2" 98 | :type target: str, optional 99 | :param conformers: List of conformers keys, defaults to ["a", "b", "c", "d", "e"] 100 | :type target: list[str], optional 101 | """ 102 | 103 | # Define inputs 104 | self.target = target 105 | self.graph_dim = graph_dim 106 | self.fingerprint = fingerprint 107 | self.conformers = conformers 108 | self.rxn_ids = rxn_ids 109 | 110 | # Load data from h5 file 111 | self.h5f = h5py.File(data) 112 | self.h5f_subs = h5py.File(data_substrates) 113 | 114 | # Generate dict (int to rxn keys) 115 | nums = list(range(0, len(self.rxn_ids))) 116 | self.idx2rxn = {} 117 | for x in range(len(self.rxn_ids)): 118 | self.idx2rxn[nums[x]] = self.rxn_ids[x] 119 | 120 | print("\nLoader initialized:") 121 | print(f"Number of reactions loaded: {len(self.rxn_ids)}") 122 | print(f"Chosen target (binary or mono): {self.target}") 123 | print(f"Chosen graph_dim (edge_2d of edge_3d): {self.graph_dim}") 124 | print(f"Chosen fingerprint (ecfp4_2 of ecfp6_1): {self.fingerprint}") 125 | 126 | def __getitem__(self, idx): 127 | """Loop over data. 128 | 129 | :param idx: Reaction ID 130 | :type idx: str 131 | :return: Input graph for the neural network. 132 | :rtype: torch_geometric.loader.dataloader.DataLoader 133 | """ 134 | 135 | # int to rxn_id 136 | rxn_id = self.idx2rxn[idx] 137 | 138 | sbst_rxn = rxn_id.split("_")[0] 139 | sbst_rxn = sbst_rxn.split("-")[-1] 140 | 141 | # Pick random conformer 142 | conformer = random.choice(self.conformers) 143 | 144 | # Molecule 145 | atom_id = np.array(self.h5f_subs[sbst_rxn][f"atom_id_{conformer}"]) 146 | ring_id = np.array(self.h5f_subs[sbst_rxn][f"ring_id_{conformer}"]) 147 | hybr_id = np.array(self.h5f_subs[sbst_rxn][f"hybr_id_{conformer}"]) 148 | arom_id = np.array(self.h5f_subs[sbst_rxn][f"arom_id_{conformer}"]) 149 | charges = np.array(self.h5f_subs[sbst_rxn][f"charges_{conformer}"]) 150 | crds_3d = np.array(self.h5f_subs[sbst_rxn][f"crds_3d_{conformer}"]) 151 | 152 | if self.fingerprint is not None: 153 | ecfp_fp = np.array(self.h5f[rxn_id][self.fingerprint]) 154 | else: 155 | ecfp_fp = np.array([]) 156 | 157 | # Edge IDs with desired dimension 158 | edge_index = np.array(self.h5f_subs[sbst_rxn][f"{self.graph_dim}_{conformer}"]) 159 | 160 | # Conditions 161 | lgnd_id = np.array(self.h5f[rxn_id]["lgnd_id"]) 162 | slvn_id = np.array(self.h5f[rxn_id]["slvn_id"]) 163 | 164 | # Tragets 165 | if self.target == "binary": 166 | rxn_trg = np.array(self.h5f[rxn_id]["mono_id"]) 167 | elif self.target == "mono": 168 | mo_frct = np.array(self.h5f[rxn_id]["mo_frct"]) 169 | di_frct = np.array(self.h5f[rxn_id]["di_frct"]) 170 | rxn_trg = np.array(float(mo_frct) + float(di_frct)) 171 | 172 | num_nodes = torch.LongTensor(atom_id).size(0) 173 | 174 | graph_data = Data( 175 | atom_id=torch.LongTensor(atom_id), 176 | ring_id=torch.LongTensor(ring_id), 177 | hybr_id=torch.LongTensor(hybr_id), 178 | arom_id=torch.LongTensor(arom_id), 179 | lgnd_id=torch.LongTensor(lgnd_id), 180 | slvn_id=torch.LongTensor(slvn_id), 181 | charges=torch.FloatTensor(charges), 182 | ecfp_fp=torch.FloatTensor(ecfp_fp), 183 | crds_3d=torch.FloatTensor(crds_3d), 184 | rxn_trg=torch.FloatTensor(rxn_trg), 185 | edge_index=torch.LongTensor(edge_index), 186 | num_nodes=num_nodes, 187 | ) 188 | 189 | return graph_data 190 | 191 | def __len__(self): 192 | """Return length 193 | 194 | :return: length 195 | :rtype: int 196 | """ 197 | return len(self.rxn_ids) 198 | -------------------------------------------------------------------------------- /lsfml/experimental/preprocessh5.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | import h5py, os 7 | import networkx as nx 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | from rdkit import Chem 12 | from rdkit.Chem import AllChem 13 | from torch_geometric.data import Data 14 | from torch_geometric.utils import add_self_loops 15 | from torch_geometric.utils.undirected import to_undirected 16 | from tqdm import tqdm 17 | from scipy.spatial.distance import pdist, squareform 18 | 19 | from lsfml.qml.prod import get_model 20 | from lsfml.utils import ( 21 | get_dict_for_embedding, 22 | get_fp_from_smi, 23 | HYBRIDISATIONS, 24 | AROMATOCITY, 25 | IS_RING, 26 | ATOMTYPES, 27 | QML_ATOMTYPES, 28 | UTILS_PATH, 29 | ) 30 | 31 | QMLMODEL = get_model(gpu=False) 32 | 33 | 34 | HYBRIDISATION_DICT = get_dict_for_embedding(HYBRIDISATIONS) 35 | AROMATOCITY_DICT = get_dict_for_embedding(AROMATOCITY) 36 | IS_RING_DICT = get_dict_for_embedding(IS_RING) 37 | ATOMTYPE_DICT = get_dict_for_embedding(ATOMTYPES) 38 | QML_ATOMTYPE_DICT = get_dict_for_embedding(QML_ATOMTYPES) 39 | 40 | 41 | def get_info_from_smi(smi, randomseed, radius): 42 | """Main function for extracting relevant reaction conditions and generating the 2D and 3D molecular graphs given a SMILES-string and a seed for 3D conformer generation. 43 | 44 | :param smi: SMILES-string 45 | :type smi: str 46 | :param randomseed: random seed 47 | :type randomseed: int 48 | :return: tuple including all graph-relevant numpy arrays 49 | :rtype: tuple 50 | """ 51 | # Get mol objects from smiles 52 | mol_no_Hs = Chem.MolFromSmiles(smi) 53 | mol = Chem.rdmolops.AddHs(mol_no_Hs) 54 | 55 | atomids = [] 56 | qml_atomids = [] 57 | is_ring = [] 58 | hyb = [] 59 | arom = [] 60 | crds_3d = [] 61 | 62 | AllChem.EmbedMolecule(mol, randomSeed=randomseed) 63 | AllChem.UFFOptimizeMolecule(mol) 64 | 65 | for idx, i in enumerate(mol.GetAtoms()): 66 | atomids.append(ATOMTYPE_DICT[i.GetSymbol()]) 67 | qml_atomids.append(QML_ATOMTYPE_DICT[i.GetSymbol()]) 68 | is_ring.append(IS_RING_DICT[str(i.IsInRing())]) 69 | hyb.append(HYBRIDISATION_DICT[str(i.GetHybridization())]) 70 | arom.append(AROMATOCITY_DICT[str(i.GetIsAromatic())]) 71 | crds_3d.append(list(mol.GetConformer().GetAtomPosition(idx))) 72 | 73 | atomids = np.array(atomids) 74 | qml_atomids = np.array(qml_atomids) 75 | is_ring = np.array(is_ring) 76 | hyb = np.array(hyb) 77 | arom = np.array(arom) 78 | crds_3d = np.array(crds_3d) 79 | 80 | # Edges for covalent bonds in sdf file 81 | edge_dir1 = [] 82 | edge_dir2 = [] 83 | for idx, bond in enumerate(mol.GetBonds()): 84 | a2 = bond.GetEndAtomIdx() 85 | a1 = bond.GetBeginAtomIdx() 86 | edge_dir1.append(a1) 87 | edge_dir1.append(a2) 88 | edge_dir2.append(a2) 89 | edge_dir2.append(a1) 90 | 91 | edge_2d = torch.from_numpy(np.array([edge_dir1, edge_dir2])) 92 | 93 | # 3D graph for qml and qml prediction 94 | qml_atomids = torch.LongTensor(qml_atomids) 95 | xyzs = torch.FloatTensor(crds_3d) 96 | edge_index = np.array(nx.complete_graph(qml_atomids.size(0)).edges()) 97 | edge_index = to_undirected(torch.from_numpy(edge_index).t().contiguous()) 98 | edge_index, _ = add_self_loops(edge_index, num_nodes=crds_3d.shape[0]) 99 | 100 | qml_graph = Data( 101 | atomids=qml_atomids, 102 | coords=xyzs, 103 | edge_index=edge_index, 104 | num_nodes=qml_atomids.size(0), 105 | ) 106 | 107 | charges = QMLMODEL(qml_graph).unsqueeze(1).detach().numpy() 108 | 109 | # Get edges for 3d graph 110 | distance_matrix = squareform(pdist(crds_3d)) 111 | np.fill_diagonal(distance_matrix, float("inf")) # to remove self-loops 112 | edge_3d = torch.from_numpy(np.vstack(np.where(distance_matrix <= radius))) 113 | 114 | return ( 115 | atomids, 116 | is_ring, 117 | hyb, 118 | arom, 119 | charges, 120 | edge_2d, 121 | edge_3d, 122 | crds_3d, 123 | ) 124 | 125 | 126 | if __name__ == "__main__": 127 | df = pd.read_csv(os.path.join(UTILS_PATH, "data/experimental_rxndata.csv")) 128 | 129 | # Rxn id 130 | rxn_id = list(df["rxn_id"]) 131 | 132 | # Substrate 133 | educt = list(df["educt"]) 134 | 135 | # Non-molecular conditions 136 | rxn_scale_mol = list(df["rxn_scale_mol"]) 137 | rxn_temp_C = list(df["rxn_temp_C"]) 138 | rxn_time_h = list(df["rxn_time_h"]) 139 | rxn_atm = list(df["rxn_atm"]) 140 | rxn_c_moll = list(df["rxn_c_moll"]) 141 | 142 | # Molecular conditions 143 | catalyst = list(df["catalyst"]) 144 | catalyst_eq = list(df["catalyst_eq"]) 145 | ligand = list(df["ligand"]) 146 | ligand_eq = list(df["ligand_eq"]) 147 | reagent = list(df["reagent"]) 148 | reagent_eq = list(df["reagent_eq"]) 149 | solvent = list(df["solvent"]) 150 | solvent_ratio = list(df["solvent_ratio"]) 151 | 152 | # Targets 153 | yes_no = list(df["yes_no"]) 154 | mono_bo = list(df["mono_bo"]) 155 | di_bo = list(df["di_bo"]) 156 | non_bo = list(df["non_bo"]) 157 | 158 | # Embedding of molecular conditions 159 | rea_dict = get_dict_for_embedding(reagent) 160 | lig_dict = get_dict_for_embedding(ligand) 161 | cat_dict = get_dict_for_embedding(catalyst) 162 | sol_dict = get_dict_for_embedding(solvent) 163 | 164 | print("Liands in data set:", lig_dict) 165 | print("Solvents in data set:", sol_dict) 166 | 167 | # Get molecule-dict for short rxids 168 | 169 | unique_substraes = {} 170 | 171 | for idx, rxn_key in enumerate(rxn_id): 172 | short_rxn_key = rxn_key.split("_")[0] 173 | short_rxn_key = short_rxn_key.split("-")[-1] 174 | 175 | if short_rxn_key not in unique_substraes: 176 | unique_substraes[short_rxn_key] = educt[idx] 177 | 178 | else: 179 | pass 180 | 181 | print(f"Calculating properties for {len(unique_substraes)} unique substartes") 182 | 183 | h5_path = os.path.join(UTILS_PATH, "data/experimental_substrates.h5") 184 | 185 | with h5py.File(h5_path, "w") as lsf_container1: 186 | for rxn_key in tqdm(unique_substraes): 187 | ( 188 | atom_id_a, 189 | ring_id_a, 190 | hybr_id_a, 191 | arom_id_a, 192 | charges_a, 193 | edge_2d_a, 194 | edge_3d_a, 195 | crds_3d_a, 196 | ) = get_info_from_smi(unique_substraes[rxn_key], 0xF00A, 4) 197 | 198 | ( 199 | atom_id_b, 200 | ring_id_b, 201 | hybr_id_b, 202 | arom_id_b, 203 | charges_b, 204 | edge_2d_b, 205 | edge_3d_b, 206 | crds_3d_b, 207 | ) = get_info_from_smi(unique_substraes[rxn_key], 0xF00B, 4) 208 | 209 | ( 210 | atom_id_c, 211 | ring_id_c, 212 | hybr_id_c, 213 | arom_id_c, 214 | charges_c, 215 | edge_2d_c, 216 | edge_3d_c, 217 | crds_3d_c, 218 | ) = get_info_from_smi(unique_substraes[rxn_key], 0xF00C, 4) 219 | 220 | ( 221 | atom_id_d, 222 | ring_id_d, 223 | hybr_id_d, 224 | arom_id_d, 225 | charges_d, 226 | edge_2d_d, 227 | edge_3d_d, 228 | crds_3d_d, 229 | ) = get_info_from_smi(unique_substraes[rxn_key], 0xF00D, 4) 230 | 231 | ( 232 | atom_id_e, 233 | ring_id_e, 234 | hybr_id_e, 235 | arom_id_e, 236 | charges_e, 237 | edge_2d_e, 238 | edge_3d_e, 239 | crds_3d_e, 240 | ) = get_info_from_smi(unique_substraes[rxn_key], 0xF00E, 4) 241 | 242 | # Substrate ID 243 | lsf_container1.create_group(rxn_key) 244 | 245 | # Molecule 246 | lsf_container1[rxn_key].create_dataset("atom_id_a", data=atom_id_a) 247 | lsf_container1[rxn_key].create_dataset("ring_id_a", data=ring_id_a) 248 | lsf_container1[rxn_key].create_dataset("hybr_id_a", data=hybr_id_a) 249 | lsf_container1[rxn_key].create_dataset("arom_id_a", data=arom_id_a) 250 | lsf_container1[rxn_key].create_dataset("charges_a", data=charges_a) 251 | lsf_container1[rxn_key].create_dataset("edge_2d_a", data=edge_2d_a) 252 | lsf_container1[rxn_key].create_dataset("edge_3d_a", data=edge_3d_a) 253 | lsf_container1[rxn_key].create_dataset("crds_3d_a", data=crds_3d_a) 254 | lsf_container1[rxn_key].create_dataset("atom_id_b", data=atom_id_b) 255 | lsf_container1[rxn_key].create_dataset("ring_id_b", data=ring_id_b) 256 | lsf_container1[rxn_key].create_dataset("hybr_id_b", data=hybr_id_b) 257 | lsf_container1[rxn_key].create_dataset("arom_id_b", data=arom_id_b) 258 | lsf_container1[rxn_key].create_dataset("charges_b", data=charges_b) 259 | lsf_container1[rxn_key].create_dataset("edge_2d_b", data=edge_2d_b) 260 | lsf_container1[rxn_key].create_dataset("edge_3d_b", data=edge_3d_b) 261 | lsf_container1[rxn_key].create_dataset("crds_3d_b", data=crds_3d_b) 262 | lsf_container1[rxn_key].create_dataset("atom_id_c", data=atom_id_c) 263 | lsf_container1[rxn_key].create_dataset("ring_id_c", data=ring_id_c) 264 | lsf_container1[rxn_key].create_dataset("hybr_id_c", data=hybr_id_c) 265 | lsf_container1[rxn_key].create_dataset("arom_id_c", data=arom_id_c) 266 | lsf_container1[rxn_key].create_dataset("charges_c", data=charges_c) 267 | lsf_container1[rxn_key].create_dataset("edge_2d_c", data=edge_2d_c) 268 | lsf_container1[rxn_key].create_dataset("edge_3d_c", data=edge_3d_c) 269 | lsf_container1[rxn_key].create_dataset("crds_3d_c", data=crds_3d_c) 270 | lsf_container1[rxn_key].create_dataset("atom_id_d", data=atom_id_d) 271 | lsf_container1[rxn_key].create_dataset("ring_id_d", data=ring_id_d) 272 | lsf_container1[rxn_key].create_dataset("hybr_id_d", data=hybr_id_d) 273 | lsf_container1[rxn_key].create_dataset("arom_id_d", data=arom_id_d) 274 | lsf_container1[rxn_key].create_dataset("charges_d", data=charges_d) 275 | lsf_container1[rxn_key].create_dataset("edge_2d_d", data=edge_2d_d) 276 | lsf_container1[rxn_key].create_dataset("edge_3d_d", data=edge_3d_d) 277 | lsf_container1[rxn_key].create_dataset("crds_3d_d", data=crds_3d_d) 278 | lsf_container1[rxn_key].create_dataset("atom_id_e", data=atom_id_e) 279 | lsf_container1[rxn_key].create_dataset("ring_id_e", data=ring_id_e) 280 | lsf_container1[rxn_key].create_dataset("hybr_id_e", data=hybr_id_e) 281 | lsf_container1[rxn_key].create_dataset("arom_id_e", data=arom_id_e) 282 | lsf_container1[rxn_key].create_dataset("charges_e", data=charges_e) 283 | lsf_container1[rxn_key].create_dataset("edge_2d_e", data=edge_2d_e) 284 | lsf_container1[rxn_key].create_dataset("edge_3d_e", data=edge_3d_e) 285 | lsf_container1[rxn_key].create_dataset("crds_3d_e", data=crds_3d_e) 286 | 287 | wins = 0 288 | loss = 0 289 | 290 | print(f"Transforming {len(rxn_id)} reactions into h5 format") 291 | 292 | with h5py.File("../data/experimental_rxndata.h5", "w") as lsf_container: 293 | for idx, rxn_key in enumerate(tqdm(rxn_id)): 294 | try: 295 | rgnt_id = rea_dict[reagent[idx]] 296 | lgnd_id = lig_dict[ligand[idx]] 297 | clst_id = cat_dict[catalyst[idx]] 298 | slvn_id = sol_dict[solvent[idx]] 299 | 300 | # Create group in h5 for this id 301 | lsf_container.create_group(rxn_key) 302 | 303 | # Add all parameters as datasets to the created group 304 | if ligand_eq[idx] == "none": 305 | lgnd_eq = 0.0 306 | else: 307 | lgnd_eq = ligand_eq[idx] 308 | 309 | # Molecule 310 | ecfp4_2 = get_fp_from_smi(educt[idx]) 311 | lsf_container[rxn_key].create_dataset("ecfp4_2", data=[ecfp4_2]) 312 | 313 | # Conditions 314 | lsf_container[rxn_key].create_dataset("rgnt_id", data=[int(rgnt_id)]) 315 | lsf_container[rxn_key].create_dataset("lgnd_id", data=[int(lgnd_id)]) 316 | lsf_container[rxn_key].create_dataset("clst_id", data=[int(clst_id)]) 317 | lsf_container[rxn_key].create_dataset("slvn_id", data=[int(slvn_id)]) 318 | lsf_container[rxn_key].create_dataset("rgnt_eq", data=[float(reagent_eq[idx])]) 319 | lsf_container[rxn_key].create_dataset("lgnd_eq", data=[float(lgnd_eq)]) 320 | lsf_container[rxn_key].create_dataset("clst_eq", data=[float(catalyst_eq[idx])]) 321 | lsf_container[rxn_key].create_dataset("rxn_scl", data=[float(rxn_scale_mol[idx])]) 322 | lsf_container[rxn_key].create_dataset("rxn_con", data=[float(rxn_c_moll[idx])]) 323 | lsf_container[rxn_key].create_dataset("rxn_tmp", data=[float(rxn_temp_C[idx])]) 324 | lsf_container[rxn_key].create_dataset("rxn_tme", data=[float(rxn_time_h[idx])]) 325 | 326 | # Tragets 327 | lsf_container[rxn_key].create_dataset("mono_id", data=[yes_no[idx]]) 328 | lsf_container[rxn_key].create_dataset("mo_frct", data=[[mono_bo[idx]]]) 329 | lsf_container[rxn_key].create_dataset("di_frct", data=[[di_bo[idx]]]) 330 | lsf_container[rxn_key].create_dataset("no_frct", data=[[non_bo[idx]]]) 331 | 332 | wins += 1 333 | 334 | except: 335 | loss += 1 336 | 337 | print(f"Reactions sucessfully transformed: {wins}; Reactions failed: {loss}") 338 | -------------------------------------------------------------------------------- /lsfml/experimental/property_analysis.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | from rdkit import Chem 7 | from rdkit.Chem import rdMolDescriptors, Draw 8 | 9 | import pandas as pd 10 | import numpy as np 11 | import os, xlsxwriter 12 | from tqdm import tqdm 13 | from io import BytesIO 14 | 15 | import matplotlib.pyplot as plt 16 | from matplotlib.gridspec import GridSpec 17 | 18 | from lsfml.utils import UTILS_PATH 19 | 20 | fontsize = 22 21 | 22 | 23 | def get_hist_property(wgt, rot, hba, hbd, psa, rng, sp3, ste, name, bins): 24 | """Get histogram of molecular properties. 25 | 26 | Args: 27 | wgt (list): Molecular weight. 28 | rot (list): Rotatable bonds. 29 | hba (list): Hydrogen bond acceptors. 30 | hbd (list): Hydrogen bond acceptors. 31 | psa (list): Polar surface area. 32 | rng (list): Rings. 33 | sp3 (list): Fraction sp3. 34 | ste (list): Stereogenic centers. 35 | name (str): File name. 36 | bins (int): Number of bins. 37 | """ 38 | fig = plt.figure(figsize=(40, 16)) 39 | gs = GridSpec(nrows=2, ncols=4) 40 | gs.update(wspace=0.4, hspace=0.2) 41 | 42 | ax = fig.add_subplot(111) 43 | ax.spines["top"].set_color("none") 44 | ax.spines["bottom"].set_color("none") 45 | ax.spines["left"].set_color("none") 46 | ax.spines["right"].set_color("none") 47 | ax.tick_params(labelcolor="w", top=False, bottom=False, left=False, right=False) 48 | ax.set_ylabel("Number of molecules", fontsize=fontsize + 8, labelpad=60) 49 | 50 | ax1 = fig.add_subplot(gs[0, 0]) 51 | ax1.hist(wgt, density=False, facecolor="royalblue", bins=bins) 52 | ax1.tick_params(axis="x", labelsize=fontsize) 53 | ax1.tick_params(axis="y", labelsize=fontsize) 54 | ax1.set_xlabel(str("Molecular weight / g/mol"), fontsize=fontsize + 4) 55 | 56 | ax2 = fig.add_subplot(gs[0, 1]) 57 | ax2.hist(rot, density=False, facecolor="royalblue", bins=bins) 58 | ax2.tick_params(axis="x", labelsize=fontsize) 59 | ax2.tick_params(axis="y", labelsize=fontsize) 60 | ax2.set_xlabel(str("Rotatable bonds / $N$"), fontsize=fontsize + 4) 61 | 62 | ax3 = fig.add_subplot(gs[0, 2]) 63 | ax3.hist(hba, density=False, facecolor="royalblue", bins=bins) 64 | ax3.tick_params(axis="x", labelsize=fontsize) 65 | ax3.tick_params(axis="y", labelsize=fontsize) 66 | ax3.set_xlabel(str("Hydrogen bond acceptors / $N$"), fontsize=fontsize + 4) 67 | 68 | ax4 = fig.add_subplot(gs[0, 3]) 69 | ax4.hist(hbd, density=False, facecolor="royalblue", bins=bins) 70 | ax4.tick_params(axis="x", labelsize=fontsize) 71 | ax4.tick_params(axis="y", labelsize=fontsize) 72 | ax4.set_xlabel(str("Hydrogen bond donors / $N$"), fontsize=fontsize + 4) 73 | 74 | ax5 = fig.add_subplot(gs[1, 0]) 75 | ax5.hist(psa, density=False, facecolor="royalblue", bins=bins) 76 | ax5.tick_params(axis="x", labelsize=fontsize) 77 | ax5.tick_params(axis="y", labelsize=fontsize) 78 | ax5.set_xlabel(str("Polar surface area / $A^2$"), fontsize=fontsize + 4) 79 | 80 | ax6 = fig.add_subplot(gs[1, 1]) 81 | ax6.hist(rng, density=False, facecolor="royalblue", bins=bins) 82 | ax6.tick_params(axis="x", labelsize=fontsize) 83 | ax6.tick_params(axis="y", labelsize=fontsize) 84 | ax6.set_xlabel(str("Rings / $N$"), fontsize=fontsize + 4) 85 | 86 | ax7 = fig.add_subplot(gs[1, 2]) 87 | ax7.hist(sp3, density=False, facecolor="royalblue", bins=bins) 88 | ax7.tick_params(axis="x", labelsize=fontsize) 89 | ax7.tick_params(axis="y", labelsize=fontsize) 90 | ax7.set_xlabel(str("Fraction sp3"), fontsize=fontsize + 4) 91 | 92 | ax8 = fig.add_subplot(gs[1, 3]) 93 | ax8.hist(ste, density=False, facecolor="royalblue", bins=bins) 94 | ax8.tick_params(axis="x", labelsize=fontsize) 95 | ax8.tick_params(axis="y", labelsize=fontsize) 96 | ax8.set_xlabel(str("Stereogenic centers / $N$"), fontsize=fontsize + 4) 97 | 98 | out_name = os.path.join("analysis/figures/", name + ".png") 99 | os.makedirs(os.path.dirname(out_name), exist_ok=True) 100 | plt.savefig(out_name, dpi=200) 101 | plt.clf() 102 | 103 | return 104 | 105 | 106 | def get_hist_temp_time(temps, times, scale, concs): 107 | """Get historgram of reaction conditions. 108 | 109 | Args: 110 | temps (list): Reaction temperature. 111 | times (list): Reaction time. 112 | scale (list): Reaction scale. 113 | concs (list): Reaction concentration. 114 | """ 115 | fig = plt.figure(figsize=(40, 10)) 116 | gs = GridSpec(nrows=1, ncols=4) 117 | gs.update(wspace=0.4, hspace=0.2) 118 | 119 | ax = fig.add_subplot(111) 120 | ax.spines["top"].set_color("none") 121 | ax.spines["bottom"].set_color("none") 122 | ax.spines["left"].set_color("none") 123 | ax.spines["right"].set_color("none") 124 | ax.tick_params(labelcolor="w", top=False, bottom=False, left=False, right=False) 125 | ax.set_ylabel("Number of reactions", fontsize=fontsize + 8, labelpad=60) 126 | 127 | ax1 = fig.add_subplot(gs[0, 0]) 128 | ax1.hist(temps, density=False, facecolor="royalblue", bins=30) 129 | ax1.tick_params(axis="x", labelsize=fontsize) 130 | ax1.tick_params(axis="y", labelsize=fontsize) 131 | ax1.set_xlabel(str("Temperarure / Celsius"), fontsize=fontsize + 4) 132 | 133 | ax2 = fig.add_subplot(gs[0, 1]) 134 | ax2.hist(times, density=False, facecolor="royalblue", bins=30) 135 | ax2.tick_params(axis="x", labelsize=fontsize) 136 | ax2.tick_params(axis="y", labelsize=fontsize) 137 | ax2.set_xlabel(str("Time / hour"), fontsize=fontsize + 4) 138 | 139 | scale = [x * 1000 for x in scale] 140 | ax3 = fig.add_subplot(gs[0, 2]) 141 | ax3.hist(scale, density=False, facecolor="royalblue", bins=30) 142 | ax3.tick_params(axis="x", labelsize=fontsize) 143 | ax3.tick_params(axis="y", labelsize=fontsize) 144 | ax3.set_xlabel(str("Scale / mol"), fontsize=fontsize + 4) 145 | 146 | ax4 = fig.add_subplot(gs[0, 3]) 147 | ax4.hist(concs, density=False, facecolor="royalblue", bins=30) 148 | ax4.tick_params(axis="x", labelsize=fontsize) 149 | ax4.tick_params(axis="y", labelsize=fontsize) 150 | ax4.set_xlabel(str("Concentration / mol/L"), fontsize=fontsize + 4) 151 | 152 | out_name = os.path.join("analysis/figures/histogram_non_molecular_conditions.png") 153 | os.makedirs(os.path.dirname(out_name), exist_ok=True) 154 | plt.savefig(out_name, dpi=300) 155 | plt.clf() 156 | 157 | return 158 | 159 | 160 | def get_propertiest(smiles): 161 | """Calculating molecular properties from a list of SMILES. 162 | 163 | Args: 164 | smiles (list): SMILES strings 165 | 166 | Returns: 167 | lists: Molecular properties. 168 | """ 169 | wgt = [] 170 | rot = [] 171 | hba = [] 172 | hbd = [] 173 | psa = [] 174 | rng = [] 175 | sp3 = [] 176 | ste = [] 177 | 178 | for i, smi in enumerate(tqdm(smiles)): 179 | try: 180 | mol = Chem.MolFromSmiles(smi) 181 | wgt.append(rdMolDescriptors.CalcExactMolWt(mol)) 182 | rot.append(rdMolDescriptors.CalcNumRotatableBonds(mol)) 183 | hba.append(rdMolDescriptors.CalcNumHBA(mol)) 184 | hbd.append(rdMolDescriptors.CalcNumHBD(mol)) 185 | psa.append(rdMolDescriptors.CalcTPSA(mol)) 186 | rng.append(rdMolDescriptors.CalcNumRings(mol)) 187 | sp3.append(rdMolDescriptors.CalcFractionCSP3(mol)) 188 | ste.append(rdMolDescriptors.CalcNumAtomStereoCenters(mol)) 189 | except: 190 | pass 191 | 192 | return wgt, rot, hba, hbd, psa, rng, sp3, ste 193 | 194 | 195 | def get_csv_summary(df, key, with_mol_img=False): 196 | """Save csv from data frame. 197 | 198 | Args: 199 | df (pandas data frame): Data frame consisting of SMILES strings and their IDs. 200 | key (str): key to access SMILES strings. 201 | with_mol_img (bool, optional): Save mol img to xlsx file. Defaults to False. 202 | """ 203 | smiles_set = list(set(df[key])) 204 | smiles_list = list(df[key]) 205 | 206 | summary_dict = {} 207 | 208 | for smi in smiles_set: 209 | summary_dict[smi] = [] 210 | 211 | yes_no = list(df["yes_no"]) 212 | 213 | for idx, x in enumerate(smiles_list): 214 | summary_dict[x].append(yes_no[idx]) 215 | 216 | smls = [] 217 | wins = [] 218 | loss = [] 219 | totl = [] 220 | 221 | for x in summary_dict: 222 | smls.append(x) 223 | wins.append(summary_dict[x].count(1)) 224 | loss.append(summary_dict[x].count(0)) 225 | totl.append(summary_dict[x].count(0) + summary_dict[x].count(1)) 226 | 227 | df_tmp = pd.DataFrame( 228 | { 229 | "smiles": smls, 230 | "num_reactions_worked": wins, 231 | "num_reactions_failed": loss, 232 | "num_reactions_total": totl, 233 | } 234 | ) 235 | 236 | df_tmp.sort_values(by="num_reactions_worked", ascending=False, inplace=True, ignore_index=True) 237 | 238 | out_name = "analysis/summary/summary_" + str(key) + ".xlsx" 239 | os.makedirs(os.path.dirname(out_name), exist_ok=True) 240 | df_tmp.to_csv( 241 | "analysis/summary/summary_" + str(key) + ".csv", 242 | index=False, 243 | ) 244 | 245 | if with_mol_img: 246 | SaveXlsxFromFrame( 247 | df_tmp, 248 | out_name, 249 | molCols=[ 250 | "smiles", 251 | ], 252 | size=(300, 300), 253 | ) 254 | 255 | return 256 | 257 | 258 | def SaveXlsxFromFrame(frame, outFile, molCols=["ROMol"], size=(300, 300)): 259 | """Generating xlsx file with drawings. 260 | 261 | Args: 262 | frame (pandas data frame): Data frame consisting of SMILES strings and their IDs and other properties. 263 | outFile (str): Name of the files saved. 264 | molCols (list, optional): Columns from which SMILES are saved as drwings. Defaults to ["ROMol"]. 265 | size (tuple, optional): Size of the drawings. Defaults to (300, 300). 266 | """ 267 | cols = list(frame.columns) 268 | 269 | dataTypes = dict(frame.dtypes) 270 | 271 | workbook = xlsxwriter.Workbook(outFile) # New workbook 272 | worksheet = workbook.add_worksheet() # New work sheet 273 | worksheet.set_column("A:A", size[0] / 6.0) # column width 274 | 275 | # Write first row with column names 276 | c2 = 0 277 | molCol_names = [f"{x}_img" for x in molCols] 278 | for x in molCol_names + cols: 279 | worksheet.write_string(0, c2, x) 280 | c2 += 1 281 | 282 | c = 1 283 | for _, row in tqdm(frame.iterrows(), total=len(frame)): 284 | for k, molCol in enumerate(molCols): 285 | image_data = BytesIO() 286 | 287 | # none can not be visualized as molecule 288 | if row[molCol] == "none": 289 | pass 290 | else: 291 | img = Draw.MolToImage(Chem.MolFromSmiles(row[molCol]), size=size) 292 | img.save(image_data, format="PNG") 293 | worksheet.set_row(c, height=size[1]) # looks like height is not in px? 294 | worksheet.insert_image(c, k, "f", {"image_data": image_data}) 295 | 296 | c2 = len(molCols) 297 | for x in cols: 298 | if str(dataTypes[x]) == "object": 299 | # string length is limited in xlsx 300 | worksheet.write_string(c, c2, str(row[x])[:32000]) 301 | elif ("float" in str(dataTypes[x])) or ("int" in str(dataTypes[x])): 302 | if (row[x] != np.nan) or (row[x] != np.inf): 303 | worksheet.write_number(c, c2, row[x]) 304 | elif "datetime" in str(dataTypes[x]): 305 | worksheet.write_datetime(c, c2, row[x]) 306 | c2 += 1 307 | c += 1 308 | 309 | workbook.close() 310 | image_data.close() 311 | 312 | 313 | def get_hist_equiv(ctls_eq, lgnd_eq, rgnt_eq): 314 | """Plot hoistogram. 315 | 316 | Args: 317 | ctls_eq (str): Equivalents of Catalyst. 318 | lgnd_eq (str): Equivalents of Ligand. 319 | rgnt_eq (str): Equivalents of Reagent. 320 | """ 321 | fig = plt.figure(figsize=(30, 14)) 322 | gs = GridSpec(nrows=1, ncols=3) 323 | gs.update(wspace=0.4, hspace=0.2) 324 | 325 | ax = fig.add_subplot(111) 326 | ax.spines["top"].set_color("none") 327 | ax.spines["bottom"].set_color("none") 328 | ax.spines["left"].set_color("none") 329 | ax.spines["right"].set_color("none") 330 | ax.tick_params(labelcolor="w", top=False, bottom=False, left=False, right=False) 331 | ax.set_xlabel("Equivalents / %", fontsize=fontsize + 8, labelpad=60) 332 | ax.set_ylabel("Number of reactions", fontsize=fontsize + 8, labelpad=60) 333 | 334 | ctls_eq = [x * 100 for x in ctls_eq] 335 | ax1 = fig.add_subplot(gs[0, 0]) 336 | ax1.hist(ctls_eq, density=False, facecolor="royalblue", bins=30) 337 | ax1.tick_params(axis="x", labelsize=fontsize) 338 | ax1.tick_params(axis="y", labelsize=fontsize) 339 | ax1.set_xlabel(str("Catalyst"), fontsize=fontsize + 4) 340 | 341 | lgnd_eq = [x * 100 for x in lgnd_eq] 342 | ax2 = fig.add_subplot(gs[0, 1]) 343 | ax2.hist(lgnd_eq, density=False, facecolor="royalblue", bins=30) 344 | ax2.tick_params(axis="x", labelsize=fontsize) 345 | ax2.tick_params(axis="y", labelsize=fontsize) 346 | ax2.set_xlabel(str("Ligand"), fontsize=fontsize + 4) 347 | 348 | rgnt_eq = [x * 100 for x in rgnt_eq] 349 | ax3 = fig.add_subplot(gs[0, 2]) 350 | ax3.hist(rgnt_eq, density=False, facecolor="royalblue", bins=30) 351 | ax3.tick_params(axis="x", labelsize=fontsize) 352 | ax3.tick_params(axis="y", labelsize=fontsize) 353 | ax3.set_xlabel(str("Reagent"), fontsize=fontsize + 4) 354 | 355 | out_name = os.path.join("analysis/figures/histogram_equivalents.png") 356 | os.makedirs(os.path.dirname(out_name), exist_ok=True) 357 | plt.savefig(out_name, dpi=300) 358 | plt.clf() 359 | 360 | return 361 | 362 | 363 | def yield_hist(yields, name): 364 | """Plot histrogram of reaction yields. 365 | 366 | Args: 367 | yields (list): List of reaction yields. 368 | name (str): Name of output file. 369 | 370 | Yields: 371 | _type_: _description_ 372 | """ 373 | plt.figure(figsize=(8, 8)) 374 | yields = [float(x) * 100 for x in yields] 375 | plt.hist(yields, density=False, color="royalblue", bins=20) 376 | plt.ylabel("Occurrence", fontsize=18) 377 | plt.xlabel("Reaction yield / %", fontsize=18) 378 | plt.tick_params(axis="x", labelsize=16) 379 | plt.tick_params(axis="y", labelsize=16) 380 | plt.savefig(name, dpi=600) 381 | plt.show() 382 | plt.clf() 383 | 384 | return 385 | 386 | 387 | if __name__ == "__main__": 388 | # read csv 389 | df = pd.read_csv(os.path.join(UTILS_PATH, "data/experimental_rxndata.csv")) 390 | smiles = list(set(df["educt"])) 391 | print(f"Number of different educts: {len(list(set(smiles)))}") 392 | 393 | # gets all properties from smiles list 394 | print("Calculating property distribution and generating histogram:") 395 | wgt, rot, hba, hbd, psa, rng, sp3, ste = get_propertiest(smiles) 396 | 397 | # plot histogram of properteies from all smiles 398 | get_hist_property(wgt, rot, hba, hbd, psa, rng, sp3, ste, "property_all", bins=15) 399 | 400 | # get csv summaries 401 | print("Generating csv and xls files:") 402 | get_csv_summary(df, "educt", with_mol_img=True) 403 | get_csv_summary(df, "catalyst", with_mol_img=False) 404 | get_csv_summary(df, "reagent", with_mol_img=True) 405 | get_csv_summary(df, "solvent", with_mol_img=True) 406 | get_csv_summary(df, "ligand", with_mol_img=True) 407 | get_csv_summary(df, "rxn_atm", with_mol_img=False) 408 | 409 | # get histogram of time/temp 410 | print("Generating rxn histogram:") 411 | temps = list(df["rxn_temp_C"]) 412 | times = list(df["rxn_time_h"]) 413 | scale = list(df["rxn_scale_mol"]) 414 | concs = list(df["rxn_c_moll"]) 415 | get_hist_temp_time(temps, times, scale, concs) 416 | 417 | # get histogram of equivalents 418 | ctls_eq = list(df["catalyst_eq"]) 419 | lgnd_eq = list(df["ligand_eq"]) 420 | rgnt_eq = list(df["reagent_eq"]) 421 | get_hist_equiv(ctls_eq, lgnd_eq, rgnt_eq) 422 | 423 | # yield hist 424 | yields = [1 - x for x in list(df["non_bo"])] 425 | yield_hist(yields, "analysis/figures/hist_rxn_yield.png") 426 | 427 | yields = [x for x in yields if x > 0] 428 | yield_hist(yields, "analysis/figures/hist_rxn_yield_pos.png") 429 | 430 | print("All Done!") 431 | -------------------------------------------------------------------------------- /lsfml/experimental/train.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | import argparse 7 | import configparser 8 | import os 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch_geometric.loader import DataLoader 15 | 16 | from lsfml.experimental.net import EGNN, FNN, GraphTransformer 17 | from lsfml.experimental.net_utils import DataLSF, get_rxn_ids 18 | from lsfml.utils import mae_loss, UTILS_PATH 19 | 20 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | SUB_DATA = os.path.join(UTILS_PATH, "data/experimental_substrates.h5") 22 | RXN_DATA = os.path.join(UTILS_PATH, "data/experimental_rxndata.h5") 23 | 24 | 25 | def train( 26 | model, 27 | optimizer, 28 | criterion, 29 | train_loader, 30 | ): 31 | """Train loop. 32 | 33 | :param model: Model 34 | :type model: class 35 | :param optimizer: Optimizer 36 | :type optimizer: class 37 | :param criterion: Loss 38 | :type criterion: class 39 | :param train_loader: Data loader 40 | :type train_loader: torch_geometric.loader.dataloader.DataLoader 41 | :return: RMSE Loss 42 | :rtype: numpy.float64 43 | """ 44 | model.train() 45 | training_loss = [] 46 | 47 | for g in train_loader: 48 | g = g.to(DEVICE) 49 | optimizer.zero_grad() 50 | 51 | pred = model(g) 52 | 53 | loss = criterion(pred, g.rxn_trg) 54 | loss.backward() 55 | optimizer.step() 56 | 57 | with torch.no_grad(): 58 | mae = mae_loss(pred, g.rxn_trg) 59 | training_loss.append(mae) 60 | 61 | return np.mean(training_loss) 62 | 63 | 64 | def eval( 65 | model, 66 | eval_loader, 67 | ): 68 | """Validation & test loop. 69 | 70 | :param model: Model 71 | :type model: class 72 | :param eval_loader: Data loader 73 | :type eval_loader: torch_geometric.loader.dataloader.DataLoader 74 | :return: tuple including essential information to quantify network perfromance such as MAE, predirctions, labels etc. 75 | :rtype: tuple 76 | """ 77 | 78 | model.eval() 79 | eval_loss = [] 80 | 81 | preds = [] 82 | ys = [] 83 | 84 | with torch.no_grad(): 85 | for g in eval_loader: 86 | g = g.to(DEVICE) 87 | pred = model(g) 88 | mae = mae_loss(pred, g.rxn_trg) 89 | eval_loss.append(mae) 90 | ys.append(g.rxn_trg) 91 | preds.append(pred) 92 | 93 | return np.mean(eval_loss), ys, preds 94 | 95 | 96 | if __name__ == "__main__": 97 | # python train.py -config 420 -mode a -cv 1 -testset 1 -early_stop 0 98 | 99 | # Make Folders for Results and Models 100 | os.makedirs("results/", exist_ok=True) 101 | os.makedirs("models/", exist_ok=True) 102 | 103 | # Read Passed Arguments 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument("-config", type=str, default="100") 106 | parser.add_argument("-mode", type=str, default="a") 107 | parser.add_argument("-cv", type=str, default="1") 108 | parser.add_argument("-testset", type=str, default="1") 109 | parser.add_argument("-early_stop", type=int, default=1) 110 | args = parser.parse_args() 111 | 112 | # Define Configuration form Model and Dataset 113 | config = configparser.ConfigParser() 114 | CONFIG_PATH = os.path.join(UTILS_PATH, f"config/config_{str(args.config)}.ini") 115 | config.read(CONFIG_PATH) 116 | print({section: dict(config[section]) for section in config.sections()}) 117 | early_stop = True if args.early_stop >= 1 else False 118 | 119 | LR_FACTOR = float(config["PARAMS"]["LR_FACTOR"]) 120 | LR_STEP_SIZE = int(config["PARAMS"]["LR_STEP_SIZE"]) 121 | N_KERNELS = int(config["PARAMS"]["N_KERNELS"]) 122 | POOLING_HEADS = int(config["PARAMS"]["POOLING_HEADS"]) 123 | D_MLP = int(config["PARAMS"]["D_MLP"]) 124 | D_KERNEL = int(config["PARAMS"]["D_KERNEL"]) 125 | D_EMBEDDING = int(config["PARAMS"]["D_EMBEDDING"]) 126 | BATCH_SIZE = int(config["PARAMS"]["BATCH_SIZE"]) 127 | SPLIT = str(config["PARAMS"]["SPLIT"]) 128 | ELN = str(config["PARAMS"]["ELN"]) 129 | TARGET = str(config["PARAMS"]["TARGET"]) 130 | QML = int(config["PARAMS"]["QML"]) 131 | GEOMETRY = int(config["PARAMS"]["GEOMETRY"]) 132 | FINGERPRINT = str(config["PARAMS"]["FINGERPRINT"]) 133 | QML = True if QML >= 1 else False 134 | GEOMETRY = True if GEOMETRY >= 1 else False 135 | GRAPH_DIM = "edge_3d" if GEOMETRY >= 1 else "edge_2d" 136 | FINGERPRINT = FINGERPRINT if args.mode == "c" else None 137 | FP_DIM = 1024 if FINGERPRINT == "ecfp6_1" else 256 138 | CONFORMER_LIST = ["a", "b", "c", "d", "e"] 139 | 140 | # Initialize Model 141 | if args.mode == "a": 142 | model = GraphTransformer( 143 | n_kernels=N_KERNELS, 144 | pooling_heads=POOLING_HEADS, 145 | mlp_dim=D_MLP, 146 | kernel_dim=D_KERNEL, 147 | embeddings_dim=D_EMBEDDING, 148 | qml=QML, 149 | geometry=GEOMETRY, 150 | ) 151 | elif args.mode == "b": 152 | model = EGNN( 153 | n_kernels=N_KERNELS, 154 | mlp_dim=D_MLP, 155 | kernel_dim=D_KERNEL, 156 | embeddings_dim=D_EMBEDDING, 157 | qml=QML, 158 | geometry=GEOMETRY, 159 | ) 160 | elif args.mode == "c": 161 | model = FNN( 162 | fp_dim=FP_DIM, 163 | mlp_dim=D_MLP, 164 | kernel_dim=D_KERNEL, 165 | embeddings_dim=D_EMBEDDING, 166 | ) 167 | 168 | model = model.to(DEVICE) 169 | 170 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 171 | model_parameters = sum([np.prod(e.size()) for e in model_parameters]) 172 | print("\nmodel_parameters", model_parameters) 173 | 174 | optimizer = torch.optim.Adam( 175 | model.parameters(), 176 | lr=LR_FACTOR, 177 | weight_decay=1e-10, 178 | ) 179 | criterion = nn.MSELoss() 180 | scheduler = torch.optim.lr_scheduler.StepLR( 181 | optimizer, 182 | step_size=LR_STEP_SIZE, 183 | gamma=0.5, 184 | verbose=False, 185 | ) 186 | 187 | # Neural Netowork Training 188 | tr_losses = [] 189 | ev_losses = [] 190 | 191 | if early_stop: 192 | # Get Datasets 193 | tran_ids, eval_ids, test_ids = get_rxn_ids( 194 | data=RXN_DATA, 195 | split=SPLIT, 196 | eln=ELN, 197 | testset=args.testset, 198 | ) 199 | train_data = DataLSF( 200 | rxn_ids=tran_ids, 201 | data=RXN_DATA, 202 | data_substrates=SUB_DATA, 203 | target=TARGET, 204 | graph_dim=GRAPH_DIM, 205 | fingerprint=FINGERPRINT, 206 | conformers=CONFORMER_LIST, 207 | ) 208 | train_loader = DataLoader( 209 | train_data, 210 | batch_size=BATCH_SIZE, 211 | shuffle=True, 212 | num_workers=2, 213 | ) 214 | eval_data = DataLSF( 215 | rxn_ids=eval_ids, 216 | data=RXN_DATA, 217 | data_substrates=SUB_DATA, 218 | target=TARGET, 219 | graph_dim=GRAPH_DIM, 220 | fingerprint=FINGERPRINT, 221 | conformers=CONFORMER_LIST, 222 | ) 223 | eval_loader = DataLoader( 224 | eval_data, 225 | batch_size=BATCH_SIZE, 226 | shuffle=True, 227 | num_workers=2, 228 | ) 229 | test_data = DataLSF( 230 | rxn_ids=test_ids, 231 | data=RXN_DATA, 232 | data_substrates=SUB_DATA, 233 | target=TARGET, 234 | graph_dim=GRAPH_DIM, 235 | fingerprint=FINGERPRINT, 236 | conformers=CONFORMER_LIST, 237 | ) 238 | test_loader = DataLoader( 239 | test_data, 240 | batch_size=BATCH_SIZE, 241 | shuffle=True, 242 | num_workers=2, 243 | ) 244 | 245 | # Training with Early Stopping 246 | min_mae = 1000 247 | 248 | for epoch in range(1000): 249 | # Training and Eval Loops 250 | tr_l = train(model, optimizer, criterion, train_loader) 251 | ev_l, ev_ys, ev_pred = eval(model, eval_loader) 252 | tr_losses.append(tr_l) 253 | ev_losses.append(ev_l) 254 | scheduler.step() 255 | 256 | if ev_l <= min_mae: 257 | # Define new min-loss 258 | min_mae = ev_l 259 | 260 | # Test model 261 | te_l, te_ys, te_pred = eval(model, test_loader) 262 | 263 | ys_saved = [item for sublist in te_ys for item in sublist] 264 | pred_saved = [item for sublist in te_pred for item in sublist] 265 | 266 | # Save Model and Save Loos + Predictions 267 | if SPLIT == "eln": 268 | torch.save(model.state_dict(), f"models/config_{args.config}_{args.mode}_{TARGET}_{args.cv}.pt") 269 | torch.save( 270 | [tr_losses, ev_losses, ys_saved, pred_saved, ELN, TARGET], 271 | f"results/config_{args.config}_{args.mode}_{args.cv}.pt", 272 | ) 273 | elif SPLIT == "random": 274 | torch.save( 275 | model.state_dict(), 276 | f"models/config_{args.config}_{args.mode}_{TARGET}_{args.cv}_{args.testset}.pt", 277 | ) 278 | torch.save( 279 | [tr_losses, ev_losses, ys_saved, pred_saved, ELN, TARGET], 280 | f"results/config_{args.config}_{args.mode}_{args.cv}_{args.testset}.pt", 281 | ) 282 | 283 | else: 284 | # Get Datasets 285 | tran_ids, eval_ids, test_ids = get_rxn_ids( 286 | data=RXN_DATA, 287 | split=SPLIT, 288 | eln=ELN, 289 | testset=args.testset, 290 | ) 291 | tran_ids += eval_ids 292 | train_data = DataLSF( 293 | rxn_ids=tran_ids, 294 | data=RXN_DATA, 295 | data_substrates=SUB_DATA, 296 | target=TARGET, 297 | graph_dim=GRAPH_DIM, 298 | fingerprint=FINGERPRINT, 299 | conformers=CONFORMER_LIST, 300 | ) 301 | train_loader = DataLoader( 302 | train_data, 303 | batch_size=BATCH_SIZE, 304 | shuffle=True, 305 | num_workers=2, 306 | ) 307 | test_data = DataLSF( 308 | rxn_ids=test_ids, 309 | data=RXN_DATA, 310 | data_substrates=SUB_DATA, 311 | target=TARGET, 312 | graph_dim=GRAPH_DIM, 313 | fingerprint=FINGERPRINT, 314 | conformers=CONFORMER_LIST, 315 | ) 316 | test_loader = DataLoader( 317 | test_data, 318 | batch_size=BATCH_SIZE, 319 | shuffle=True, 320 | num_workers=2, 321 | ) 322 | 323 | # Training without Early Stopping 324 | for epoch in range(1000): 325 | # Training Loop 326 | tr_l = train(model, optimizer, criterion, train_loader) 327 | tr_losses.append(tr_l) 328 | scheduler.step() 329 | 330 | if epoch >= 999: 331 | # Test model 332 | te_l, te_ys, te_pred = eval(model, test_loader) 333 | 334 | ys_saved = [item for sublist in te_ys for item in sublist] 335 | pred_saved = [item for sublist in te_pred for item in sublist] 336 | 337 | # Save Model and Save Loos + Predictions 338 | if SPLIT == "eln": 339 | torch.save(model.state_dict(), f"models/config_{args.config}_{args.mode}_{TARGET}_{args.cv}.pt") 340 | torch.save( 341 | [tr_losses, ev_losses, ys_saved, pred_saved, ELN, TARGET], 342 | f"results/config_{args.config}_{args.mode}_{args.cv}.pt", 343 | ) 344 | elif SPLIT == "random": 345 | torch.save( 346 | model.state_dict(), 347 | f"models/config_{args.config}_{args.mode}_{TARGET}_{args.cv}_{args.testset}.pt", 348 | ) 349 | torch.save( 350 | [tr_losses, ev_losses, ys_saved, pred_saved, ELN, TARGET], 351 | f"results/config_{args.config}_{args.mode}_{args.cv}_{args.testset}.pt", 352 | ) 353 | -------------------------------------------------------------------------------- /lsfml/fganalysis/ertl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Original authors: Richard Hall and Guillaume Godin 5 | # This file is part of the RDKit. 6 | # The contents are covered by the terms of the BSD license 7 | # which is included in the file license.txt, found at the root 8 | # of the RDKit source tree. 9 | 10 | from collections import namedtuple 11 | 12 | # 13 | # 14 | # Richard hall 2017 15 | # IFG main code 16 | # Guillaume Godin 2017 17 | # refine output function 18 | # astex_ifg: identify functional groups a la Ertl, J. Cheminform (2017) 9:36 19 | from rdkit import Chem 20 | 21 | 22 | def merge(mol, marked, aset): 23 | 24 | bset = set() 25 | for idx in aset: 26 | atom = mol.GetAtomWithIdx(idx) 27 | for nbr in atom.GetNeighbors(): 28 | jdx = nbr.GetIdx() 29 | if jdx in marked: 30 | marked.remove(jdx) 31 | bset.add(jdx) 32 | if not bset: 33 | return 34 | merge(mol, marked, bset) 35 | aset.update(bset) 36 | 37 | 38 | # atoms connected by non-aromatic double or triple bond to any heteroatom 39 | # c=O should not match (see fig1, box 15). I think using A instead of * should sort that out? 40 | PATT_DOUBLE_TRIPLE = Chem.MolFromSmarts("A=,#[!#6]") 41 | # atoms in non aromatic carbon-carbon double or triple bonds 42 | PATT_CC_DOUBLE_TRIPLE = Chem.MolFromSmarts("C=,#C") 43 | # acetal carbons, i.e. sp3 carbons connected to tow or more oxygens, 44 | # nitrogens or sulfurs; these O, N or S atoms must have only single bonds 45 | PATT_ACETAL = Chem.MolFromSmarts("[CX4](-[O,N,S])-[O,N,S]") 46 | # all atoms in oxirane, aziridine and thiirane rings 47 | PATT_OXIRANE_ETC = Chem.MolFromSmarts("[O,N,S]1CC1") 48 | 49 | PATT_TUPLE = (PATT_DOUBLE_TRIPLE, PATT_CC_DOUBLE_TRIPLE, PATT_ACETAL, PATT_OXIRANE_ETC) 50 | 51 | 52 | def identify_functional_groups(mol): 53 | 54 | marked = set() 55 | # mark all heteroatoms in a molecule, including halogens 56 | for atom in mol.GetAtoms(): 57 | if atom.GetAtomicNum() not in (6, 1): # would we ever have hydrogen? 58 | marked.add(atom.GetIdx()) 59 | 60 | # mark the four specific types of carbon atom 61 | for patt in PATT_TUPLE: 62 | for path in mol.GetSubstructMatches(patt): 63 | for atomindex in path: 64 | marked.add(atomindex) 65 | 66 | # merge all connected marked atoms to a single FG 67 | groups = [] 68 | while marked: 69 | grp = set([marked.pop()]) 70 | merge(mol, marked, grp) 71 | groups.append(grp) 72 | 73 | # extract also connected unmarked carbon atoms 74 | ifg = namedtuple("IFG", ["atomIds", "atoms", "type"]) 75 | ifgs = [] 76 | for g in groups: 77 | uca = set() 78 | for atomidx in g: 79 | for n in mol.GetAtomWithIdx(atomidx).GetNeighbors(): 80 | if n.GetAtomicNum() == 6: 81 | uca.add(n.GetIdx()) 82 | ifgs.append( 83 | ifg( 84 | atomIds=tuple(list(g)), 85 | atoms=Chem.MolFragmentToSmiles(mol, g, canonical=True), 86 | type=Chem.MolFragmentToSmiles(mol, g.union(uca), canonical=True), 87 | ) 88 | ) 89 | return ifgs 90 | -------------------------------------------------------------------------------- /lsfml/fganalysis/fg_analysis.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | import os 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import pandas as pd 11 | from ertl import identify_functional_groups 12 | from rdkit import Chem 13 | from tqdm import tqdm 14 | 15 | PARAMS = [2, 40, "fg_analysis/"] 16 | FOLDER_NAME = PARAMS[2] 17 | os.makedirs(FOLDER_NAME, exist_ok=True) 18 | 19 | 20 | def get_smiles_2_rxn_dict(smiles, yes_no): 21 | """Creates a dict for SMILES-strings to sum of successful reactions. 22 | 23 | :param smiles: SMILES-strings 24 | :type smiles: list[str] 25 | :param yes_no: Reaction outcomes. 26 | :type yes_no: list[int] 27 | :return: Dict SMILES-strings (str): sum of successful reactions (int). 28 | :rtype: dict 29 | """ 30 | smiles_2_rxn_dict = {} 31 | 32 | for i, smi in enumerate(tqdm(smiles)): 33 | if smi not in smiles_2_rxn_dict: 34 | smiles_2_rxn_dict[smi] = 0 35 | smiles_2_rxn_dict[smi] += yes_no[i] 36 | 37 | elif smi in smiles_2_rxn_dict: 38 | smiles_2_rxn_dict[smi] += yes_no[i] 39 | 40 | return smiles_2_rxn_dict 41 | 42 | 43 | def get_smiles_2_negative_rxn_dict(smiles, yes_no): 44 | """Creates a dict for SMILES-strings to sum of failed reactions. 45 | 46 | :param smiles: SMILES-strings 47 | :type smiles: list[str] 48 | :param yes_no: Reaction outcomes. 49 | :type yes_no: list[int] 50 | :return: Dict SMILES-strings (str): sum of failed reactions (int). 51 | :rtype: dict 52 | """ 53 | smiles_2_rxn_dict = {} 54 | 55 | for i, smi in enumerate(tqdm(smiles)): 56 | failure = abs(yes_no[i] - 1) 57 | 58 | if smi not in smiles_2_rxn_dict: 59 | smiles_2_rxn_dict[smi] = 0 60 | smiles_2_rxn_dict[smi] += failure 61 | 62 | elif smi in smiles_2_rxn_dict: 63 | smiles_2_rxn_dict[smi] += failure 64 | 65 | return smiles_2_rxn_dict 66 | 67 | 68 | def plt_barplot(sorted_model_dict, model_dict_std, name, ylabel, xlabel, ylim, imgsize1, imgsize2): 69 | """Creates a bar plot given a list of mean values and standard deviations. 70 | 71 | :param sorted_model_dict: Dict for mean values. 72 | :type sorted_model_dict: dict 73 | :param model_dict_std: Dict for standard deviations. 74 | :type model_dict_std: dict 75 | :param name: File name. 76 | :type name: str 77 | :param ylabel: Y-axis label. 78 | :type ylabel: str 79 | :param xlabel: X-axis label. 80 | :type xlabel: str 81 | :param ylim: Y-axis limit. 82 | :type ylim: boolean 83 | :param imgsize1: Figure size 1. 84 | :type imgsize1: int 85 | :param imgsize2: Figure size 2. 86 | :type imgsize2: int 87 | """ 88 | plt.figure(figsize=(imgsize1, imgsize2)) 89 | 90 | keys = list(sorted_model_dict.keys()) 91 | accs = [] 92 | stds = [] 93 | 94 | for k in sorted_model_dict: 95 | accs.append(np.mean(np.array(sorted_model_dict[k]))) 96 | stds.append(np.mean(np.array(model_dict_std[k]))) 97 | 98 | plt.bar(keys, accs, color="lightskyblue") # grey 99 | plt.errorbar(keys, accs, yerr=stds, color="black", ls="none", elinewidth=2.5) 100 | plt.tick_params(axis="x", labelsize=30, rotation=90) 101 | plt.tick_params(axis="y", labelsize=34) 102 | plt.xlabel(f"\n{xlabel}", fontsize=34) 103 | plt.ylabel(f"\n{ylabel}", fontsize=34) 104 | 105 | if ylim: 106 | bottom, top = plt.xlim() 107 | plt.ylim((min(accs) - 10, max(accs) + 3)) 108 | plt.tight_layout() 109 | plt.savefig(f"{FOLDER_NAME}{name}.png", dpi=400) 110 | plt.clf() 111 | 112 | 113 | def plt_barplot_two_colors( 114 | sorted_model_dict, model_dict_std, name, ylabel, xlabel, ylim, imgsize1, imgsize2, dict_to_check 115 | ): 116 | """Creates a bar plot given a list of mean values and standard deviations. Additionally, specific bars can be colored in a second color. 117 | 118 | :param sorted_model_dict: Dict for mean values. 119 | :type sorted_model_dict: dict 120 | :param model_dict_std: Dict for standard deviations. 121 | :type model_dict_std: dict 122 | :param name: File name. 123 | :type name: str 124 | :param ylabel: Y-axis label. 125 | :type ylabel: str 126 | :param xlabel: X-axis label. 127 | :type xlabel: str 128 | :param ylim: Y-axis limit. 129 | :type ylim: boolean 130 | :param imgsize1: Figure size 1. 131 | :type imgsize1: int 132 | :param imgsize2: Figure size 2. 133 | :type imgsize2: int 134 | :param dict_to_check: Name of keys for second color. 135 | :type dict_to_check: str 136 | """ 137 | plt.figure(figsize=(imgsize1, imgsize2)) 138 | 139 | keys = list(sorted_model_dict.keys()) 140 | accs = [] 141 | stds = [] 142 | 143 | for k in sorted_model_dict: 144 | accs.append(np.mean(np.array(sorted_model_dict[k]))) 145 | stds.append(np.mean(np.array(model_dict_std[k]))) 146 | 147 | barlist = plt.bar(keys, accs, color="lightskyblue") 148 | 149 | for ( 150 | i, 151 | k, 152 | ) in enumerate(keys): 153 | if k not in dict_to_check: 154 | barlist[i].set_color("orange") 155 | 156 | plt.errorbar(keys, accs, yerr=stds, color="black", ls="none", elinewidth=2.5) 157 | plt.tick_params(axis="x", labelsize=30, rotation=90) 158 | plt.tick_params(axis="y", labelsize=34) 159 | plt.xlabel(f"\n{xlabel}", fontsize=34) 160 | plt.ylabel(f"\n{ylabel}", fontsize=34) 161 | 162 | if ylim: 163 | bottom, top = plt.xlim() 164 | plt.ylim((min(accs) - 10, max(accs) + 3)) 165 | plt.tight_layout() 166 | plt.savefig(f"{FOLDER_NAME}{name}.png", dpi=400) 167 | plt.clf() 168 | 169 | 170 | def plt_barplot_with_adapt_color( 171 | sorted_model_dict, model_dict_std, name, ylabel, xlabel, ylim, imgsize1, imgsize2, bar_color 172 | ): 173 | """Creates a bar plot given a list of mean values and standard deviations. Additionally, the color can be specified. 174 | 175 | :param sorted_model_dict: Dict for mean values. 176 | :type sorted_model_dict: dict 177 | :param model_dict_std: Dict for standard deviations. 178 | :type model_dict_std: dict 179 | :param name: File name. 180 | :type name: str 181 | :param ylabel: Y-axis label. 182 | :type ylabel: str 183 | :param xlabel: X-axis label. 184 | :type xlabel: str 185 | :param ylim: Y-axis limit. 186 | :type ylim: boolean 187 | :param imgsize1: Figure size 1. 188 | :type imgsize1: int 189 | :param imgsize2: Figure size 2. 190 | :type imgsize2: int 191 | :param bar_color: Color. 192 | :type bar_color: str 193 | """ 194 | plt.figure(figsize=(imgsize1, imgsize2)) 195 | 196 | keys = list(sorted_model_dict.keys()) 197 | 198 | accs = [] 199 | stds = [] 200 | 201 | for k in sorted_model_dict: 202 | accs.append(np.mean(np.array(sorted_model_dict[k]))) 203 | stds.append(np.mean(np.array(model_dict_std[k]))) 204 | 205 | plt.errorbar(keys, accs, yerr=stds, color="black", ls="none", elinewidth=2.5) 206 | plt.tick_params(axis="x", labelsize=30, rotation=90) 207 | plt.tick_params(axis="y", labelsize=34) 208 | plt.xlabel(f"\n{xlabel}", fontsize=34) 209 | plt.ylabel(f"\n{ylabel}", fontsize=34) 210 | 211 | if ylim: 212 | bottom, top = plt.xlim() 213 | plt.ylim((min(accs) - 10, max(accs) + 3)) 214 | plt.tight_layout() 215 | plt.savefig(f"{FOLDER_NAME}{name}.png", dpi=400) 216 | plt.clf() 217 | 218 | 219 | def get_fg_occurence(smiles): 220 | """Counts successful reactions per functional group in a dict of SMILES-srings to reaction outcome. 221 | 222 | :param smiles: SMILES-strings (str): reaction outcome {int} 223 | :type smiles: dict 224 | :return: dict functional group (str): number of successful reactions (int). 225 | :rtype: dict 226 | """ 227 | fg_dict = {} 228 | 229 | uniques = list(set(smiles)) 230 | 231 | for u in uniques: 232 | try: 233 | m = Chem.MolFromSmiles(u) 234 | fgs = identify_functional_groups(m) 235 | 236 | tmp_fgs = [] 237 | 238 | for f in fgs: 239 | tmp_fgs.append(f[PARAMS[0]]) 240 | 241 | for fg in tmp_fgs: 242 | if fg not in fg_dict: 243 | fg_dict[fg] = 1 244 | 245 | elif fg in fg_dict: 246 | fg_dict[fg] += 1 247 | except: 248 | print("skipping:", u) 249 | 250 | return fg_dict 251 | 252 | 253 | def get_fg_tollerance(smiles_2_rxn_dict): 254 | """Counts functional groups in a list of SMILES-srings. 255 | 256 | :param smiles: SMILES-strings 257 | :type smiles: list[str] 258 | :return: dict functional group (str): number of successful reactions (int). 259 | :rtype: dict 260 | """ 261 | fg_tollerance_dict = {} 262 | 263 | for smi in tqdm(smiles_2_rxn_dict): 264 | rxn = smiles_2_rxn_dict[smi] 265 | 266 | try: 267 | m = Chem.MolFromSmiles(smi) 268 | fgs = identify_functional_groups(m) 269 | 270 | tmp_fgs = [] 271 | 272 | for f in fgs: 273 | tmp_fgs.append(f[PARAMS[0]]) 274 | 275 | for fg in tmp_fgs: 276 | if fg not in fg_tollerance_dict: 277 | fg_tollerance_dict[fg] = rxn 278 | 279 | elif fg in fg_tollerance_dict: 280 | fg_tollerance_dict[fg] += rxn 281 | 282 | except: 283 | print("skipping:", smi) 284 | 285 | return fg_tollerance_dict 286 | 287 | 288 | if __name__ == "__main__": 289 | # Smiles from lsf-space 290 | df = pd.read_csv("../data/experimental_rxndata.csv") 291 | smiles = list(df["educt"]) 292 | yes_no = list(df["yes_no"]) 293 | yes_no = [float(x) for x in yes_no] 294 | 295 | # Get dicts for absoluet and relative success/failure rate of the individual FGs 296 | print( 297 | f"\nExtracting all functional groups from the substrates present in the {len(smiles)} " 298 | "reactions of the experimental data set:" 299 | ) 300 | smiles_2_rxn_dict = get_smiles_2_rxn_dict(smiles, yes_no) 301 | smiles_2_negative_rxn_dict = get_smiles_2_negative_rxn_dict(smiles, yes_no) 302 | 303 | print( 304 | f"\nCalculating the relative success/failure rate of the {len(smiles_2_rxn_dict)} individual " 305 | "functioan groups:" 306 | ) 307 | fg_tollerance_dict = get_fg_tollerance(smiles_2_rxn_dict) 308 | fg_intollerance_dict = get_fg_tollerance(smiles_2_negative_rxn_dict) 309 | 310 | print("\nCreating the five plots:") 311 | 312 | # Green barplot 313 | sorted_fgs = sorted(fg_tollerance_dict, key=lambda k: fg_tollerance_dict[k], reverse=True) 314 | sorted_fg_dict = {} 315 | model_dict_std = {} 316 | for k in sorted_fgs: 317 | sorted_fg_dict[k] = fg_tollerance_dict[k] 318 | model_dict_std[k] = 0 319 | 320 | plt_barplot_with_adapt_color( 321 | sorted_fg_dict, 322 | model_dict_std, 323 | "lsf_space_success", 324 | "Successful reactions / $N$", 325 | "Functional group / SMILES", 326 | None, 327 | 22, 328 | 14, 329 | "palegreen", 330 | ) 331 | print( 332 | f"1. Plotted the functional groups for the unique {len(smiles_2_rxn_dict)} substrates by success " 333 | "(absolute number)." 334 | ) 335 | 336 | # Red barplot 337 | sorted_fgs = sorted(fg_intollerance_dict, key=lambda k: fg_intollerance_dict[k], reverse=True) 338 | sorted_fg_dict = {} 339 | model_dict_std = {} 340 | for k in sorted_fgs: 341 | sorted_fg_dict[k] = fg_intollerance_dict[k] 342 | model_dict_std[k] = 0 343 | 344 | plt_barplot_with_adapt_color( 345 | sorted_fg_dict, 346 | model_dict_std, 347 | "lsf_space_failure", 348 | "Failed reactions / $N$", 349 | "Functional group / SMILES", 350 | None, 351 | 22, 352 | 14, 353 | "lightcoral", 354 | ) 355 | print( 356 | f"2. Plotted the functional groups for the unique {len(smiles_2_rxn_dict)} substrates by failure " 357 | "(absolute number)." 358 | ) 359 | 360 | # Barplot of LSF-space 361 | fg_dict = get_fg_occurence(smiles) 362 | sorted_fgs = sorted(fg_dict, key=lambda k: fg_dict[k], reverse=True) 363 | sorted_fg_dict = {} 364 | model_dict_std = {} 365 | for k in sorted_fgs: 366 | sorted_fg_dict[k] = fg_dict[k] 367 | model_dict_std[k] = 0 368 | 369 | plt_barplot( 370 | sorted_fg_dict, 371 | model_dict_std, 372 | "lsf_space_number", 373 | "Occurence in LSF-space library / $N$", 374 | "Functional group / SMILES", 375 | None, 376 | 22, 377 | 14, 378 | ) 379 | dict_to_cherck = sorted_fg_dict 380 | print( 381 | f"3. Plotted the {len(list(fg_dict.keys()))} unique functional groups for the {len(smiles_2_rxn_dict)} " 382 | "substrates by occurence (absolute number)." 383 | ) 384 | 385 | # Bar plot percentage of failed 386 | percent_dict = {} 387 | 388 | for fg in fg_dict: 389 | percent_dict[fg] = fg_intollerance_dict[fg] / (fg_dict[fg] * 24) 390 | 391 | sorted_fgs = sorted(percent_dict, key=lambda k: percent_dict[k], reverse=True) 392 | sorted_fg_dict = {} 393 | model_dict_std = {} 394 | for k in sorted_fgs: 395 | sorted_fg_dict[k] = percent_dict[k] 396 | model_dict_std[k] = 0 397 | 398 | plt_barplot_with_adapt_color( 399 | sorted_fg_dict, 400 | model_dict_std, 401 | "lsf_space_failed_fraction", 402 | "Fraction of failed reactions", 403 | "Functional group / SMILES", 404 | None, 405 | 22, 406 | 14, 407 | "lightskyblue", 408 | ) 409 | print( 410 | f"4. Plotted the functional groups from the unique {len(smiles_2_rxn_dict)} substrates by " 411 | "failure (relative number)." 412 | ) 413 | 414 | # Barplot of Drug-space 415 | drug_data = pd.read_csv("../clustering/cluster_analysis/filtered_list/filtered_list.csv") 416 | smiles = list(drug_data["smiles_list"]) 417 | fg_dict = get_fg_occurence(smiles) 418 | 419 | occurences = [] 420 | for k in fg_dict: 421 | occurences.append(fg_dict[k]) 422 | top_100 = sorted(occurences)[-PARAMS[1]] 423 | 424 | sorted_fgs = sorted(fg_dict, key=lambda k: fg_dict[k], reverse=True) 425 | sorted_fg_dict = {} 426 | model_dict_std = {} 427 | for k in sorted_fgs: 428 | if fg_dict[k] >= top_100: 429 | sorted_fg_dict[k] = fg_dict[k] 430 | model_dict_std[k] = 0 431 | 432 | plt_barplot_two_colors( 433 | sorted_fg_dict, 434 | model_dict_std, 435 | "drug_lsf_space_comparison", 436 | "Occurence in drug-space library / $N$", 437 | "Functional group / SMILES", 438 | None, 439 | 22, 440 | 14, 441 | dict_to_cherck, 442 | ) 443 | print( 444 | f"5. Plotted the {len(sorted_fg_dict)} / {len(fg_dict)} most abundant functional groups of the " 445 | "drug space library and highlighted the once not present in the LSF space library in orange." 446 | ) 447 | -------------------------------------------------------------------------------- /lsfml/img/NCHEM-22102062A_figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ETHmodlab/lsfml/d0178f1ebfedd73639cee2452fadacc500ca23e1/lsfml/img/NCHEM-22102062A_figure1.png -------------------------------------------------------------------------------- /lsfml/img/regio_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ETHmodlab/lsfml/d0178f1ebfedd73639cee2452fadacc500ca23e1/lsfml/img/regio_example.jpg -------------------------------------------------------------------------------- /lsfml/literature/regioselectivity/graph_mapping.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | import h5py, os 7 | import networkx as nx 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | from rdkit import Chem 12 | from rdkit.Chem import AllChem 13 | from torch_geometric.data import Data 14 | from torch_geometric.utils import add_self_loops 15 | from torch_geometric.utils.undirected import to_undirected 16 | from tqdm import tqdm 17 | from scipy.spatial.distance import pdist, squareform 18 | 19 | from lsfml.qml.prod import get_model 20 | from lsfml.utils import get_dict_for_embedding, AROMATOCITY, IS_RING, QML_ATOMTYPES, UTILS_PATH 21 | 22 | QMLMODEL = get_model(gpu=False) 23 | 24 | HYBRIDISATION_DICT = {"SP3": 0, "SP2": 1, "SP": 2, "UNSPECIFIED": 3, "S": 3} 25 | AROMATOCITY_DICT = get_dict_for_embedding(AROMATOCITY) 26 | IS_RING_DICT = get_dict_for_embedding(IS_RING) 27 | QML_ATOMTYPE_DICT = get_dict_for_embedding(QML_ATOMTYPES) 28 | ATOMTYPE_DICT = {"H": 0, "C": 1, "N": 2, "O": 3, "F": 4, "P": 5, "S": 6, "Cl": 7, "Br": 8, "I": 9} 29 | 30 | 31 | def get_rms(smi, patt, repl): 32 | """Takes a SMILES-sting and replaces a substructure (patt) such as a functional group with another substructure (repl) 33 | 34 | :param smi: SMILES-string. 35 | :type smi: str 36 | :param patt: Pattern to be replaced. 37 | :type patt: str 38 | :param repl: Pattern to replace. 39 | :type repl: str 40 | :return: SMILES-string. 41 | :rtype: str 42 | """ 43 | m = Chem.MolFromSmiles(smi) 44 | rms = Chem.MolToSmiles(AllChem.ReplaceSubstructs(m, patt, repl)[0]) 45 | 46 | m = Chem.MolFromSmiles(rms) 47 | while rms != Chem.MolToSmiles(AllChem.ReplaceSubstructs(m, patt, repl)[0]): 48 | print( 49 | f"More than one Boron: {rms == Chem.MolToSmiles(AllChem.ReplaceSubstructs(m, patt, repl)[0])}" 50 | f", {rms}, {Chem.MolToSmiles(AllChem.ReplaceSubstructs(m, patt, repl)[0])}" 51 | ) 52 | rms = Chem.MolToSmiles(AllChem.ReplaceSubstructs(m, patt, repl)[0]) 53 | m = Chem.MolFromSmiles(rms) 54 | 55 | return rms 56 | 57 | 58 | def get_regioselectivity(rms, seed, radius): 59 | """Main function to generate the 2D and 3D molecular graphs and to extract the reagioselectivity given a SMILES-string and a seed for 3D conformer generation. 60 | 61 | :param rms: SMILES-string 62 | :type rms: str 63 | :param seed: Random seed for 3D conformer generation 64 | :type seed: int 65 | :return: tuple including all graph-relevant numpy arrays 66 | :rtype: tuple 67 | """ 68 | mol_no_Hs = Chem.MolFromSmiles(rms) 69 | mol = Chem.rdmolops.AddHs(mol_no_Hs) 70 | 71 | atomids = [] 72 | qml_atomids = [] 73 | is_ring = [] 74 | hyb = [] 75 | arom = [] 76 | crds_3d = [] 77 | pot_trg = [] 78 | 79 | AllChem.EmbedMolecule(mol, randomSeed=seed) 80 | AllChem.UFFOptimizeMolecule(mol) 81 | 82 | for idx, i in enumerate(mol.GetAtoms()): 83 | atomids.append(ATOMTYPE_DICT[i.GetSymbol()]) 84 | qml_atomids.append(QML_ATOMTYPE_DICT[i.GetSymbol()]) 85 | is_ring.append(IS_RING_DICT[str(i.IsInRing())]) 86 | hyb.append(HYBRIDISATION_DICT[str(i.GetHybridization())]) 87 | arom.append(AROMATOCITY_DICT[str(i.GetIsAromatic())]) 88 | crds_3d.append(list(mol.GetConformer().GetAtomPosition(idx))) 89 | 90 | nghbrs = [x.GetSymbol() for x in i.GetNeighbors()] 91 | if (i.GetSymbol() == "C") and ("H" in nghbrs): 92 | pot_trg.append(1) 93 | else: 94 | pot_trg.append(0) 95 | 96 | trg_atoms = [] 97 | edge_dir1 = [] 98 | edge_dir2 = [] 99 | for idx, bond in enumerate(mol.GetBonds()): 100 | a2 = bond.GetEndAtomIdx() 101 | a1 = bond.GetBeginAtomIdx() 102 | edge_dir1.append(a1) 103 | edge_dir1.append(a2) 104 | edge_dir2.append(a2) 105 | edge_dir2.append(a1) 106 | 107 | if mol.GetAtoms()[a1].GetIsotope() == 2: 108 | trg_atoms.append(a2) 109 | 110 | if mol.GetAtoms()[a2].GetIsotope() == 2: 111 | trg_atoms.append(a1) 112 | 113 | edge_2d = torch.from_numpy(np.array([edge_dir1, edge_dir2])) 114 | 115 | target = [0 for i, x in enumerate(atomids)] 116 | for trg_atom in trg_atoms: 117 | target[trg_atom] = 1 118 | 119 | atomids = np.array(atomids) 120 | target = np.array(target) 121 | qml_atomids = np.array(qml_atomids) 122 | is_ring = np.array(is_ring) 123 | hyb = np.array(hyb) 124 | arom = np.array(arom) 125 | crds_3d = np.array(crds_3d) 126 | pot_trg = np.array(pot_trg) 127 | 128 | # 3D graph for qml and qml prediction 129 | qml_atomids = torch.LongTensor(qml_atomids) 130 | xyzs = torch.FloatTensor(crds_3d) 131 | edge_index = np.array(nx.complete_graph(qml_atomids.size(0)).edges()) 132 | edge_index = to_undirected(torch.from_numpy(edge_index).t().contiguous()) 133 | edge_index, _ = add_self_loops(edge_index, num_nodes=crds_3d.shape[0]) 134 | 135 | qml_graph = Data( 136 | atomids=qml_atomids, 137 | coords=xyzs, 138 | edge_index=edge_index, 139 | num_nodes=qml_atomids.size(0), 140 | ) 141 | 142 | charges = QMLMODEL(qml_graph).unsqueeze(1).detach().numpy() 143 | 144 | # Get edges for 3d graph 145 | distance_matrix = squareform(pdist(crds_3d)) 146 | np.fill_diagonal(distance_matrix, float("inf")) # to remove self-loops 147 | edge_3d = torch.from_numpy(np.vstack(np.where(distance_matrix <= radius))) 148 | 149 | return ( 150 | atomids, 151 | is_ring, 152 | hyb, 153 | arom, 154 | charges, 155 | edge_2d, 156 | edge_3d, 157 | crds_3d, 158 | pot_trg, 159 | target, 160 | ) 161 | 162 | 163 | if __name__ == "__main__": 164 | # Read csv 165 | df = pd.read_csv(os.path.join(UTILS_PATH, "data/literature_rxndata.csv"), encoding="unicode_escape") 166 | smiles = list(df["product_1_smiles"]) 167 | rxn_yield = list(df["product_1_yield"]) 168 | 169 | print(f"Initial number of reactions: {len(smiles)}") 170 | rxn_yield = [0.0 if np.isnan(x) else x for x in rxn_yield] 171 | smiles = [s for i, s in enumerate(smiles) if rxn_yield[i] >= 0.3] 172 | print(f"Number of reactions after removing yields <= 30%: {len(smiles)}") 173 | smiles = [s for i, s in enumerate(smiles) if len(str(smiles[i])) >= 5] 174 | print(f"Number of reactions after removing nan SMILES: {len(smiles)}") 175 | uniques = list(set(smiles)) 176 | print(f"Number of reactions after removing duplicate SMILES: {len(uniques)}") 177 | 178 | repl = Chem.MolFromSmiles("[2H]") 179 | patt = Chem.MolFromSmarts("B2OC(C)(C)C(O2)(C)C") 180 | 181 | wins = 0 182 | loss = 0 183 | rxn_key = 0 184 | 185 | all_smiles = [] 186 | h5_path = os.path.join(UTILS_PATH, "data/literature_regio.h5") 187 | 188 | with h5py.File(h5_path, "w") as lsf_container: 189 | for smi in tqdm(uniques): 190 | try: 191 | rms = get_rms(smi, patt, repl) 192 | 193 | if "[2H]" in rms: 194 | ( 195 | atom_id, 196 | ring_id, 197 | hybr_id, 198 | arom_id, 199 | charges, 200 | edge_2d, 201 | edge_3d, 202 | crds_3d, 203 | pot_trg, 204 | reg_trg, 205 | ) = get_regioselectivity(rms, 0xF10D, 4) 206 | 207 | # Create group in h5 for this id 208 | lsf_container.create_group(str(rxn_key)) 209 | 210 | # Molecule 211 | lsf_container[str(rxn_key)].create_dataset("atom_id", data=atom_id) 212 | lsf_container[str(rxn_key)].create_dataset("ring_id", data=ring_id) 213 | lsf_container[str(rxn_key)].create_dataset("hybr_id", data=hybr_id) 214 | lsf_container[str(rxn_key)].create_dataset("arom_id", data=arom_id) 215 | lsf_container[str(rxn_key)].create_dataset("edge_2d", data=edge_2d) 216 | lsf_container[str(rxn_key)].create_dataset("edge_3d", data=edge_3d) 217 | lsf_container[str(rxn_key)].create_dataset("charges", data=charges) 218 | lsf_container[str(rxn_key)].create_dataset("crds_3d", data=crds_3d) 219 | lsf_container[str(rxn_key)].create_dataset("pot_trg", data=pot_trg) 220 | lsf_container[str(rxn_key)].create_dataset("reg_trg", data=reg_trg) 221 | 222 | all_smiles.append(rms) 223 | 224 | wins += 1 225 | rxn_key += 1 226 | 227 | else: 228 | print(f"No boron in product or unconventional boron in product: {rms}, {smi}") 229 | 230 | except: 231 | loss += 1 232 | 233 | print(f"Reactions sucessfully transformed: {wins}; Reactions failed {loss}") 234 | 235 | df = pd.DataFrame( 236 | { 237 | "all_smiles": all_smiles, 238 | } 239 | ) 240 | df.to_csv(os.path.join(UTILS_PATH, "data/literature_regio.csv"), sep=",", encoding="utf-8", index=False) 241 | -------------------------------------------------------------------------------- /lsfml/literature/regioselectivity/net.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from lsfml.modules.gnn_blocks import ( 10 | EGNN_sparse, 11 | EGNN_sparse3D, 12 | weights_init, 13 | ) 14 | 15 | 16 | class Atomistic_EGNN(nn.Module): 17 | """Atomistic graph neural network (aGNN) for regioselectivity predictions.""" 18 | 19 | def __init__(self, n_kernels=3, mlp_dim=512, kernel_dim=64, embeddings_dim=64, qml=True, geometry=True): 20 | """Initialization of aGNN 21 | 22 | :param n_kernels: Number of message passing functions, defaults to 3 23 | :type n_kernels: int, optional 24 | :param mlp_dim: Feature dimension within the multi layer perceptrons, defaults to 512 25 | :type mlp_dim: int, optional 26 | :param kernel_dim: Feature dimension within the message passing fucntions, defaults to 64 27 | :type kernel_dim: int, optional 28 | :param embeddings_dim: Embedding dimension of the input features (e.g. reaction conditions and atomic features), defaults to 64 29 | :type embeddings_dim: int, optional 30 | :param qml: Option to include DFT-level partial charges, defaults to True 31 | :type qml: bool, optional 32 | :param geometry: Option to include steric information in the input graph, defaults to True 33 | :type geometry: bool, optional 34 | """ 35 | super(Atomistic_EGNN, self).__init__() 36 | 37 | self.embeddings_dim = embeddings_dim 38 | self.m_dim = 16 39 | self.kernel_dim = kernel_dim 40 | self.n_kernels = n_kernels 41 | self.aggr = "add" 42 | self.pos_dim = 3 43 | self.mlp_dim = mlp_dim 44 | self.qml = qml 45 | self.geometry = geometry 46 | 47 | dropout = 0.1 48 | self.dropout = nn.Dropout(dropout) 49 | 50 | self.atom_em = nn.Embedding(num_embeddings=10, embedding_dim=self.embeddings_dim) 51 | self.ring_em = nn.Embedding(num_embeddings=2, embedding_dim=self.embeddings_dim) 52 | self.hybr_em = nn.Embedding(num_embeddings=4, embedding_dim=self.embeddings_dim) 53 | self.arom_em = nn.Embedding(num_embeddings=2, embedding_dim=self.embeddings_dim) 54 | 55 | if self.qml: 56 | self.chrg_em = nn.Linear(1, self.embeddings_dim) 57 | self.pre_egnn_mlp_input_dim = self.embeddings_dim * 5 58 | self.chrg_em.apply(weights_init) 59 | else: 60 | self.pre_egnn_mlp_input_dim = self.embeddings_dim * 4 61 | 62 | self.pre_egnn_mlp = nn.Sequential( 63 | nn.Linear(self.pre_egnn_mlp_input_dim, self.kernel_dim * 2), 64 | self.dropout, 65 | nn.SiLU(), 66 | nn.Linear(self.kernel_dim * 2, self.kernel_dim), 67 | nn.SiLU(), 68 | nn.Linear(self.kernel_dim, self.kernel_dim), 69 | nn.SiLU(), 70 | ) 71 | 72 | self.kernels = nn.ModuleList() 73 | for _ in range(self.n_kernels): 74 | if self.geometry: 75 | self.kernels.append( 76 | EGNN_sparse3D( 77 | feats_dim=self.kernel_dim, 78 | m_dim=self.m_dim, 79 | aggr=self.aggr, 80 | ) 81 | ) 82 | else: 83 | self.kernels.append( 84 | EGNN_sparse( 85 | feats_dim=self.kernel_dim, 86 | m_dim=self.m_dim, 87 | aggr=self.aggr, 88 | ) 89 | ) 90 | 91 | self.post_egnn_mlp = nn.Sequential( 92 | nn.Linear(self.kernel_dim * self.n_kernels, self.mlp_dim), 93 | self.dropout, 94 | nn.SiLU(), 95 | nn.Linear(self.mlp_dim, self.mlp_dim), 96 | nn.SiLU(), 97 | nn.Linear(self.mlp_dim, self.mlp_dim), 98 | nn.SiLU(), 99 | nn.Linear(self.mlp_dim, 1), 100 | nn.Sigmoid(), 101 | ) 102 | 103 | self.kernels.apply(weights_init) 104 | self.pre_egnn_mlp.apply(weights_init) 105 | self.post_egnn_mlp.apply(weights_init) 106 | nn.init.xavier_uniform_(self.atom_em.weight) 107 | nn.init.xavier_uniform_(self.ring_em.weight) 108 | nn.init.xavier_uniform_(self.hybr_em.weight) 109 | nn.init.xavier_uniform_(self.arom_em.weight) 110 | 111 | def forward(self, g_batch): 112 | """Forward pass of the atomistic GNN. 113 | 114 | :param g_batch: Input graph. 115 | :type g_batch: class 116 | :return: Regioselectivity, 0 - 1 per atom. 117 | :rtype: Tensor 118 | """ 119 | if self.qml: 120 | features = self.pre_egnn_mlp( 121 | torch.cat( 122 | [ 123 | self.atom_em(g_batch.atom_id), 124 | self.ring_em(g_batch.ring_id), 125 | self.hybr_em(g_batch.hybr_id), 126 | self.arom_em(g_batch.arom_id), 127 | self.chrg_em(g_batch.charges), 128 | ], 129 | dim=1, 130 | ) 131 | ) 132 | else: 133 | features = self.pre_egnn_mlp( 134 | torch.cat( 135 | [ 136 | self.atom_em(g_batch.atom_id), 137 | self.ring_em(g_batch.ring_id), 138 | self.hybr_em(g_batch.hybr_id), 139 | self.arom_em(g_batch.arom_id), 140 | ], 141 | dim=1, 142 | ) 143 | ) 144 | 145 | feature_list = [] 146 | if self.geometry: 147 | features = torch.cat([g_batch.crds_3d, features], dim=1) 148 | for kernel in self.kernels: 149 | features = kernel( 150 | x=features, 151 | edge_index=g_batch.edge_index, 152 | ) 153 | feature_list.append(features[:, self.pos_dim :]) 154 | else: 155 | for kernel in self.kernels: 156 | features = kernel(x=features, edge_index=g_batch.edge_index) 157 | feature_list.append(features) 158 | 159 | features = torch.cat(feature_list, dim=1) 160 | features = self.post_egnn_mlp(features).squeeze(1) 161 | 162 | return features 163 | -------------------------------------------------------------------------------- /lsfml/literature/regioselectivity/net_utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | import random 7 | 8 | import h5py 9 | import numpy as np 10 | import torch 11 | from torch_geometric.data import Data 12 | from lsfml.modules.pygdataset import Dataset 13 | 14 | random.seed(2) 15 | 16 | 17 | def get_rxn_ids( 18 | data, 19 | ): 20 | """Generates the data set split into training, validation and test sets. 21 | 22 | :param data: Path to h5 file, including preprocessed data, defaults to "../../data/literature_regio.h5" 23 | :type data: str, optional 24 | :return: Reaction IDs for training, validation and test split 25 | :rtype: list[str] 26 | """ 27 | # Load data from h5 file 28 | h5f = h5py.File(data) 29 | 30 | # Load all rxn keys 31 | rxn_ids = list(h5f.keys()) 32 | random.shuffle(rxn_ids) 33 | 34 | # Define subset of rxn keys 35 | tran_ids = rxn_ids[: int(len(rxn_ids) / 2)] 36 | eval_ids = rxn_ids[int(len(rxn_ids) / 4) * 3 :] 37 | test_ids = rxn_ids[int(len(rxn_ids) / 2) : int(len(rxn_ids) / 4) * 3] 38 | 39 | return tran_ids, eval_ids, test_ids 40 | 41 | 42 | class DataLSF(Dataset): 43 | """Generates the desired graph objects (2D, 3D, QM) from reading the h5 files.""" 44 | 45 | def __init__( 46 | self, 47 | rxn_ids, 48 | data, 49 | graph_dim, 50 | ): 51 | """Initialization. 52 | 53 | :param rxn_ids: Reaction IDs from the given split (train, eval, test) 54 | :type rxn_ids: list[str] 55 | :param data: Path to h5 file, including preprocessed data, defaults to "../../data/literature_regio.h5" 56 | :type data: str, optional 57 | :param graph_dim: Indicating 2D or 3D graph structure ("edge_2d" or "edge_3d"), defaults to "edge_2d" 58 | :type graph_dim: str, optional 59 | """ 60 | # Define inputs 61 | self.graph_dim = graph_dim 62 | self.rxn_ids = rxn_ids 63 | 64 | # Load data from h5 file 65 | self.h5f = h5py.File(data) 66 | 67 | # Generate dict (int to rxn keys) 68 | nums = list(range(0, len(self.rxn_ids))) 69 | self.idx2rxn = {} 70 | for x in range(len(self.rxn_ids)): 71 | self.idx2rxn[nums[x]] = self.rxn_ids[x] 72 | 73 | print("\nLoader initialized:") 74 | print(f"Number of reactions loaded: {len(self.rxn_ids)}") 75 | print(f"Chosen graph_dim (edge_2d of edge_3d): {self.graph_dim}") 76 | 77 | def __getitem__(self, idx): 78 | """Loop over data. 79 | 80 | :param idx: Reaction ID 81 | :type idx: str 82 | :return: Input graph for the neural network. 83 | :rtype: torch_geometric.loader.dataloader.DataLoader 84 | """ 85 | # int to rxn_id 86 | rxn_id = self.idx2rxn[idx] 87 | 88 | # Molecule 89 | atom_id = np.array(self.h5f[str(rxn_id)]["atom_id"]) 90 | ring_id = np.array(self.h5f[str(rxn_id)]["ring_id"]) 91 | hybr_id = np.array(self.h5f[str(rxn_id)]["hybr_id"]) 92 | arom_id = np.array(self.h5f[str(rxn_id)]["arom_id"]) 93 | charges = np.array(self.h5f[str(rxn_id)]["charges"]) 94 | crds_3d = np.array(self.h5f[str(rxn_id)]["crds_3d"]) 95 | pot_trg = np.array(self.h5f[str(rxn_id)]["pot_trg"]) 96 | # print(idx, rxn_id, atom_id) 97 | 98 | # Edge IDs with desired dimension 99 | edge_index = np.array(self.h5f[str(rxn_id)][self.graph_dim]) 100 | 101 | # Tragets 102 | rxn_trg = np.array(self.h5f[str(rxn_id)]["reg_trg"]) 103 | 104 | num_nodes = torch.LongTensor(atom_id).size(0) 105 | 106 | graph_data = Data( 107 | atom_id=torch.LongTensor(atom_id), 108 | pot_trg=torch.LongTensor(pot_trg), 109 | ring_id=torch.LongTensor(ring_id), 110 | hybr_id=torch.LongTensor(hybr_id), 111 | arom_id=torch.LongTensor(arom_id), 112 | charges=torch.FloatTensor(charges), 113 | crds_3d=torch.FloatTensor(crds_3d), 114 | rxn_trg=torch.FloatTensor(rxn_trg), 115 | edge_index=torch.LongTensor(edge_index), 116 | num_nodes=num_nodes, 117 | rxn_id=rxn_id, 118 | ) 119 | 120 | return graph_data 121 | 122 | def __len__(self): 123 | """Get length 124 | 125 | :return: length 126 | :rtype: int 127 | """ 128 | return len(self.rxn_ids) 129 | -------------------------------------------------------------------------------- /lsfml/literature/regioselectivity/production.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | from rdkit import Chem 7 | from rdkit.Chem import Draw 8 | from rdkit.Chem.Draw import SimilarityMaps 9 | 10 | import configparser 11 | import io 12 | import os 13 | import time 14 | 15 | import numpy as np 16 | import torch 17 | from PIL import Image 18 | from torch_geometric.data import Data 19 | 20 | from lsfml.literature.regioselectivity.graph_mapping import ( 21 | get_regioselectivity, 22 | ) 23 | from lsfml.literature.regioselectivity.net import ( 24 | Atomistic_EGNN, 25 | ) 26 | 27 | ATOMTYPE_DICT = {"H": 0, "C": 1, "N": 2, "O": 3, "F": 4, "P": 5, "S": 6, "Cl": 7, "Br": 8, "I": 9} 28 | 29 | 30 | def get_predictions(smiles, models): 31 | """Main function to apply regioselectivity prediction given a SMILES-string and a model id. 32 | 33 | :param smiles: SMILES-string 34 | :type smiles: str 35 | :param models: model id 36 | :type models: str 37 | """ 38 | CONFIG_PATH = "config/" 39 | config = configparser.ConfigParser() 40 | 41 | for model_id in models: 42 | # Load model 43 | CONFIG_NAME = f"config_{model_id}.ini" 44 | config.read(CONFIG_PATH + CONFIG_NAME) 45 | 46 | N_KERNELS = int(config["PARAMS"]["N_KERNELS"]) 47 | D_MLP = int(config["PARAMS"]["D_MLP"]) 48 | D_KERNEL = int(config["PARAMS"]["D_KERNEL"]) 49 | D_EMBEDDING = int(config["PARAMS"]["D_EMBEDDING"]) 50 | QML = int(config["PARAMS"]["QML"]) 51 | GEOMETRY = int(config["PARAMS"]["GEOMETRY"]) 52 | QML = True if QML >= 1 else False 53 | GEOMETRY = True if GEOMETRY >= 1 else False 54 | 55 | model = Atomistic_EGNN( 56 | n_kernels=N_KERNELS, 57 | mlp_dim=D_MLP, 58 | kernel_dim=D_KERNEL, 59 | embeddings_dim=D_EMBEDDING, 60 | qml=QML, 61 | geometry=GEOMETRY, 62 | ) 63 | 64 | model.load_state_dict( 65 | torch.load( 66 | f"models/config_{model_id}_1.pt", 67 | map_location=torch.device("cpu"), 68 | ) 69 | ) 70 | 71 | for j, smi in enumerate(smiles): 72 | name = f"regiosel_{j}" 73 | print(name, smi) 74 | preds_stat = [] 75 | 76 | for k in range(3): 77 | seeds = [0xF00A, 0xF00B, 0xF00C, 0xF00E, 0xF00F, 0xF10D, 0xF20D, 0xF00D] 78 | pred_list = [] 79 | 80 | for k in seeds: 81 | ( 82 | atom_id, 83 | ring_id, 84 | hybr_id, 85 | arom_id, 86 | charges, 87 | edge_2d, 88 | edge_3d, 89 | crds_3d, 90 | pot_trg, 91 | rxn_trg, 92 | ) = get_regioselectivity(smi, k) 93 | 94 | # Generate graph 95 | num_nodes = torch.LongTensor(atom_id).size(0) 96 | 97 | graph_data = Data( 98 | atom_id=torch.LongTensor(atom_id), 99 | ring_id=torch.LongTensor(ring_id), 100 | hybr_id=torch.LongTensor(hybr_id), 101 | arom_id=torch.LongTensor(arom_id), 102 | charges=torch.FloatTensor(charges), 103 | crds_3d=torch.FloatTensor(crds_3d), 104 | rxn_trg=torch.FloatTensor(rxn_trg), 105 | edge_index=torch.LongTensor(edge_3d), # TODO: !!! 106 | num_nodes=num_nodes, 107 | ) 108 | 109 | pred = model(graph_data) 110 | pred = [float(item) for item in pred] 111 | pred_list.append(pred) 112 | 113 | preds_stat.append(np.mean(np.array(pred_list), axis=0)) 114 | 115 | pred = np.mean(np.array(preds_stat), axis=0) 116 | pred_stds = np.std(np.array(preds_stat), axis=0) 117 | 118 | # Get image 119 | mol_no_Hs = Chem.MolFromSmiles(smi) 120 | mol = Chem.AddHs(mol_no_Hs) 121 | 122 | atomids = [] 123 | pot_trg = [] 124 | for idx, i in enumerate(mol.GetAtoms()): 125 | atomids.append(ATOMTYPE_DICT[i.GetSymbol()]) 126 | nghbrs = [x.GetSymbol() for x in i.GetNeighbors()] 127 | if (i.GetSymbol() == "C") and ("H" in nghbrs): 128 | pot_trg.append(1) 129 | else: 130 | pot_trg.append(0) 131 | 132 | atomids = np.array(atomids) 133 | pot_trg = np.array(pot_trg) 134 | 135 | trth = rxn_trg 136 | pred = pred 137 | for i, x in enumerate(pred): 138 | if atom_id[i] == 1: 139 | print(int(x * 10000) / 100, int(pred_stds[i] * 10000) / 100, hybr_id[i], atom_id[i]) 140 | 141 | # Remove Hs for image 142 | trth2 = [int(-x) for i, x in enumerate(trth) if atomids[i] != 0] 143 | pred2 = [-x for i, x in enumerate(pred) if atomids[i] != 0] 144 | pot_trg = [int(x) for i, x in enumerate(pot_trg) if atomids[i] != 0] 145 | pred2 = [x * pot_trg[i] for i, x in enumerate(pred2)] 146 | trth2 = [float(x * pot_trg[i]) for i, x in enumerate(trth2)] 147 | RemoveHs = Chem.MolFromSmiles(smi) 148 | 149 | # Image 150 | d = Draw.MolDraw2DCairo(650, 650) 151 | d.SetFontSize(26) 152 | s = d.drawOptions() 153 | s.bondLineWidth = 6 154 | SimilarityMaps.GetSimilarityMapFromWeights(RemoveHs, list(pred2), draw2d=d) 155 | d.FinishDrawing() 156 | data = d.GetDrawingText() 157 | bio = io.BytesIO(data) 158 | img = Image.open(bio) 159 | img.save(f"regiosel_imgs/AVG_mol_{name}_{model_id}_pred.png") 160 | time.sleep(3) 161 | 162 | 163 | if __name__ == "__main__": 164 | os.makedirs("regiosel_imgs", exist_ok=True) 165 | smiles = [ 166 | "c1cc(OC)cc2cc(C(=O)OCC)[nH]c12", 167 | "c1c(Br)ccc2c1c(C(=O)C(C)(C)C)cn2S(=O)(=O)c1ccc(C)cc1", 168 | "c1ccc(C(=O)N(CCCCCC)CCCCCC)cc1Br", 169 | "O=C(N1CCOCC1)C2=C(Br)C=CC=C2", 170 | "O=C(OCC)c(ccc1)c2c1cc[nH]2", 171 | "CSC1=CC(C2OCCO2)=CC=C1", 172 | ] 173 | models = [ 174 | "161", 175 | ] 176 | get_predictions(smiles, models) 177 | -------------------------------------------------------------------------------- /lsfml/literature/regioselectivity/train.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | 7 | import argparse 8 | import configparser 9 | import os 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | from torch_geometric.loader import DataLoader 15 | 16 | from lsfml.literature.regioselectivity.net import ( 17 | Atomistic_EGNN, 18 | ) 19 | from lsfml.literature.regioselectivity.net_utils import ( 20 | DataLSF, 21 | get_rxn_ids, 22 | ) 23 | from lsfml.utils import mae_loss, UTILS_PATH 24 | 25 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | RXN_DATA = os.path.join(UTILS_PATH, "data/literature_regio.h5") 27 | 28 | 29 | def train( 30 | model, 31 | optimizer, 32 | criterion, 33 | train_loader, 34 | ): 35 | """Train loop. 36 | 37 | :param model: Model 38 | :type model: class 39 | :param optimizer: Optimizer 40 | :type optimizer: class 41 | :param criterion: Loss 42 | :type criterion: class 43 | :param train_loader: Data loader 44 | :type train_loader: torch_geometric.loader.dataloader.DataLoader 45 | :return: RMSE Loss 46 | :rtype: numpy.float64 47 | """ 48 | model.train() 49 | training_loss = [] 50 | 51 | for g in train_loader: 52 | g = g.to(DEVICE) 53 | optimizer.zero_grad() 54 | 55 | pred = model(g) 56 | 57 | loss = criterion(pred, g.rxn_trg) 58 | loss.backward() 59 | optimizer.step() 60 | 61 | with torch.no_grad(): 62 | mae = mae_loss(pred, g.rxn_trg) 63 | training_loss.append(mae) 64 | 65 | return np.mean(training_loss) 66 | 67 | 68 | def eval( 69 | model, 70 | eval_loader, 71 | ): 72 | """Validation & test loop. 73 | 74 | :param model: Model 75 | :type model: class 76 | :param eval_loader: Data loader 77 | :type eval_loader: torch_geometric.loader.dataloader.DataLoader 78 | :return: tuple including essential information to quantify network perfromance such as MAE, predirctions, labels etc. 79 | :rtype: tuple 80 | """ 81 | model.eval() 82 | eval_loss = [] 83 | 84 | preds = [] 85 | ys = [] 86 | 87 | rxn_ids = [] 88 | atm_ids = [] 89 | pt_trgs = [] 90 | 91 | with torch.no_grad(): 92 | for g in eval_loader: 93 | g = g.to(DEVICE) 94 | pred = model(g) 95 | mae = mae_loss(pred, g.rxn_trg) 96 | eval_loss.append(mae) 97 | ys.append(g.rxn_trg) 98 | preds.append(pred) 99 | rxn_ids.append(g.rxn_id) 100 | atm_ids.append(g.atom_id) 101 | pt_trgs.append(g.pot_trg) 102 | 103 | return np.mean(eval_loss), ys, preds, rxn_ids, pt_trgs 104 | 105 | 106 | if __name__ == "__main__": 107 | # python train.py -config 141 -mode a -cv 1 -early_stop 0 108 | 109 | # Make Folders for Results and Models 110 | os.makedirs("results/", exist_ok=True) 111 | os.makedirs("models/", exist_ok=True) 112 | 113 | # Read Passed Arguments 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument("-config", type=str, default="100") 116 | parser.add_argument("-mode", type=str, default="a") 117 | parser.add_argument("-cv", type=str, default="1") 118 | parser.add_argument("-early_stop", type=int, default=1) 119 | args = parser.parse_args() 120 | 121 | # Define Configuration form Model and Dataset 122 | config = configparser.ConfigParser() 123 | CONFIG_PATH = os.path.join(UTILS_PATH, f"config/config_{str(args.config)}.ini") 124 | config.read(CONFIG_PATH) 125 | print({section: dict(config[section]) for section in config.sections()}) 126 | early_stop = True if args.early_stop >= 1 else False 127 | 128 | LR_FACTOR = float(config["PARAMS"]["LR_FACTOR"]) 129 | LR_STEP_SIZE = int(config["PARAMS"]["LR_STEP_SIZE"]) 130 | N_KERNELS = int(config["PARAMS"]["N_KERNELS"]) 131 | D_MLP = int(config["PARAMS"]["D_MLP"]) 132 | D_KERNEL = int(config["PARAMS"]["D_KERNEL"]) 133 | D_EMBEDDING = int(config["PARAMS"]["D_EMBEDDING"]) 134 | BATCH_SIZE = int(config["PARAMS"]["BATCH_SIZE"]) 135 | QML = int(config["PARAMS"]["QML"]) 136 | GEOMETRY = int(config["PARAMS"]["GEOMETRY"]) 137 | QML = True if QML >= 1 else False 138 | GEOMETRY = True if GEOMETRY >= 1 else False 139 | GRAPH_DIM = "edge_3d" if GEOMETRY >= 1 else "edge_2d" 140 | 141 | # Initialize Model 142 | model = Atomistic_EGNN( 143 | n_kernels=N_KERNELS, 144 | mlp_dim=D_MLP, 145 | kernel_dim=D_KERNEL, 146 | embeddings_dim=D_EMBEDDING, 147 | qml=QML, 148 | geometry=GEOMETRY, 149 | ) 150 | model = model.to(DEVICE) 151 | 152 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 153 | model_parameters = sum([np.prod(e.size()) for e in model_parameters]) 154 | print("\nmodel_parameters", model_parameters) 155 | 156 | optimizer = torch.optim.Adam( 157 | model.parameters(), 158 | lr=LR_FACTOR, 159 | weight_decay=1e-10, 160 | ) 161 | criterion = nn.MSELoss() 162 | scheduler = torch.optim.lr_scheduler.StepLR( 163 | optimizer, 164 | step_size=LR_STEP_SIZE, 165 | gamma=0.5, 166 | verbose=False, 167 | ) 168 | 169 | # Neural Netowork Training 170 | tr_losses = [] 171 | ev_losses = [] 172 | 173 | if early_stop: 174 | # Get Datasets 175 | tran_ids, eval_ids, test_ids = get_rxn_ids( 176 | data=RXN_DATA, 177 | ) 178 | train_data = DataLSF( 179 | rxn_ids=tran_ids, 180 | data=RXN_DATA, 181 | graph_dim=GRAPH_DIM, 182 | ) 183 | train_loader = DataLoader( 184 | train_data, 185 | batch_size=BATCH_SIZE, 186 | shuffle=True, 187 | num_workers=2, 188 | ) 189 | eval_data = DataLSF( 190 | rxn_ids=eval_ids, 191 | data=RXN_DATA, 192 | graph_dim=GRAPH_DIM, 193 | ) 194 | eval_loader = DataLoader( 195 | eval_data, 196 | batch_size=BATCH_SIZE, 197 | shuffle=True, 198 | num_workers=2, 199 | ) 200 | test_data = DataLSF( 201 | rxn_ids=test_ids, 202 | data=RXN_DATA, 203 | graph_dim=GRAPH_DIM, 204 | ) 205 | test_loader = DataLoader( 206 | test_data, 207 | batch_size=BATCH_SIZE, 208 | shuffle=True, 209 | num_workers=2, 210 | ) 211 | 212 | # Training with Early Stopping 213 | min_mae = 100 214 | 215 | for epoch in range(1000): 216 | tr_l = train(model, optimizer, criterion, train_loader) 217 | ev_l, ev_ys, ev_pred, ev_rxns, ev_pt_trgs = eval(model, eval_loader) 218 | tr_losses.append(tr_l) 219 | ev_losses.append(ev_l) 220 | scheduler.step() 221 | 222 | print(epoch, tr_l, ev_l) 223 | 224 | if ev_l <= min_mae: 225 | # Define new min-loss 226 | min_mae = ev_l 227 | 228 | # Test model 229 | te_l, te_ys, te_pred, te_rxns, te_pt_trgs = eval(model, test_loader) 230 | 231 | ys_saved = [float(item) for sublist in te_ys for item in sublist] 232 | pred_saved = [float(item) for sublist in te_pred for item in sublist] 233 | rxns_saved = [str(item) for sublist in te_rxns for item in sublist] 234 | pt_trgs_saved = [int(item) for sublist in te_pt_trgs for item in sublist] 235 | 236 | print(len(ys_saved), len(pred_saved), len(rxns_saved), len(pt_trgs_saved)) 237 | 238 | # Save Model and Save Loos + Predictions 239 | torch.save(model.state_dict(), f"models/config_{args.config}_{args.cv}.pt") 240 | torch.save( 241 | [tr_losses, ev_losses, ys_saved, pred_saved, rxns_saved, pt_trgs_saved], 242 | f"results/config_{args.config}_{args.cv}.pt", 243 | ) 244 | else: 245 | # Get Datasets 246 | tran_ids, eval_ids, test_ids = get_rxn_ids( 247 | data=RXN_DATA, 248 | ) 249 | tran_ids += eval_ids 250 | train_data = DataLSF( 251 | rxn_ids=tran_ids, 252 | data=RXN_DATA, 253 | graph_dim=GRAPH_DIM, 254 | ) 255 | train_loader = DataLoader( 256 | train_data, 257 | batch_size=BATCH_SIZE, 258 | shuffle=True, 259 | num_workers=2, 260 | ) 261 | test_data = DataLSF( 262 | rxn_ids=test_ids, 263 | data=RXN_DATA, 264 | graph_dim=GRAPH_DIM, 265 | ) 266 | test_loader = DataLoader( 267 | test_data, 268 | batch_size=BATCH_SIZE, 269 | shuffle=True, 270 | num_workers=2, 271 | ) 272 | 273 | # Training without Early Stopping 274 | for epoch in range(1000): 275 | tr_l = train(model, optimizer, criterion, train_loader) 276 | tr_losses.append(tr_l) 277 | scheduler.step() 278 | 279 | if epoch >= 999: 280 | # Test model 281 | te_l, te_ys, te_pred, te_rxns, te_pt_trgs = eval(model, test_loader) 282 | 283 | ys_saved = [float(item) for sublist in te_ys for item in sublist] 284 | pred_saved = [float(item) for sublist in te_pred for item in sublist] 285 | rxns_saved = [str(item) for sublist in te_rxns for item in sublist] 286 | pt_trgs_saved = [int(item) for sublist in te_pt_trgs for item in sublist] 287 | 288 | # Save Model and Save Loos + Predictions 289 | torch.save(model.state_dict(), f"models/config_{args.config}_{args.cv}.pt") 290 | torch.save( 291 | [tr_losses, ev_losses, ys_saved, pred_saved, rxns_saved, pt_trgs_saved], 292 | f"results/config_{args.config}_{args.cv}.pt", 293 | ) 294 | -------------------------------------------------------------------------------- /lsfml/literature/rxnyield/net_utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | import random 7 | 8 | import h5py 9 | import numpy as np 10 | import torch 11 | from torch_geometric.data import Data 12 | from lsfml.modules.pygdataset import Dataset 13 | 14 | random.seed(2) 15 | 16 | 17 | def get_rxn_ids( 18 | data, 19 | ): 20 | """Generates the data set split into training, validation and test sets. 21 | 22 | :param data: Path to h5 file, including preprocessed data, defaults to "../../data/literature_rxndata.h5"" 23 | :type data: str, optional 24 | :return: Reaction IDs for training, validation and test split 25 | :rtype: list[str] 26 | """ 27 | # Load data from h5 file 28 | h5f = h5py.File(data) 29 | 30 | # Load all rxn keys 31 | rxn_ids = list(h5f.keys()) 32 | random.shuffle(rxn_ids) 33 | 34 | # Define subset of rxn keys 35 | tran_ids = rxn_ids[: int(len(rxn_ids) / 2)] 36 | eval_ids = rxn_ids[int(len(rxn_ids) / 2) : int(len(rxn_ids) / 4) * 3] 37 | test_ids = rxn_ids[int(len(rxn_ids) / 4) * 3 :] 38 | 39 | return tran_ids, eval_ids, test_ids 40 | 41 | 42 | class DataLSF(Dataset): 43 | def __init__( 44 | self, 45 | rxn_ids, 46 | data, 47 | graph_dim, 48 | fingerprint, 49 | ): 50 | """Initialization. 51 | 52 | :param rxn_ids: Reaction IDs from the given split (train, eval, test) 53 | :type rxn_ids: list[str] 54 | :param data: Path to h5 file, including preprocessed data, defaults to "../../data/literature_rxndata.h5"" 55 | :type data: str, optional 56 | :param graph_dim: Indicating 2D or 3D graph structure ("edge_2d" or "edge_3d"), defaults to "edge_2d" 57 | :type graph_dim: str, optional 58 | :param fingerprint: Indicating fingerprint type (ecfp4_2 or None), defaults to "ecfp4_2" 59 | :type target: str, optional 60 | """ 61 | # Define inputs 62 | self.graph_dim = graph_dim 63 | self.fingerprint = fingerprint 64 | self.rxn_ids = rxn_ids 65 | 66 | # Load data from h5 file 67 | self.h5f = h5py.File(data) 68 | 69 | # Generate dict (int to rxn keys) 70 | nums = list(range(0, len(self.rxn_ids))) 71 | self.idx2rxn = {} 72 | for x in range(len(self.rxn_ids)): 73 | self.idx2rxn[nums[x]] = self.rxn_ids[x] 74 | 75 | print("\nLoader initialized:") 76 | print(f"Number of reactions loaded: {len(self.rxn_ids)}") 77 | print(f"Chosen graph_dim (edge_2d of edge_3d): {self.graph_dim}") 78 | print(f"Chosen fingerprint (ecfp4_2 of ecfp6_1): {self.fingerprint}") 79 | 80 | def __getitem__(self, idx): 81 | """Loop over data. 82 | 83 | :param idx: Reaction ID 84 | :type idx: str 85 | :return: Input graph for the neural network. 86 | :rtype: torch_geometric.loader.dataloader.DataLoader 87 | """ 88 | # int to rxn_id 89 | rxn_id = self.idx2rxn[idx] 90 | 91 | # Molecule 92 | atom_id = np.array(self.h5f[str(rxn_id)]["atom_id"]) 93 | ring_id = np.array(self.h5f[str(rxn_id)]["ring_id"]) 94 | hybr_id = np.array(self.h5f[str(rxn_id)]["hybr_id"]) 95 | arom_id = np.array(self.h5f[str(rxn_id)]["arom_id"]) 96 | charges = np.array(self.h5f[str(rxn_id)]["charges"]) 97 | crds_3d = np.array(self.h5f[str(rxn_id)]["crds_3d"]) 98 | 99 | if self.fingerprint is not None: 100 | ecfp_fp = np.array(self.h5f[str(rxn_id)][self.fingerprint]) 101 | else: 102 | ecfp_fp = np.array([]) 103 | 104 | # Edge IDs with desired dimension 105 | edge_index = np.array(self.h5f[str(rxn_id)][self.graph_dim]) 106 | 107 | # Conditions 108 | rgnt_id = np.array(self.h5f[str(rxn_id)]["rgnt_id"]) 109 | lgnd_id = np.array(self.h5f[str(rxn_id)]["lgnd_id"]) 110 | clst_id = np.array(self.h5f[str(rxn_id)]["clst_id"]) 111 | slvn_id = np.array(self.h5f[str(rxn_id)]["slvn_id"]) 112 | 113 | # Tragets 114 | rxn_trg = np.array(self.h5f[str(rxn_id)]["trg_rxn"]) 115 | 116 | num_nodes = torch.LongTensor(atom_id).size(0) 117 | 118 | graph_data = Data( 119 | atom_id=torch.LongTensor(atom_id), 120 | ring_id=torch.LongTensor(ring_id), 121 | hybr_id=torch.LongTensor(hybr_id), 122 | arom_id=torch.LongTensor(arom_id), 123 | rgnt_id=torch.LongTensor(rgnt_id), 124 | lgnd_id=torch.LongTensor(lgnd_id), 125 | clst_id=torch.LongTensor(clst_id), 126 | slvn_id=torch.LongTensor(slvn_id), 127 | charges=torch.FloatTensor(charges), 128 | ecfp_fp=torch.FloatTensor(ecfp_fp), 129 | crds_3d=torch.FloatTensor(crds_3d), 130 | rxn_trg=torch.FloatTensor(rxn_trg), 131 | edge_index=torch.LongTensor(edge_index), 132 | num_nodes=num_nodes, 133 | ) 134 | 135 | return graph_data 136 | 137 | def __len__(self): 138 | """Get length 139 | 140 | :return: length 141 | :rtype: int 142 | """ 143 | return len(self.rxn_ids) 144 | -------------------------------------------------------------------------------- /lsfml/literature/rxnyield/preprocessh5.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | import h5py, os 7 | import networkx as nx 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | from rdkit import Chem 12 | from rdkit.Chem import AllChem, rdMolDescriptors 13 | from torch_geometric.data import Data 14 | from torch_geometric.utils import add_self_loops 15 | from torch_geometric.utils.undirected import to_undirected 16 | from tqdm import tqdm 17 | from scipy.spatial.distance import pdist, squareform 18 | 19 | from lsfml.qml.prod import get_model 20 | from lsfml.utils import ( 21 | get_dict_for_embedding, 22 | HYBRIDISATIONS, 23 | AROMATOCITY, 24 | IS_RING, 25 | ATOMTYPES, 26 | QML_ATOMTYPES, 27 | UTILS_PATH, 28 | ) 29 | 30 | QMLMODEL = get_model(gpu=False) 31 | 32 | 33 | HYBRIDISATION_DICT = get_dict_for_embedding(HYBRIDISATIONS) 34 | AROMATOCITY_DICT = get_dict_for_embedding(AROMATOCITY) 35 | IS_RING_DICT = get_dict_for_embedding(IS_RING) 36 | ATOMTYPE_DICT = get_dict_for_embedding(ATOMTYPES) 37 | QML_ATOMTYPE_DICT = get_dict_for_embedding(QML_ATOMTYPES) 38 | 39 | sol_dict = { 40 | "O1CCCC1": 0, 41 | "C=1C=C(C=CC1C)C": 1, 42 | "N#CC": 2, 43 | "CCCCCC": 3, 44 | "O(C)C1CCCC1": 4, 45 | "O(C)C(C)(C)C": 5, 46 | "C1CCCCC1": 6, 47 | "C1CCCCCCC1": 7, 48 | "ClCCCl": 8, 49 | } 50 | 51 | rea_dict = { 52 | "O1B(OC(C)(C)C1(C)C)B2OC(C)(C)C(O2)(C)C": 0, 53 | "O1BOC(C)(C)C1(C)C": 1, 54 | } 55 | 56 | cat_dict = { 57 | "[O-]1(C)[Ir+]234([O-](C)[Ir+]1567[CH]=8CC[CH]7=[CH]6CC[CH]85)[CH]=9CC[CH]4=[CH]3CC[CH]92": 0, 58 | "[Cl-]1[Ir+]234([Cl-][Ir+]1567[CH]=8CC[CH]7=[CH]6CC[CH]85)[CH]=9CC[CH]4=[CH]3CC[CH]92": 1, 59 | "[OH-]1[Ir+]234([OH-][Ir+]1567[CH]=8CC[CH]7=[CH]6CC[CH]85)[CH]=9CC[CH]4=[CH]3CC[CH]92": 2, 60 | "O1=C([CH-]C(=O[Ir+]1234[CH]=5CC[CH]4=[CH]3CC[CH]52)C)C": 3, 61 | } 62 | 63 | lig_dict = { 64 | "N=1C=CC(=CC1C=2N=CC=C(C2)C(C)(C)C)C(C)(C)C": 0, 65 | "O=C1C=CC=2C=CC=C(C3=CN=C(C=C3)C=4N=CC=CC4)C2N1": 1, 66 | "N=1C=C(C(=C2C=CC3=C(N=CC(=C3C)C)C12)C)C": 2, 67 | "O=S(=O)([O-])CC=1C=NC(=CC1)C2=NC=C(C=C2)C.CCCC[N+](CCCC)(CCCC)CCCC": 3, 68 | "O=C(NC=1C=CC=CC1C=2C=NC(=CC2)C3=NC=CC=C3)NC4CCCCC4": 4, 69 | "N=1C=CC=CC1N2B(NC=3C=CC=CC32)B4NC=5C=CC=CC5N4C6=NC=CC=C6": 5, 70 | "O=C(NC1=CC=CC2=C1NC(=C2C)C)C=3C=NC(=CC3)C4=NC=CC=C4": 6, 71 | "N=1C=CC=C2C=CC=3C=CC(=NC3C12)C": 7, 72 | "O(C1=CC=CC(=C1C=2C(OC)=CC=CC2P(C=3C=C(C=C(C3)C)C)C=4C=C(C=C(C4)C)C)P(C=5C=C(C=C(C5)C)C)C=6C=C(C=C(C6)C)C)C": 8, 73 | "N=1C=CC(=CC1C=2N=CC=C(C2)C)C": 9, 74 | "FC(F)(F)C1OB(OC1)C=2C=CC=CC2C=3C=NC(=CC3)C4=NC=CC=C4": 10, 75 | "N=1C=CC=CC1C=2N=CC=CC2": 11, 76 | } 77 | 78 | 79 | def get_info_from_smi(smi, radius): 80 | """Main function for extracting relevant reaction conditions and generating the 2D and 3D molecular graphs given a SMILES-string. 81 | 82 | :param smi: SMILES-string 83 | :type smi: str 84 | :return: tuple including all graph-relevant numpy arrays 85 | :rtype: tuple 86 | """ 87 | 88 | # Get mol objects from smiles 89 | mol_no_Hs = Chem.MolFromSmiles(smi) 90 | mol = Chem.rdmolops.AddHs(mol_no_Hs) 91 | 92 | ecfp4_fp = np.array(rdMolDescriptors.GetMorganFingerprintAsBitVect(mol, 2, nBits=256)) 93 | 94 | atomids = [] 95 | qml_atomids = [] 96 | is_ring = [] 97 | hyb = [] 98 | arom = [] 99 | crds_3d = [] 100 | 101 | AllChem.EmbedMolecule(mol, randomSeed=0xF00D) 102 | AllChem.UFFOptimizeMolecule(mol) 103 | 104 | for idx, i in enumerate(mol.GetAtoms()): 105 | atomids.append(ATOMTYPE_DICT[i.GetSymbol()]) 106 | qml_atomids.append(QML_ATOMTYPE_DICT[i.GetSymbol()]) 107 | is_ring.append(IS_RING_DICT[str(i.IsInRing())]) 108 | hyb.append(HYBRIDISATION_DICT[str(i.GetHybridization())]) 109 | arom.append(AROMATOCITY_DICT[str(i.GetIsAromatic())]) 110 | crds_3d.append(list(mol.GetConformer().GetAtomPosition(idx))) 111 | 112 | atomids = np.array(atomids) 113 | qml_atomids = np.array(qml_atomids) 114 | is_ring = np.array(is_ring) 115 | hyb = np.array(hyb) 116 | arom = np.array(arom) 117 | crds_3d = np.array(crds_3d) 118 | 119 | # Edges for covalent bonds in sdf file 120 | edge_dir1 = [] 121 | edge_dir2 = [] 122 | for idx, bond in enumerate(mol.GetBonds()): 123 | a2 = bond.GetEndAtomIdx() 124 | a1 = bond.GetBeginAtomIdx() 125 | edge_dir1.append(a1) 126 | edge_dir1.append(a2) 127 | edge_dir2.append(a2) 128 | edge_dir2.append(a1) 129 | 130 | edge_2d = torch.from_numpy(np.array([edge_dir1, edge_dir2])) 131 | 132 | # 3D graph for qml and qml prediction 133 | qml_atomids = torch.LongTensor(qml_atomids) 134 | xyzs = torch.FloatTensor(crds_3d) 135 | edge_index = np.array(nx.complete_graph(qml_atomids.size(0)).edges()) 136 | edge_index = to_undirected(torch.from_numpy(edge_index).t().contiguous()) 137 | edge_index, _ = add_self_loops(edge_index, num_nodes=crds_3d.shape[0]) 138 | 139 | qml_graph = Data( 140 | atomids=qml_atomids, 141 | coords=xyzs, 142 | edge_index=edge_index, 143 | num_nodes=qml_atomids.size(0), 144 | ) 145 | 146 | charges = QMLMODEL(qml_graph).unsqueeze(1).detach().numpy() 147 | 148 | # Get edges for 3d graph 149 | distance_matrix = squareform(pdist(crds_3d)) 150 | np.fill_diagonal(distance_matrix, float("inf")) # to remove self-loops 151 | edge_3d = torch.from_numpy(np.vstack(np.where(distance_matrix <= radius))) 152 | 153 | return ( 154 | atomids, 155 | is_ring, 156 | hyb, 157 | arom, 158 | charges, 159 | edge_2d, 160 | edge_3d, 161 | crds_3d, 162 | ecfp4_fp, 163 | ) 164 | 165 | 166 | if __name__ == "__main__": 167 | df = pd.read_csv(os.path.join(UTILS_PATH, "data/literature_rxndata.csv"), encoding="unicode_escape") 168 | 169 | # Rxn id 170 | rxn_id = list(df["rxn_id"]) 171 | 172 | # Substrate 173 | educt = list(df["stmat_1_smiles"]) 174 | 175 | # Molecular conditions 176 | catalyst = list(df["catalyst_1_smiles"]) 177 | catalyst_eq = list(df["catalyst_1_eq"]) 178 | ligand = list(df["ligand_1_smiles"]) 179 | ligand_eq = list(df["ligand_1_eq"]) 180 | reagent = list(df["reagent_1_smiles"]) 181 | reagent_eq = list(df["reagent_1_eq"]) 182 | solvent = list(df["solvent_1_smiles"]) 183 | solvent_ratio = list(df["solvent_1_fraction"]) 184 | 185 | # Targets 186 | trg = list(df["product_1_yield"]) 187 | 188 | # Get molecule-dict for short rxids 189 | print("Calculating properties for all substartes") 190 | 191 | wins = 0 192 | loss = 0 193 | 194 | print(f"Transforming {len(rxn_id)} reactions into h5 format") 195 | 196 | h5_path = os.path.join(UTILS_PATH, "data/literature_rxndata.h5") 197 | 198 | with h5py.File(h5_path, "w") as lsf_container: 199 | for idx, rxn_key in enumerate(tqdm(rxn_id)): 200 | try: 201 | ( 202 | atom_id, 203 | ring_id, 204 | hybr_id, 205 | arom_id, 206 | charges, 207 | edge_2d, 208 | edge_3d, 209 | crds_3d, 210 | ecfp4_2, 211 | ) = get_info_from_smi(educt[idx], 4) 212 | 213 | rgnt_id = rea_dict[reagent[idx]] 214 | lgnd_id = lig_dict[ligand[idx]] 215 | clst_id = cat_dict[catalyst[idx]] 216 | slvn_id = sol_dict[solvent[idx]] 217 | trg_rxn = trg[idx] 218 | 219 | if np.isnan(trg_rxn): 220 | trg_rxn = 0.0 221 | 222 | # Create group in h5 for this id 223 | lsf_container.create_group(str(rxn_key)) 224 | 225 | # Molecule 226 | lsf_container[str(rxn_key)].create_dataset("atom_id", data=atom_id) 227 | lsf_container[str(rxn_key)].create_dataset("ring_id", data=ring_id) 228 | lsf_container[str(rxn_key)].create_dataset("hybr_id", data=hybr_id) 229 | lsf_container[str(rxn_key)].create_dataset("arom_id", data=arom_id) 230 | lsf_container[str(rxn_key)].create_dataset("edge_2d", data=edge_2d) 231 | lsf_container[str(rxn_key)].create_dataset("edge_3d", data=edge_3d) 232 | lsf_container[str(rxn_key)].create_dataset("charges", data=charges) 233 | lsf_container[str(rxn_key)].create_dataset("crds_3d", data=crds_3d) 234 | lsf_container[str(rxn_key)].create_dataset("ecfp4_2", data=[ecfp4_2]) 235 | 236 | # Conditions 237 | lsf_container[str(rxn_key)].create_dataset("rgnt_id", data=[int(rgnt_id)]) 238 | lsf_container[str(rxn_key)].create_dataset("lgnd_id", data=[int(lgnd_id)]) 239 | lsf_container[str(rxn_key)].create_dataset("clst_id", data=[int(clst_id)]) 240 | lsf_container[str(rxn_key)].create_dataset("slvn_id", data=[int(slvn_id)]) 241 | 242 | # Traget 243 | lsf_container[str(rxn_key)].create_dataset("trg_rxn", data=[trg_rxn]) 244 | 245 | wins += 1 246 | 247 | except: 248 | loss += 1 249 | 250 | print(f"Reactions sucessfully transformed: {wins}; Reactions failed: {loss}") 251 | -------------------------------------------------------------------------------- /lsfml/literature/rxnyield/train.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | import argparse 7 | import configparser 8 | import os 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch_geometric.loader import DataLoader 15 | 16 | from lsfml.literature.rxnyield.net import ( 17 | EGNN, 18 | FNN, 19 | GraphTransformer, 20 | ) 21 | from lsfml.literature.rxnyield.net_utils import ( 22 | DataLSF, 23 | get_rxn_ids, 24 | ) 25 | 26 | from lsfml.utils import mae_loss, UTILS_PATH 27 | 28 | 29 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | RXN_DATA = os.path.join(UTILS_PATH, "data/literature_rxndata.h5") 31 | 32 | 33 | def train( 34 | model, 35 | optimizer, 36 | criterion, 37 | train_loader, 38 | ): 39 | """Train loop. 40 | 41 | :param model: Model 42 | :type model: class 43 | :param optimizer: Optimizer 44 | :type optimizer: class 45 | :param criterion: Loss 46 | :type criterion: class 47 | :param train_loader: Data loader 48 | :type train_loader: torch_geometric.loader.dataloader.DataLoader 49 | :return: RMSE Loss 50 | :rtype: numpy.float64 51 | """ 52 | model.train() 53 | training_loss = [] 54 | 55 | for g in train_loader: 56 | g = g.to(DEVICE) 57 | optimizer.zero_grad() 58 | 59 | pred = model(g) 60 | 61 | loss = criterion(pred, g.rxn_trg) 62 | loss.backward() 63 | optimizer.step() 64 | 65 | with torch.no_grad(): 66 | mae = mae_loss(pred, g.rxn_trg) 67 | training_loss.append(mae) 68 | 69 | return np.mean(training_loss) 70 | 71 | 72 | def eval( 73 | model, 74 | eval_loader, 75 | ): 76 | """Validation & test loop. 77 | 78 | :param model: Model 79 | :type model: class 80 | :param eval_loader: Data loader 81 | :type eval_loader: torch_geometric.loader.dataloader.DataLoader 82 | :return: tuple including essential information to quantify network perfromance such as MAE, predirctions, labels etc. 83 | :rtype: tuple 84 | """ 85 | model.eval() 86 | eval_loss = [] 87 | 88 | preds = [] 89 | ys = [] 90 | 91 | with torch.no_grad(): 92 | for g in eval_loader: 93 | g = g.to(DEVICE) 94 | pred = model(g) 95 | mae = mae_loss(pred, g.rxn_trg) 96 | eval_loss.append(mae) 97 | ys.append(g.rxn_trg) 98 | preds.append(pred) 99 | 100 | return np.mean(eval_loss), ys, preds 101 | 102 | 103 | if __name__ == "__main__": 104 | # python train.py -config 420 -mode a -cv 1 -early_stop 0 105 | 106 | # Make Folders for Results and Models 107 | os.makedirs("results/", exist_ok=True) 108 | os.makedirs("models/", exist_ok=True) 109 | 110 | # Read Passed Arguments 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument("-config", type=str, default="100") 113 | parser.add_argument("-mode", type=str, default="a") 114 | parser.add_argument("-cv", type=str, default="1") 115 | parser.add_argument("-early_stop", type=int, default=1) 116 | args = parser.parse_args() 117 | 118 | # Define Configuration form Model and Dataset 119 | config = configparser.ConfigParser() 120 | CONFIG_PATH = os.path.join(UTILS_PATH, f"config/config_{str(args.config)}.ini") 121 | config.read(CONFIG_PATH) 122 | print({section: dict(config[section]) for section in config.sections()}) 123 | early_stop = True if args.early_stop >= 1 else False 124 | 125 | LR_FACTOR = float(config["PARAMS"]["LR_FACTOR"]) 126 | LR_STEP_SIZE = int(config["PARAMS"]["LR_STEP_SIZE"]) 127 | N_KERNELS = int(config["PARAMS"]["N_KERNELS"]) 128 | POOLING_HEADS = int(config["PARAMS"]["POOLING_HEADS"]) 129 | D_MLP = int(config["PARAMS"]["D_MLP"]) 130 | D_KERNEL = int(config["PARAMS"]["D_KERNEL"]) 131 | D_EMBEDDING = int(config["PARAMS"]["D_EMBEDDING"]) 132 | BATCH_SIZE = int(config["PARAMS"]["BATCH_SIZE"]) 133 | QML = int(config["PARAMS"]["QML"]) 134 | GEOMETRY = int(config["PARAMS"]["GEOMETRY"]) 135 | FINGERPRINT = str(config["PARAMS"]["FINGERPRINT"]) 136 | QML = True if QML >= 1 else False 137 | GEOMETRY = True if GEOMETRY >= 1 else False 138 | GRAPH_DIM = "edge_3d" if GEOMETRY >= 1 else "edge_2d" 139 | FINGERPRINT = FINGERPRINT if args.mode == "c" else None 140 | FP_DIM = 1024 if FINGERPRINT == "ecfp6_1" else 256 141 | 142 | # Initialize Model 143 | if args.mode == "a": 144 | model = GraphTransformer( 145 | n_kernels=N_KERNELS, 146 | pooling_heads=POOLING_HEADS, 147 | mlp_dim=D_MLP, 148 | kernel_dim=D_KERNEL, 149 | embeddings_dim=D_EMBEDDING, 150 | qml=QML, 151 | geometry=GEOMETRY, 152 | ) 153 | elif args.mode == "b": 154 | model = EGNN( 155 | n_kernels=N_KERNELS, 156 | mlp_dim=D_MLP, 157 | kernel_dim=D_KERNEL, 158 | embeddings_dim=D_EMBEDDING, 159 | qml=QML, 160 | geometry=GEOMETRY, 161 | ) 162 | elif args.mode == "c": 163 | model = FNN( 164 | fp_dim=FP_DIM, 165 | mlp_dim=D_MLP, 166 | kernel_dim=D_KERNEL, 167 | embeddings_dim=D_EMBEDDING, 168 | ) 169 | 170 | model = model.to(DEVICE) 171 | 172 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 173 | model_parameters = sum([np.prod(e.size()) for e in model_parameters]) 174 | print("\nmodel_parameters", model_parameters) 175 | 176 | optimizer = torch.optim.Adam(model.parameters(), lr=LR_FACTOR, weight_decay=1e-10) 177 | criterion = nn.MSELoss() 178 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=LR_STEP_SIZE, gamma=0.5, verbose=False) 179 | 180 | # Neural Netowork Training 181 | tr_losses = [] 182 | ev_losses = [] 183 | 184 | if early_stop: 185 | # Get Datasets 186 | tran_ids, eval_ids, test_ids = get_rxn_ids( 187 | data=RXN_DATA, 188 | ) 189 | train_data = DataLSF( 190 | rxn_ids=tran_ids, 191 | data=RXN_DATA, 192 | graph_dim=GRAPH_DIM, 193 | fingerprint=FINGERPRINT, 194 | ) 195 | train_loader = DataLoader( 196 | train_data, 197 | batch_size=BATCH_SIZE, 198 | shuffle=True, 199 | num_workers=2, 200 | ) 201 | eval_data = DataLSF( 202 | rxn_ids=eval_ids, 203 | data=RXN_DATA, 204 | graph_dim=GRAPH_DIM, 205 | fingerprint=FINGERPRINT, 206 | ) 207 | eval_loader = DataLoader( 208 | eval_data, 209 | batch_size=BATCH_SIZE, 210 | shuffle=True, 211 | num_workers=2, 212 | ) 213 | test_data = DataLSF( 214 | rxn_ids=test_ids, 215 | data=RXN_DATA, 216 | graph_dim=GRAPH_DIM, 217 | fingerprint=FINGERPRINT, 218 | ) 219 | test_loader = DataLoader( 220 | test_data, 221 | batch_size=BATCH_SIZE, 222 | shuffle=True, 223 | num_workers=2, 224 | ) 225 | 226 | # Training with Early Stopping 227 | min_mae = 1000 228 | 229 | for epoch in range(1000): 230 | # Training and Eval Loops 231 | tr_l = train(model, optimizer, criterion, train_loader) 232 | ev_l, ev_ys, ev_pred = eval(model, eval_loader) 233 | tr_losses.append(tr_l) 234 | ev_losses.append(ev_l) 235 | scheduler.step() 236 | 237 | if ev_l <= min_mae: 238 | # Define new min-loss 239 | min_mae = ev_l 240 | 241 | # Test model 242 | te_l, te_ys, te_pred = eval(model, test_loader) 243 | 244 | ys_saved = [item for sublist in te_ys for item in sublist] 245 | pred_saved = [item for sublist in te_pred for item in sublist] 246 | 247 | # Save Model and Save Loos + Predictions 248 | torch.save(model.state_dict(), f"models/config_{str(args.config)}_{str(args.mode)}_{args.cv}.pt") 249 | torch.save( 250 | [tr_losses, ev_losses, ys_saved, pred_saved], 251 | f"results/config_{str(args.config)}_{str(args.mode)}_{args.cv}.pt", 252 | ) 253 | 254 | else: 255 | # Get Datasets 256 | tran_ids, eval_ids, test_ids = get_rxn_ids( 257 | data=RXN_DATA, 258 | ) 259 | tran_ids += eval_ids 260 | train_data = DataLSF( 261 | rxn_ids=tran_ids, 262 | data=RXN_DATA, 263 | graph_dim=GRAPH_DIM, 264 | fingerprint=FINGERPRINT, 265 | ) 266 | train_loader = DataLoader( 267 | train_data, 268 | batch_size=BATCH_SIZE, 269 | shuffle=True, 270 | num_workers=2, 271 | ) 272 | test_data = DataLSF( 273 | rxn_ids=test_ids, 274 | data=RXN_DATA, 275 | graph_dim=GRAPH_DIM, 276 | fingerprint=FINGERPRINT, 277 | ) 278 | test_loader = DataLoader( 279 | test_data, 280 | batch_size=BATCH_SIZE, 281 | shuffle=True, 282 | num_workers=2, 283 | ) 284 | 285 | # Training without Early Stopping 286 | for epoch in range(1000): 287 | # Training Loop 288 | tr_l = train(model, optimizer, criterion, train_loader) 289 | tr_losses.append(tr_l) 290 | scheduler.step() 291 | 292 | if epoch >= 999: 293 | # Test model 294 | te_l, te_ys, te_pred = eval(model, test_loader) 295 | 296 | ys_saved = [item for sublist in te_ys for item in sublist] 297 | pred_saved = [item for sublist in te_pred for item in sublist] 298 | 299 | # Save Model and Save Loos + Predictions 300 | torch.save(model.state_dict(), f"models/config_{str(args.config)}_{str(args.mode)}_{args.cv}.pt") 301 | torch.save( 302 | [tr_losses, ev_losses, ys_saved, pred_saved], 303 | f"results/config_{str(args.config)}_{str(args.mode)}_{args.cv}.pt", 304 | ) 305 | -------------------------------------------------------------------------------- /lsfml/modules/gmt.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Original Paper: The global Graph Multiset Transformer pooling operator from the "Accurate Learning 5 | # of Graph Representations with Graph Multiset Pooling" `. 6 | # Code adapted from PyTorch Geometric: 7 | # https://pytorch-geometric.readthedocs.io/en/2.0.3/_modules/torch_geometric/nn/glob/gmt.html 8 | 9 | import math 10 | from typing import List, Optional, Tuple, Type 11 | 12 | import torch 13 | from torch import Tensor 14 | from torch.nn import LayerNorm, Linear 15 | from torch_geometric.nn import GCNConv 16 | from torch_geometric.utils import to_dense_batch 17 | 18 | 19 | class MAB(torch.nn.Module): 20 | def __init__( 21 | self, dim_Q: int, dim_K: int, dim_V: int, num_heads: int, Conv: Optional[Type] = None, layer_norm: bool = False 22 | ): 23 | super().__init__() 24 | self.dim_V = dim_V 25 | self.num_heads = num_heads 26 | self.layer_norm = layer_norm 27 | 28 | self.fc_q = Linear(dim_Q, dim_V) 29 | 30 | if Conv is None: 31 | self.layer_k = Linear(dim_K, dim_V) 32 | self.layer_v = Linear(dim_K, dim_V) 33 | else: 34 | self.layer_k = Conv(dim_K, dim_V) 35 | self.layer_v = Conv(dim_K, dim_V) 36 | 37 | if layer_norm: 38 | self.ln0 = LayerNorm(dim_V) 39 | self.ln1 = LayerNorm(dim_V) 40 | 41 | self.fc_o = Linear(dim_V, dim_V) 42 | 43 | def reset_parameters(self): 44 | self.fc_q.reset_parameters() 45 | self.layer_k.reset_parameters() 46 | self.layer_v.reset_parameters() 47 | if self.layer_norm: 48 | self.ln0.reset_parameters() 49 | self.ln1.reset_parameters() 50 | self.fc_o.reset_parameters() 51 | pass 52 | 53 | def forward( 54 | self, 55 | Q: Tensor, 56 | K: Tensor, 57 | graph: Optional[Tuple[Tensor, Tensor, Tensor]] = None, 58 | mask: Optional[Tensor] = None, 59 | ) -> Tensor: 60 | Q = self.fc_q(Q) 61 | 62 | if graph is not None: 63 | x, edge_index, batch = graph 64 | K, V = self.layer_k(x, edge_index), self.layer_v(x, edge_index) 65 | K, _ = to_dense_batch(K, batch) 66 | V, _ = to_dense_batch(V, batch) 67 | else: 68 | K, V = self.layer_k(K), self.layer_v(K) 69 | 70 | dim_split = self.dim_V // self.num_heads 71 | Q_ = torch.cat(Q.split(dim_split, 2), dim=0) 72 | K_ = torch.cat(K.split(dim_split, 2), dim=0) 73 | V_ = torch.cat(V.split(dim_split, 2), dim=0) 74 | 75 | if mask is not None: 76 | mask = torch.cat([mask for _ in range(self.num_heads)], 0) 77 | attention_score = Q_.bmm(K_.transpose(1, 2)) 78 | attention_score = attention_score / math.sqrt(self.dim_V) 79 | A = torch.softmax(mask + attention_score, 1) 80 | else: 81 | A = torch.softmax(Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_V), 1) 82 | 83 | out = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) 84 | 85 | if self.layer_norm: 86 | out = self.ln0(out) 87 | 88 | out = out + self.fc_o(out).relu() 89 | 90 | if self.layer_norm: 91 | out = self.ln1(out) 92 | 93 | return out 94 | 95 | 96 | class SAB(torch.nn.Module): 97 | def __init__( 98 | self, 99 | in_channels: int, 100 | out_channels: int, 101 | num_heads: int, 102 | Conv: Optional[Type] = None, 103 | layer_norm: bool = False, 104 | ): 105 | super().__init__() 106 | self.mab = MAB(in_channels, in_channels, out_channels, num_heads, Conv=Conv, layer_norm=layer_norm) 107 | 108 | def reset_parameters(self): 109 | self.mab.reset_parameters() 110 | 111 | def forward( 112 | self, 113 | x: Tensor, 114 | graph: Optional[Tuple[Tensor, Tensor, Tensor]] = None, 115 | mask: Optional[Tensor] = None, 116 | ) -> Tensor: 117 | return self.mab(x, x, graph, mask) 118 | 119 | 120 | class PMA(torch.nn.Module): 121 | def __init__( 122 | self, channels: int, num_heads: int, num_seeds: int, Conv: Optional[Type] = None, layer_norm: bool = False 123 | ): 124 | super().__init__() 125 | self.S = torch.nn.Parameter(torch.Tensor(1, num_seeds, channels)) 126 | self.mab = MAB(channels, channels, channels, num_heads, Conv=Conv, layer_norm=layer_norm) 127 | 128 | self.reset_parameters() 129 | 130 | def reset_parameters(self): 131 | torch.nn.init.xavier_uniform_(self.S) 132 | self.mab.reset_parameters() 133 | 134 | def forward( 135 | self, 136 | x: Tensor, 137 | graph: Optional[Tuple[Tensor, Tensor, Tensor]] = None, 138 | mask: Optional[Tensor] = None, 139 | ) -> Tensor: 140 | return self.mab(self.S.repeat(x.size(0), 1, 1), x, graph, mask) 141 | 142 | 143 | class GraphMultisetTransformer(torch.nn.Module): 144 | r"""The global Graph Multiset Transformer pooling operator from the 145 | `"Accurate Learning of Graph Representations 146 | with Graph Multiset Pooling" `_ paper. 147 | The Graph Multiset Transformer clusters nodes of the entire graph via 148 | attention-based pooling operations (:obj:`"GMPool_G"` or 149 | :obj:`"GMPool_I"`). 150 | In addition, self-attention (:obj:`"SelfAtt"`) can be used to calculate 151 | the inter-relationships among nodes. 152 | Args: 153 | in_channels (int): Size of each input sample. 154 | hidden_channels (int): Size of each hidden sample. 155 | out_channels (int): Size of each output sample. 156 | conv (Type, optional): A graph neural network layer 157 | for calculating hidden representations of nodes for 158 | :obj:`"GMPool_G"` (one of 159 | :class:`~torch_geometric.nn.conv.GCNConv`, 160 | :class:`~torch_geometric.nn.conv.GraphConv` or 161 | :class:`~torch_geometric.nn.conv.GATConv`). 162 | (default: :class:`~torch_geometric.nn.conv.GCNConv`) 163 | num_nodes (int, optional): The number of average 164 | or maximum nodes. (default: :obj:`300`) 165 | pooling_ratio (float, optional): Graph pooling ratio 166 | for each pooling. (default: :obj:`0.25`) 167 | pool_sequences ([str], optional): A sequence of pooling layers 168 | consisting of Graph Multiset Transformer submodules (one of 169 | :obj:`["GMPool_I"]`, 170 | :obj:`["GMPool_G"]`, 171 | :obj:`["GMPool_G", "GMPool_I"]`, 172 | :obj:`["GMPool_G", "SelfAtt", "GMPool_I"]` or 173 | :obj:`["GMPool_G", "SelfAtt", "SelfAtt", "GMPool_I"]`). 174 | (default: :obj:`["GMPool_G", "SelfAtt", "GMPool_I"]`) 175 | num_heads (int, optional): Number of attention heads. 176 | (default: :obj:`4`) 177 | layer_norm (bool, optional): If set to :obj:`True`, will make use of 178 | layer normalization. (default: :obj:`False`) 179 | """ 180 | 181 | def __init__( 182 | self, 183 | in_channels: int, 184 | hidden_channels: int, 185 | out_channels: int, 186 | Conv: Optional[Type] = None, 187 | num_nodes: int = 300, 188 | pooling_ratio: float = 0.25, 189 | pool_sequences: List[str] = ["GMPool_G", "SelfAtt", "GMPool_I"], 190 | num_heads: int = 4, 191 | layer_norm: bool = False, 192 | ): 193 | super().__init__() 194 | self.in_channels = in_channels 195 | self.hidden_channels = hidden_channels 196 | self.out_channels = out_channels 197 | self.Conv = Conv or GCNConv 198 | self.num_nodes = num_nodes 199 | self.pooling_ratio = pooling_ratio 200 | self.pool_sequences = pool_sequences 201 | self.num_heads = num_heads 202 | self.layer_norm = layer_norm 203 | 204 | self.lin1 = Linear(in_channels, hidden_channels) 205 | self.lin2 = Linear(hidden_channels, out_channels) 206 | 207 | self.pools = torch.nn.ModuleList() 208 | num_out_nodes = math.ceil(num_nodes * pooling_ratio) 209 | for i, pool_type in enumerate(pool_sequences): 210 | if pool_type not in ["GMPool_G", "GMPool_I", "SelfAtt"]: 211 | raise ValueError( 212 | "Elements in 'pool_sequences' should be one " "of 'GMPool_G', 'GMPool_I', or 'SelfAtt'" 213 | ) 214 | 215 | if i == len(pool_sequences) - 1: 216 | num_out_nodes = 1 217 | 218 | if pool_type == "GMPool_G": 219 | self.pools.append( 220 | PMA(hidden_channels, num_heads, num_out_nodes, Conv=self.Conv, layer_norm=layer_norm) 221 | ) 222 | num_out_nodes = math.ceil(num_out_nodes * self.pooling_ratio) 223 | 224 | elif pool_type == "GMPool_I": 225 | self.pools.append(PMA(hidden_channels, num_heads, num_out_nodes, Conv=None, layer_norm=layer_norm)) 226 | num_out_nodes = math.ceil(num_out_nodes * self.pooling_ratio) 227 | 228 | elif pool_type == "SelfAtt": 229 | self.pools.append(SAB(hidden_channels, hidden_channels, num_heads, Conv=None, layer_norm=layer_norm)) 230 | 231 | def reset_parameters(self): 232 | self.lin1.reset_parameters() 233 | self.lin2.reset_parameters() 234 | for pool in self.pools: 235 | pool.reset_parameters() 236 | 237 | def forward(self, x: Tensor, batch: Tensor, edge_index: Optional[Tensor] = None) -> Tensor: 238 | """""" 239 | x = self.lin1(x) 240 | batch_x, mask = to_dense_batch(x, batch) 241 | mask = (~mask).unsqueeze(1).to(dtype=x.dtype) * -1e9 242 | 243 | for i, (name, pool) in enumerate(zip(self.pool_sequences, self.pools)): 244 | graph = (x, edge_index, batch) if name == "GMPool_G" else None 245 | batch_x = pool(batch_x, graph, mask) 246 | mask = None 247 | 248 | return self.lin2(batch_x.squeeze(1)) 249 | 250 | def __repr__(self) -> str: 251 | return ( 252 | f"{self.__class__.__name__}({self.in_channels}, " 253 | f"{self.out_channels}, pool_sequences={self.pool_sequences})" 254 | ) 255 | -------------------------------------------------------------------------------- /lsfml/modules/gnn_blocks.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch_geometric.nn import MessagePassing 9 | from torch_geometric.typing import Adj, Size, Tensor 10 | from typing import Optional 11 | 12 | 13 | def weights_init(m): 14 | """Xavier uniform weight initialization. 15 | 16 | :param m: A list of learnable linear PyTorch modules. 17 | :type m: [torch.nn.modules.linear.Linear] 18 | """ 19 | if isinstance(m, nn.Linear): 20 | nn.init.xavier_uniform_(m.weight) 21 | nn.init.zeros_(m.bias) 22 | 23 | 24 | def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): 25 | if dim < 0: 26 | dim = other.dim() + dim 27 | if src.dim() == 1: 28 | for _ in range(0, dim): 29 | src = src.unsqueeze(0) 30 | for _ in range(src.dim(), other.dim()): 31 | src = src.unsqueeze(-1) 32 | src = src.expand_as(other) 33 | return src 34 | 35 | 36 | def scatter_sum( 37 | src: torch.Tensor, 38 | index: torch.Tensor, 39 | dim: int = -1, 40 | out: Optional[torch.Tensor] = None, 41 | dim_size: Optional[int] = None, 42 | ) -> torch.Tensor: 43 | index = broadcast(index, src, dim) 44 | if out is None: 45 | size = list(src.size()) 46 | if dim_size is not None: 47 | size[dim] = dim_size 48 | elif index.numel() == 0: 49 | size[dim] = 0 50 | else: 51 | size[dim] = int(index.max()) + 1 52 | out = torch.zeros(size, dtype=src.dtype, device=src.device) 53 | return out.scatter_add_(dim, index, src) 54 | else: 55 | return out.scatter_add_(dim, index, src) 56 | 57 | 58 | class EGNN_sparse(MessagePassing): 59 | """torch geometric message-passing layer for 2D molecular graphs.""" 60 | 61 | def __init__(self, feats_dim, m_dim=32, dropout=0.1, aggr="add", **kwargs): 62 | """Initialization of the 2D message passing layer. 63 | 64 | :param feats_dim: Node feature dimension. 65 | :type feats_dim: int 66 | :param m_dim: Meessage passing feature dimesnion, defaults to 32 67 | :type m_dim: int, optional 68 | :param dropout: Dropout value, defaults to 0.1 69 | :type dropout: float, optional 70 | :param aggr: Message aggregation type, defaults to "add" 71 | :type aggr: str, optional 72 | """ 73 | assert aggr in { 74 | "add", 75 | "sum", 76 | "max", 77 | "mean", 78 | }, "pool method must be a valid option" 79 | 80 | kwargs.setdefault("aggr", aggr) 81 | super(EGNN_sparse, self).__init__(**kwargs) 82 | 83 | self.feats_dim = feats_dim 84 | self.m_dim = m_dim 85 | 86 | self.edge_input_dim = feats_dim * 2 87 | 88 | self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() 89 | 90 | self.edge_norm1 = nn.LayerNorm(m_dim) 91 | self.edge_norm2 = nn.LayerNorm(m_dim) 92 | 93 | self.edge_mlp = nn.Sequential( 94 | nn.Linear(self.edge_input_dim, self.edge_input_dim * 2), 95 | self.dropout, 96 | nn.SiLU(), 97 | nn.Linear(self.edge_input_dim * 2, m_dim), 98 | nn.SiLU(), 99 | ) 100 | 101 | self.node_norm1 = nn.LayerNorm(feats_dim) 102 | self.node_norm2 = nn.LayerNorm(feats_dim) 103 | 104 | self.node_mlp = nn.Sequential( 105 | nn.Linear(feats_dim + m_dim, feats_dim * 2), 106 | self.dropout, 107 | nn.SiLU(), 108 | nn.Linear(feats_dim * 2, feats_dim), 109 | ) 110 | 111 | self.apply(self.init_) 112 | 113 | def init_(self, module): 114 | if type(module) in {nn.Linear}: 115 | nn.init.xavier_normal_(module.weight) 116 | nn.init.zeros_(module.bias) 117 | 118 | def forward( 119 | self, 120 | x: Tensor, 121 | edge_index: Adj, 122 | ): 123 | """Forward pass in the mesaage passing fucntion. 124 | 125 | :param x: Node features. 126 | :type x: Tensor 127 | :param edge_index: Edge indices. 128 | :type edge_index: Adj 129 | :return: Updated node features. 130 | :rtype: Tensor 131 | """ 132 | hidden_out = self.propagate(edge_index, x=x) 133 | 134 | return hidden_out 135 | 136 | def message(self, x_i, x_j): 137 | """Message passing. 138 | 139 | :param x_i: Node n_i. 140 | :type x_i: Tensor 141 | :param x_j: Node n_j. 142 | :type x_j: Tensor 143 | :return: Message m_ji 144 | :rtype: Tensor 145 | """ 146 | m_ij = self.edge_mlp(torch.cat([x_i, x_j], dim=-1)) 147 | return m_ij 148 | 149 | def propagate(self, edge_index: Adj, size: Size = None, **kwargs): 150 | """Overall propagation within the message passing. 151 | 152 | :param edge_index: Edge indices. 153 | :type edge_index: Adj 154 | :return: Updated node features. 155 | :rtype: Tensor 156 | """ 157 | # get input tensors 158 | size = self._check_input(edge_index, size) 159 | coll_dict = self._collect(self._user_args, edge_index, size, kwargs) 160 | msg_kwargs = self.inspector.distribute("message", coll_dict) 161 | aggr_kwargs = self.inspector.distribute("aggregate", coll_dict) 162 | update_kwargs = self.inspector.distribute("update", coll_dict) 163 | 164 | # get messages 165 | m_ij = self.message(**msg_kwargs) 166 | m_ij = self.edge_norm1(m_ij) 167 | 168 | # aggregate messages 169 | m_i = self.aggregate(m_ij, **aggr_kwargs) 170 | m_i = self.edge_norm2(m_i) 171 | 172 | # get updated node features 173 | hidden_feats = self.node_norm1(kwargs["x"]) 174 | hidden_out = self.node_mlp(torch.cat([hidden_feats, m_i], dim=-1)) 175 | hidden_out = self.node_norm2(hidden_out) 176 | hidden_out = kwargs["x"] + hidden_out 177 | 178 | return self.update((hidden_out), **update_kwargs) 179 | 180 | 181 | def fourier_encode_dist(x, num_encodings=4, include_self=True): 182 | """Encoding Euclidian diatomic distances into Fourier features. 183 | 184 | :param x: Distances in Angström. 185 | :type x: Tensor 186 | :param num_encodings: Number of sine and cosine functions, defaults to 4 187 | :type num_encodings: int, optional 188 | :param include_self: Option to include absolute distance, defaults to True 189 | :type include_self: bool, optional 190 | :return: Fourier features. 191 | :rtype: Tensor 192 | """ 193 | x = x.unsqueeze(-1) 194 | device, dtype, orig_x = x.device, x.dtype, x 195 | scales = 2 ** torch.arange(num_encodings, device=device, dtype=dtype) 196 | x = x / scales 197 | x = torch.cat([x.sin(), x.cos()], dim=-1) 198 | x = torch.cat((x, orig_x), dim=-1) if include_self else x 199 | return x 200 | 201 | 202 | class EGNN_sparse3D(MessagePassing): 203 | """torch geometric message-passing layer for 3D molecular graphs.""" 204 | 205 | def __init__(self, feats_dim, pos_dim=3, m_dim=32, dropout=0.1, fourier_features=16, aggr="add", **kwargs): 206 | """Initialization of the 3D message passing layer. 207 | 208 | :param feats_dim: Node feature dimension. 209 | :type feats_dim: int 210 | :param pos_dim: Dimension of the graph, defaults to 3 211 | :type pos_dim: int, optional 212 | :param m_dim: Meessage passing feature dimesnion, defaults to 32 213 | :type m_dim: int, optional 214 | :param dropout: Dropout value, defaults to 0.1 215 | :type dropout: float, optional 216 | :param fourier_features: Number of Fourier features, defaults to 16 217 | :type fourier_features: int, optional 218 | :param aggr: Message aggregation type, defaults to "add" 219 | :type aggr: str, optional 220 | """ 221 | assert aggr in { 222 | "add", 223 | "sum", 224 | "max", 225 | "mean", 226 | }, "pool method must be a valid option" 227 | 228 | kwargs.setdefault("aggr", aggr) 229 | super(EGNN_sparse3D, self).__init__(**kwargs) 230 | 231 | # Model parameters 232 | self.feats_dim = feats_dim 233 | self.pos_dim = pos_dim 234 | self.m_dim = m_dim 235 | self.fourier_features = fourier_features 236 | 237 | self.edge_input_dim = (self.fourier_features * 2) + 1 + (feats_dim * 2) 238 | 239 | self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() 240 | 241 | self.edge_norm1 = nn.LayerNorm(m_dim) 242 | self.edge_norm2 = nn.LayerNorm(m_dim) 243 | 244 | self.edge_mlp = nn.Sequential( 245 | nn.Linear(self.edge_input_dim, self.edge_input_dim * 2), 246 | self.dropout, 247 | nn.SiLU(), 248 | nn.Linear(self.edge_input_dim * 2, m_dim), 249 | nn.SiLU(), 250 | ) 251 | 252 | self.node_norm1 = nn.LayerNorm(feats_dim) 253 | self.node_norm2 = nn.LayerNorm(feats_dim) 254 | 255 | self.node_mlp = nn.Sequential( 256 | nn.Linear(feats_dim + m_dim, feats_dim * 2), 257 | self.dropout, 258 | nn.SiLU(), 259 | nn.Linear(feats_dim * 2, feats_dim), 260 | ) 261 | 262 | self.apply(self.init_) 263 | 264 | def init_(self, module): 265 | if type(module) in {nn.Linear}: 266 | nn.init.xavier_normal_(module.weight) 267 | nn.init.zeros_(module.bias) 268 | 269 | def forward( 270 | self, 271 | x: Tensor, 272 | edge_index: Adj, 273 | ): 274 | """Forward pass in the mesaage passing fucntion. 275 | 276 | :param x: Node features. 277 | :type x: Tensor 278 | :param edge_index: Edge indices. 279 | :type edge_index: Adj 280 | :return: Updated node features. 281 | :rtype: Tensor 282 | """ 283 | coors, feats = x[:, : self.pos_dim], x[:, self.pos_dim :] 284 | rel_coors = coors[edge_index[0]] - coors[edge_index[1]] 285 | rel_dist = (rel_coors**2).sum(dim=-1, keepdim=True) 286 | 287 | if self.fourier_features > 0: 288 | rel_dist = fourier_encode_dist(rel_dist, num_encodings=self.fourier_features) 289 | rel_dist = rel_dist.squeeze(1) 290 | # rel_dist = rearrange(rel_dist, "n () d -> n d") 291 | 292 | hidden_out = self.propagate(edge_index, x=feats, edge_attr=rel_dist) 293 | return torch.cat([coors, hidden_out], dim=-1) 294 | 295 | def message(self, x_i, x_j, edge_attr): 296 | """Message passing. 297 | 298 | :param x_i: Node n_i. 299 | :type x_i: Tensor 300 | :param x_j: Node n_j. 301 | :type x_j: Tensor 302 | :param edge_attr: Edge e_{ij} 303 | :type edge_attr: Tensor 304 | :return: Message m_ji 305 | :rtype: Tensor 306 | """ 307 | m_ij = self.edge_mlp(torch.cat([x_i, x_j, edge_attr], dim=-1)) 308 | return m_ij 309 | 310 | def propagate(self, edge_index: Adj, size: Size = None, **kwargs): 311 | """Overall propagation within the message passing. 312 | 313 | :param edge_index: Edge indices. 314 | :type edge_index: Adj 315 | :return: Updated node features. 316 | :rtype: Tensor 317 | """ 318 | # get input tensors 319 | size = self._check_input(edge_index, size) 320 | coll_dict = self._collect(self._user_args, edge_index, size, kwargs) 321 | msg_kwargs = self.inspector.distribute("message", coll_dict) 322 | aggr_kwargs = self.inspector.distribute("aggregate", coll_dict) 323 | update_kwargs = self.inspector.distribute("update", coll_dict) 324 | 325 | # get messages 326 | m_ij = self.message(**msg_kwargs) 327 | m_ij = self.edge_norm1(m_ij) 328 | 329 | # aggregate messages 330 | m_i = self.aggregate(m_ij, **aggr_kwargs) 331 | m_i = self.edge_norm2(m_i) 332 | 333 | # get updated node features 334 | hidden_feats = self.node_norm1(kwargs["x"]) 335 | hidden_out = self.node_mlp(torch.cat([hidden_feats, m_i], dim=-1)) 336 | hidden_out = self.node_norm2(hidden_out) 337 | hidden_out = kwargs["x"] + hidden_out 338 | 339 | return self.update((hidden_out), **update_kwargs) 340 | -------------------------------------------------------------------------------- /lsfml/modules/pygdataset.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Code adapted from PyTorch Geometric: 5 | # https://pytorch-geometric.readthedocs.io/ 6 | 7 | from typing import List, Optional, Callable, Union, Any, Tuple 8 | 9 | import sys 10 | import re 11 | import copy 12 | import warnings 13 | import numpy as np 14 | import os.path as osp 15 | from collections.abc import Sequence 16 | 17 | import torch.utils.data 18 | from torch import Tensor 19 | 20 | from torch_geometric.data import Data 21 | from torch_geometric.data.makedirs import makedirs 22 | 23 | IndexType = Union[slice, Tensor, np.ndarray, Sequence] 24 | 25 | 26 | class Dataset(torch.utils.data.Dataset): 27 | r"""Dataset base class for creating graph datasets. 28 | See `here `__ for the accompanying tutorial. 30 | 31 | Args: 32 | root (string, optional): Root directory where the dataset should be 33 | saved. (optional: :obj:`None`) 34 | transform (callable, optional): A function/transform that takes in an 35 | :obj:`torch_geometric.data.Data` object and returns a transformed 36 | version. The data object will be transformed before every access. 37 | (default: :obj:`None`) 38 | pre_transform (callable, optional): A function/transform that takes in 39 | an :obj:`torch_geometric.data.Data` object and returns a 40 | transformed version. The data object will be transformed before 41 | being saved to disk. (default: :obj:`None`) 42 | pre_filter (callable, optional): A function that takes in an 43 | :obj:`torch_geometric.data.Data` object and returns a boolean 44 | value, indicating whether the data object should be included in the 45 | final dataset. (default: :obj:`None`) 46 | """ 47 | 48 | @property 49 | def raw_file_names(self) -> Union[str, List[str], Tuple]: 50 | r"""The name of the files in the :obj:`self.raw_dir` folder that must 51 | be present in order to skip downloading.""" 52 | raise NotImplementedError 53 | 54 | @property 55 | def processed_file_names(self) -> Union[str, List[str], Tuple]: 56 | r"""The name of the files in the :obj:`self.processed_dir` folder that 57 | must be present in order to skip processing.""" 58 | raise NotImplementedError 59 | 60 | def download(self): 61 | r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" 62 | raise NotImplementedError 63 | 64 | def process(self): 65 | r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" 66 | raise NotImplementedError 67 | 68 | def len(self) -> int: 69 | r"""Returns the number of graphs stored in the dataset.""" 70 | raise NotImplementedError 71 | 72 | def get(self, idx: int) -> Data: 73 | r"""Gets the data object at index :obj:`idx`.""" 74 | raise NotImplementedError 75 | 76 | def __init__( 77 | self, 78 | root: Optional[str] = None, 79 | transform: Optional[Callable] = None, 80 | pre_transform: Optional[Callable] = None, 81 | pre_filter: Optional[Callable] = None, 82 | ): 83 | super().__init__() 84 | 85 | if isinstance(root, str): 86 | root = osp.expanduser(osp.normpath(root)) 87 | 88 | self.root = root 89 | self.transform = transform 90 | self.pre_transform = pre_transform 91 | self.pre_filter = pre_filter 92 | self._indices: Optional[Sequence] = None 93 | 94 | if "download" in self.__class__.__dict__: 95 | self._download() 96 | 97 | if "process" in self.__class__.__dict__: 98 | self._process() 99 | 100 | def indices(self) -> Sequence: 101 | return range(self.len()) if self._indices is None else self._indices 102 | 103 | @property 104 | def raw_dir(self) -> str: 105 | return osp.join(self.root, "raw") 106 | 107 | @property 108 | def processed_dir(self) -> str: 109 | return osp.join(self.root, "processed") 110 | 111 | @property 112 | def num_node_features(self) -> int: 113 | r"""Returns the number of features per node in the dataset.""" 114 | data = self[0] 115 | data = data[0] if isinstance(data, tuple) else data 116 | if hasattr(data, "num_node_features"): 117 | return data.num_node_features 118 | raise AttributeError(f"'{data.__class__.__name__}' object has no " f"attribute 'num_node_features'") 119 | 120 | @property 121 | def num_features(self) -> int: 122 | r"""Returns the number of features per node in the dataset. 123 | Alias for :py:attr:`~num_node_features`.""" 124 | return self.num_node_features 125 | 126 | @property 127 | def num_edge_features(self) -> int: 128 | r"""Returns the number of features per edge in the dataset.""" 129 | data = self[0] 130 | data = data[0] if isinstance(data, tuple) else data 131 | if hasattr(data, "num_edge_features"): 132 | return data.num_edge_features 133 | raise AttributeError(f"'{data.__class__.__name__}' object has no " f"attribute 'num_edge_features'") 134 | 135 | @property 136 | def raw_paths(self) -> List[str]: 137 | r"""The absolute filepaths that must be present in order to skip 138 | downloading.""" 139 | files = to_list(self.raw_file_names) 140 | return [osp.join(self.raw_dir, f) for f in files] 141 | 142 | @property 143 | def processed_paths(self) -> List[str]: 144 | r"""The absolute filepaths that must be present in order to skip 145 | processing.""" 146 | files = to_list(self.processed_file_names) 147 | return [osp.join(self.processed_dir, f) for f in files] 148 | 149 | def _download(self): 150 | if files_exist(self.raw_paths): # pragma: no cover 151 | return 152 | 153 | makedirs(self.raw_dir) 154 | self.download() 155 | 156 | def _process(self): 157 | f = osp.join(self.processed_dir, "pre_transform.pt") 158 | if osp.exists(f) and torch.load(f) != _repr(self.pre_transform): 159 | warnings.warn( 160 | f"The `pre_transform` argument differs from the one used in " 161 | f"the pre-processed version of this dataset. If you want to " 162 | f"make use of another pre-processing technique, make sure to " 163 | f"sure to delete '{self.processed_dir}' first" 164 | ) 165 | 166 | f = osp.join(self.processed_dir, "pre_filter.pt") 167 | if osp.exists(f) and torch.load(f) != _repr(self.pre_filter): 168 | warnings.warn( 169 | "The `pre_filter` argument differs from the one used in the " 170 | "pre-processed version of this dataset. If you want to make " 171 | "use of another pre-fitering technique, make sure to delete " 172 | "'{self.processed_dir}' first" 173 | ) 174 | 175 | if files_exist(self.processed_paths): # pragma: no cover 176 | return 177 | 178 | print("Processing...", file=sys.stderr) 179 | 180 | makedirs(self.processed_dir) 181 | self.process() 182 | 183 | path = osp.join(self.processed_dir, "pre_transform.pt") 184 | torch.save(_repr(self.pre_transform), path) 185 | path = osp.join(self.processed_dir, "pre_filter.pt") 186 | torch.save(_repr(self.pre_filter), path) 187 | 188 | print("Done!", file=sys.stderr) 189 | 190 | def __len__(self) -> int: 191 | r"""The number of examples in the dataset.""" 192 | return len(self.indices()) 193 | 194 | def __getitem__( 195 | self, 196 | idx: Union[int, np.integer, IndexType], 197 | ) -> Union["Dataset", Data]: 198 | r"""In case :obj:`idx` is of type integer, will return the data object 199 | at index :obj:`idx` (and transforms it in case :obj:`transform` is 200 | present). 201 | In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a 202 | tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or 203 | bool, will return a subset of the dataset at the specified indices.""" 204 | if ( 205 | isinstance(idx, (int, np.integer)) 206 | or (isinstance(idx, Tensor) and idx.dim() == 0) 207 | or (isinstance(idx, np.ndarray) and np.isscalar(idx)) 208 | ): 209 | data = self.get(self.indices()[idx]) 210 | data = data if self.transform is None else self.transform(data) 211 | return data 212 | 213 | else: 214 | return self.index_select(idx) 215 | 216 | def index_select(self, idx: IndexType) -> "Dataset": 217 | r"""Creates a subset of the dataset from specified indices :obj:`idx`. 218 | Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a 219 | list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type 220 | long or bool.""" 221 | indices = self.indices() 222 | 223 | if isinstance(idx, slice): 224 | indices = indices[idx] 225 | 226 | elif isinstance(idx, Tensor) and idx.dtype == torch.long: 227 | return self.index_select(idx.flatten().tolist()) 228 | 229 | elif isinstance(idx, Tensor) and idx.dtype == torch.bool: 230 | idx = idx.flatten().nonzero(as_tuple=False) 231 | return self.index_select(idx.flatten().tolist()) 232 | 233 | elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: 234 | return self.index_select(idx.flatten().tolist()) 235 | 236 | elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: 237 | idx = idx.flatten().nonzero()[0] 238 | return self.index_select(idx.flatten().tolist()) 239 | 240 | elif isinstance(idx, Sequence) and not isinstance(idx, str): 241 | indices = [indices[i] for i in idx] 242 | 243 | else: 244 | raise IndexError( 245 | f"Only slices (':'), list, tuples, torch.tensor and " 246 | f"np.ndarray of dtype long or bool are valid indices (got " 247 | f"'{type(idx).__name__}')" 248 | ) 249 | 250 | dataset = copy.copy(self) 251 | dataset._indices = indices 252 | return dataset 253 | 254 | def shuffle( 255 | self, 256 | return_perm: bool = False, 257 | ) -> Union["Dataset", Tuple["Dataset", Tensor]]: 258 | r"""Randomly shuffles the examples in the dataset. 259 | 260 | Args: 261 | return_perm (bool, optional): If set to :obj:`True`, will also 262 | return the random permutation used to shuffle the dataset. 263 | (default: :obj:`False`) 264 | """ 265 | perm = torch.randperm(len(self)) 266 | dataset = self.index_select(perm) 267 | return (dataset, perm) if return_perm is True else dataset 268 | 269 | def __repr__(self) -> str: 270 | arg_repr = str(len(self)) if len(self) > 1 else "" 271 | return f"{self.__class__.__name__}({arg_repr})" 272 | 273 | 274 | def to_list(value: Any) -> Sequence: 275 | if isinstance(value, Sequence) and not isinstance(value, str): 276 | return value 277 | else: 278 | return [value] 279 | 280 | 281 | def files_exist(files: List[str]) -> bool: 282 | # NOTE: We return `False` in case `files` is empty, leading to a 283 | # re-processing of files on every instantiation. 284 | return len(files) != 0 and all([osp.exists(f) for f in files]) 285 | 286 | 287 | def _repr(obj: Any) -> str: 288 | if obj is None: 289 | return "None" 290 | return re.sub("(<.*?)\\s.*(>)", r"\1\2", obj.__repr__()) 291 | -------------------------------------------------------------------------------- /lsfml/qml/config_14000.ini: -------------------------------------------------------------------------------- 1 | [PARAMS] 2 | EPOCHS = 5000 3 | BATCH_SIZE = 16 4 | LEARNING_RATE = 0.0001 5 | WEIGHT_DECAY = 1e-10 6 | FACTOR = 0.7 7 | PATIENCE = 20 8 | KERNEL_DIM = 128 9 | KERNEL_NUM = 5 10 | MLP_DIM = 256 11 | MLP_NUM = 3 12 | EDGE_DIM = 32 13 | OUTPUT_DIM = 1 14 | NETWORK = "atomic" 15 | AGGR = "mean" 16 | FOURIER = 32 -------------------------------------------------------------------------------- /lsfml/qml/model1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ETHmodlab/lsfml/d0178f1ebfedd73639cee2452fadacc500ca23e1/lsfml/qml/model1.pkl -------------------------------------------------------------------------------- /lsfml/qml/prod.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | import configparser 7 | import os 8 | 9 | import torch 10 | 11 | from lsfml.qml.qml_net import DeltaNetAtomic 12 | 13 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | HERE = os.path.abspath(os.path.dirname(__file__)) 15 | 16 | 17 | def get_model(gpu=True): 18 | """Returns loaded and initialized QML model. 19 | 20 | :param gpu: Running model on GPU (True or False), defaults to True 21 | :type gpu: bool, optional 22 | :return: QML model 23 | :rtype: class 24 | """ 25 | # Load config parameters 26 | config = configparser.ConfigParser() 27 | config.read(os.path.join(HERE, "config_14000.ini")) 28 | 29 | KERNEL_DIM = int(config["PARAMS"]["KERNEL_DIM"]) 30 | KERNEL_NUM = int(config["PARAMS"]["KERNEL_NUM"]) 31 | MLP_DIM = int(config["PARAMS"]["MLP_DIM"]) 32 | MLP_NUM = int(config["PARAMS"]["MLP_NUM"]) 33 | EDGE_DIM = int(config["PARAMS"]["EDGE_DIM"]) 34 | OUTPUT_DIM = int(config["PARAMS"]["OUTPUT_DIM"]) 35 | AGGR = str(config["PARAMS"]["AGGR"]).strip('"') 36 | FOURIER = int(config["PARAMS"]["FOURIER"]) 37 | 38 | # Load model 39 | model = DeltaNetAtomic( 40 | embedding_dim=KERNEL_DIM, 41 | n_kernels=KERNEL_NUM, 42 | n_mlp=MLP_NUM, 43 | mlp_dim=MLP_DIM, 44 | n_outputs=OUTPUT_DIM, 45 | m_dim=EDGE_DIM, 46 | initialize_weights=True, 47 | fourier_features=FOURIER, 48 | aggr=AGGR, 49 | ) 50 | 51 | model.load_state_dict( 52 | torch.load( 53 | os.path.join(HERE, "model1.pkl"), 54 | map_location=torch.device("cpu"), 55 | ) 56 | ) 57 | 58 | if gpu: 59 | model = model.to(DEVICE) 60 | print(f"QML model has been sent to {DEVICE}") 61 | else: 62 | model = model.to("cpu") 63 | print("QML model has been sent to cpu") 64 | 65 | return model 66 | 67 | 68 | if __name__ == "__main__": 69 | get_model() 70 | -------------------------------------------------------------------------------- /lsfml/qml/qml_net.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from einops import rearrange 10 | from torch_geometric.nn import MessagePassing 11 | from torch_geometric.typing import Adj, Size, Tensor 12 | 13 | 14 | class DeltaNetAtomic(nn.Module): 15 | """Atomistic graph neural network (aGNN) for partial charge predictions.""" 16 | 17 | def __init__( 18 | self, 19 | embedding_dim=128, 20 | n_kernels=2, 21 | n_mlp=3, 22 | mlp_dim=256, 23 | n_outputs=1, 24 | m_dim=64, 25 | initialize_weights=True, 26 | fourier_features=4, 27 | aggr="mean", 28 | ): 29 | """Initialization of aGNN. 30 | 31 | :param embeddings_dim: Embedding dimension of the input features (e.g. reaction conditions and atomic features), defaults to 64 32 | :type embeddings_dim: int, optional 33 | :param n_kernels: Number of message passing functions, defaults to 3 34 | :type n_kernels: int, optional 35 | :param n_mlp: Number of multi layer perceptrons, defaults to 3 36 | :type n_mlp: int, optional 37 | :param mlp_dim: Feature dimension of multi layer perceptrons, defaults to 256 38 | :type mlp_dim: int, optional 39 | :param out_dim: Output dimesion, defaults to 1 40 | :type out_dim: int, optional 41 | :param m_dim: Meessage passing feature dimesnion, defaults to 64 42 | :type m_dim: int, optional 43 | :param initialize_weights: initialize weights before training, defaults to True 44 | :type initialize_weights: bool, optional 45 | :param fourier_features: Number of Fourier features, defaults to 16 46 | :type fourier_features: int, optional 47 | :param aggr: Message aggregation type, defaults to "mean" 48 | :type aggr: str, optional 49 | """ 50 | super(DeltaNetAtomic, self).__init__() 51 | 52 | self.pos_dim = 3 53 | self.m_dim = m_dim 54 | self.embedding_dim = embedding_dim 55 | self.n_kernels = n_kernels 56 | self.n_mlp = n_mlp 57 | self.mlp_dim = mlp_dim 58 | self.n_outputs = n_outputs 59 | self.initialize_weights = initialize_weights 60 | self.fourier_features = fourier_features 61 | self.aggr = aggr 62 | 63 | # Embedding 64 | self.embedding = nn.Embedding(num_embeddings=11, embedding_dim=self.embedding_dim) 65 | 66 | # Kernel 67 | self.kernel_dim = self.embedding_dim 68 | self.kernels = nn.ModuleList() 69 | for _ in range(self.n_kernels): 70 | self.kernels.append( 71 | EGNN_sparse( 72 | feats_dim=self.kernel_dim, 73 | pos_dim=self.pos_dim, 74 | m_dim=self.m_dim, 75 | fourier_features=self.fourier_features, 76 | aggr=self.aggr, 77 | ) 78 | ) 79 | 80 | # MLP 81 | self.fnn = nn.ModuleList() 82 | input_fnn = self.kernel_dim * (self.n_kernels + 1) 83 | self.fnn.append(nn.Linear(input_fnn, mlp_dim)) 84 | for _ in range(self.n_mlp - 1): 85 | self.fnn.append(nn.Linear(self.mlp_dim, self.mlp_dim)) 86 | self.fnn.append(nn.Linear(self.mlp_dim, self.n_outputs)) 87 | 88 | # Initialize weights 89 | if self.initialize_weights: 90 | self.kernels.apply(weights_init) 91 | self.fnn.apply(weights_init) 92 | nn.init.xavier_uniform_(self.embedding.weight) 93 | 94 | def forward(self, g_batch): 95 | """_summary_ 96 | 97 | :param g_batch: _description_ 98 | :type g_batch: _type_ 99 | :return: _description_ 100 | :rtype: _type_ 101 | """ 102 | # Embedding 103 | features = self.embedding(g_batch.atomids) 104 | features = torch.cat([g_batch.coords, features], dim=1) 105 | 106 | # Kernel 107 | feature_list = [] 108 | feature_list.append(features[:, self.pos_dim :]) 109 | 110 | for kernel in self.kernels: 111 | features = kernel( 112 | x=features, 113 | edge_index=g_batch.edge_index, 114 | ) 115 | feature_list.append(features[:, self.pos_dim :]) 116 | 117 | # Concat 118 | features = F.silu(torch.cat(feature_list, dim=1)) 119 | 120 | # MLP 1 121 | for mlp in self.fnn[:-1]: 122 | features = F.silu(mlp(features)) 123 | 124 | # Outputlayer 125 | features = self.fnn[-1](features).squeeze(1) 126 | 127 | return features 128 | 129 | 130 | def weights_init(m): 131 | """Xavier uniform weight initialization. 132 | 133 | :param m: A list of learnable linear PyTorch modules. 134 | :type m: [torch.nn.modules.linear.Linear] 135 | """ 136 | if isinstance(m, nn.Linear): 137 | nn.init.xavier_uniform_(m.weight) 138 | nn.init.zeros_(m.bias) 139 | 140 | 141 | def fourier_encode_dist(x, num_encodings=4, include_self=True): 142 | """Encoding Euclidian diatomic distances into Fourier features. 143 | 144 | :param x: Distances in Angström. 145 | :type x: Tensor 146 | :param num_encodings: Number of sine and cosine functions, defaults to 4 147 | :type num_encodings: int, optional 148 | :param include_self: Option to include absolute distance, defaults to True 149 | :type include_self: bool, optional 150 | :return: Fourier features. 151 | :rtype: Tensor 152 | """ 153 | x = x.unsqueeze(-1) 154 | device, dtype, orig_x = x.device, x.dtype, x 155 | scales = 2 ** torch.arange(num_encodings, device=device, dtype=dtype) 156 | x = x / scales 157 | x = torch.cat([x.sin(), x.cos()], dim=-1) 158 | x = torch.cat((x, orig_x), dim=-1) if include_self else x 159 | return x 160 | 161 | 162 | class EGNN_sparse(MessagePassing): 163 | """torch geometric message-passing layer for 3D molecular graphs.""" 164 | 165 | def __init__( 166 | self, feats_dim, pos_dim=3, edge_attr_dim=0, m_dim=32, dropout=0.1, fourier_features=32, aggr="mean", **kwargs 167 | ): 168 | """Initialization of the 3D message passing layer. 169 | 170 | :param feats_dim: Node feature dimension. 171 | :type feats_dim: int 172 | :param pos_dim: Dimension of the graph, defaults to 3 173 | :type pos_dim: int, optional 174 | :param edge_attr_dim: Additional edge features (neglected in this implementation) 175 | :type edge_attr_dim: Tensor 176 | :param m_dim: Meessage passing feature dimesnion, defaults to 32 177 | :type m_dim: int, optional 178 | :param dropout: Dropout value, defaults to 0.1 179 | :type dropout: float, optional 180 | :param fourier_features: Number of Fourier features, defaults to 16 181 | :type fourier_features: int, optional 182 | :param aggr: Message aggregation type, defaults to "add" 183 | :type aggr: str, optional 184 | """ 185 | assert aggr in { 186 | "add", 187 | "sum", 188 | "max", 189 | "mean", 190 | }, "pool method must be a valid option" 191 | 192 | kwargs.setdefault("aggr", aggr) 193 | super(EGNN_sparse, self).__init__(**kwargs) 194 | 195 | # Model parameters 196 | self.feats_dim = feats_dim 197 | self.pos_dim = pos_dim 198 | self.m_dim = m_dim 199 | self.fourier_features = fourier_features 200 | 201 | self.edge_input_dim = (self.fourier_features * 2) + edge_attr_dim + 1 + (feats_dim * 2) 202 | 203 | self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() 204 | 205 | # Edge layers 206 | self.edge_norm1 = nn.LayerNorm(m_dim) 207 | self.edge_norm2 = nn.LayerNorm(m_dim) 208 | 209 | self.edge_mlp = nn.Sequential( 210 | nn.Linear(self.edge_input_dim, self.edge_input_dim * 2), 211 | self.dropout, 212 | nn.SiLU(), 213 | nn.Linear(self.edge_input_dim * 2, m_dim), 214 | nn.SiLU(), 215 | ) 216 | 217 | # Node layers 218 | self.node_norm1 = nn.LayerNorm(feats_dim) 219 | self.node_norm2 = nn.LayerNorm(feats_dim) 220 | 221 | self.node_mlp = nn.Sequential( 222 | nn.Linear(feats_dim + m_dim, feats_dim * 2), 223 | self.dropout, 224 | nn.SiLU(), 225 | nn.Linear(feats_dim * 2, feats_dim), 226 | ) 227 | 228 | # Initialization 229 | self.apply(self.init_) 230 | 231 | def init_(self, module): 232 | if type(module) in {nn.Linear}: 233 | nn.init.xavier_normal_(module.weight) 234 | nn.init.zeros_(module.bias) 235 | 236 | def forward( 237 | self, 238 | x: Tensor, 239 | edge_index: Adj, 240 | ): 241 | """Forward pass in the mesaage passing fucntion. 242 | 243 | :param x: Node features. 244 | :type x: Tensor 245 | :param edge_index: Edge indices. 246 | :type edge_index: Adj 247 | :return: Updated node features. 248 | :rtype: Tensor 249 | """ 250 | coors, feats = x[:, : self.pos_dim], x[:, self.pos_dim :] 251 | rel_coors = coors[edge_index[0]] - coors[edge_index[1]] 252 | rel_dist = (rel_coors**2).sum(dim=-1, keepdim=True) 253 | 254 | if self.fourier_features > 0: 255 | rel_dist = fourier_encode_dist(rel_dist, num_encodings=self.fourier_features) 256 | rel_dist = rearrange(rel_dist, "n () d -> n d") 257 | 258 | hidden_out = self.propagate( 259 | edge_index, 260 | x=feats, 261 | edge_attr=rel_dist, 262 | coors=coors, 263 | rel_coors=rel_coors, 264 | ) 265 | 266 | return torch.cat([coors, hidden_out], dim=-1) 267 | 268 | def message(self, x_i, x_j, edge_attr): 269 | """Message passing. 270 | 271 | :param x_i: Node n_i. 272 | :type x_i: Tensor 273 | :param x_j: Node n_j. 274 | :type x_j: Tensor 275 | :param edge_attr: Edge e_{ij} 276 | :type edge_attr: Tensor 277 | :return: Message m_ji 278 | :rtype: Tensor 279 | """ 280 | m_ij = self.edge_mlp(torch.cat([x_i, x_j, edge_attr], dim=-1)) 281 | return m_ij 282 | 283 | def propagate(self, edge_index: Adj, size: Size = None, **kwargs): 284 | """Overall propagation within the message passing. 285 | 286 | :param edge_index: Edge indices. 287 | :type edge_index: Adj 288 | :return: Updated node features. 289 | :rtype: Tensor 290 | """ 291 | # get input tensors 292 | size = self._check_input(edge_index, size) 293 | coll_dict = self._collect(self._user_args, edge_index, size, kwargs) 294 | msg_kwargs = self.inspector.distribute("message", coll_dict) 295 | aggr_kwargs = self.inspector.distribute("aggregate", coll_dict) 296 | update_kwargs = self.inspector.distribute("update", coll_dict) 297 | 298 | # get messages 299 | m_ij = self.message(**msg_kwargs) 300 | m_ij = self.edge_norm1(m_ij) 301 | 302 | # aggregate messages 303 | m_i = self.aggregate(m_ij, **aggr_kwargs) 304 | m_i = self.edge_norm1(m_i) 305 | 306 | # get updated node features 307 | hidden_feats = self.node_norm1(kwargs["x"]) 308 | hidden_out = self.node_mlp(torch.cat([hidden_feats, m_i], dim=-1)) 309 | hidden_out = self.node_norm2(hidden_out) 310 | hidden_out = kwargs["x"] + hidden_out 311 | 312 | return self.update(hidden_out, **update_kwargs) 313 | -------------------------------------------------------------------------------- /lsfml/qml/test.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | import glob 7 | 8 | import networkx as nx 9 | import numpy as np 10 | import torch 11 | from rdkit import Chem 12 | from torch_geometric.data import Data 13 | from torch_geometric.utils import add_self_loops 14 | from torch_geometric.utils.undirected import to_undirected 15 | 16 | from lsfml.qml.prod import get_model 17 | from lsfml.utils import get_dict_for_embedding, QML_ATOMTYPES 18 | 19 | QMLMODEL = get_model(gpu=False) 20 | 21 | QML_ATOMTYPE_DICT = get_dict_for_embedding(QML_ATOMTYPES) 22 | 23 | 24 | def compare_charges(sdf): 25 | """Main function to compare predicted with calculated partial charges using DFT using reference structures in test_mols/. 26 | 27 | :param sdf: Path to SDF of a molecule. 28 | :type sdf: str 29 | :return: Mean absolute error between the predicted and calculated partial charges. 30 | :rtype: numpy.float64 31 | """ 32 | mol = next(Chem.SDMolSupplier(sdf, removeHs=False)) 33 | 34 | # props 35 | props = mol.GetPropsAsDict() 36 | 37 | # read charges 38 | dft = props["DFT:MULLIKEN_CHARGES"].split("|") 39 | dft = [float(x) for x in dft] 40 | dft = np.array(dft) 41 | 42 | # get atomids and xyz coords 43 | qml_atomids = [] 44 | crds_3d = [] 45 | 46 | for idx, i in enumerate(mol.GetAtoms()): 47 | qml_atomids.append(QML_ATOMTYPE_DICT[i.GetSymbol()]) 48 | crds_3d.append(list(mol.GetConformer().GetAtomPosition(idx))) 49 | 50 | qml_atomids = np.array(qml_atomids) 51 | crds_3d = np.array(crds_3d) 52 | 53 | # 3D graph for qml prediction 54 | qml_atomids = torch.LongTensor(qml_atomids) 55 | xyzs = torch.FloatTensor(crds_3d) 56 | edge_index = np.array(nx.complete_graph(qml_atomids.size(0)).edges()) 57 | edge_index = to_undirected(torch.from_numpy(edge_index).t().contiguous()) 58 | edge_index, _ = add_self_loops(edge_index, num_nodes=crds_3d.shape[0]) 59 | 60 | qml_graph = Data( 61 | atomids=qml_atomids, 62 | coords=xyzs, 63 | edge_index=edge_index, 64 | num_nodes=qml_atomids.size(0), 65 | ) 66 | 67 | # prediction 68 | charges = QMLMODEL(qml_graph).detach().numpy() 69 | 70 | # mae calculation 71 | return np.mean(np.abs(charges - dft)) 72 | 73 | 74 | if __name__ == "__main__": 75 | mol_files = sorted(glob.glob("test_mols/*sdf")) 76 | 77 | maes = [] 78 | 79 | print("Partial charge prediction is conducted with the following mean absolute errors:") 80 | 81 | for sdf in mol_files: 82 | mae = compare_charges(sdf) 83 | print(sdf.split("/")[-1][:-4], mae) 84 | -------------------------------------------------------------------------------- /lsfml/utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 3 | # 4 | # Copyright (©) 2023, ETH Zurich 5 | 6 | from rdkit import Chem 7 | from rdkit.Chem import rdMolDescriptors 8 | import numpy as np 9 | from collections import Counter 10 | import torch.nn.functional as F 11 | import os 12 | 13 | UTILS_PATH = os.path.dirname(__file__) 14 | 15 | 16 | def mae_loss(x, y): 17 | """Calculates the MAE loss. 18 | 19 | :param x: Predicted values. 20 | :type x: Tensor 21 | :param y: True values. 22 | :type y: Tensor 23 | :return: Calculated MAE 24 | :rtype: numpy.float64 25 | """ 26 | return F.l1_loss(x, y).item() 27 | 28 | 29 | def get_dict_for_embedding(list): 30 | """Creates a dictionary from a list of strings as keys and values form 0 to N, where N = len(list). 31 | 32 | :param list: List of strings. 33 | :type list: list[str] 34 | :return: dictionary, mapping each string to an integer. 35 | :rtype: dirct[str] = int 36 | """ 37 | 38 | list_dict = {} 39 | list_counter = Counter(list) 40 | 41 | for idx, x in enumerate(list_counter): 42 | list_dict[x] = idx 43 | 44 | return list_dict 45 | 46 | 47 | def get_fp_from_smi(smi): 48 | """Calculates ECFP from SMILES-sting 49 | 50 | :param smi: SMILES string 51 | :type smi: str 52 | :return: ECFP fingerprint vector 53 | :rtype: np.ndarray 54 | """ 55 | mol_no_Hs = Chem.MolFromSmiles(smi) 56 | mol = Chem.AddHs(mol_no_Hs) 57 | 58 | return np.array(rdMolDescriptors.GetMorganFingerprintAsBitVect(mol, 2, nBits=256)) 59 | 60 | 61 | HYBRIDISATIONS = [ 62 | "SP3", 63 | "SP2", 64 | "SP", 65 | "UNSPECIFIED", 66 | "S", 67 | ] 68 | 69 | AROMATOCITY = [ 70 | "True", 71 | "False", 72 | ] 73 | 74 | IS_RING = [ 75 | "True", 76 | "False", 77 | ] 78 | 79 | ATOMTYPES = [ 80 | "H", 81 | "C", 82 | "N", 83 | "O", 84 | "F", 85 | "P", 86 | "S", 87 | "Cl", 88 | "Br", 89 | "I", 90 | ] 91 | 92 | QML_ATOMTYPES = [ 93 | "X", 94 | "H", 95 | "C", 96 | "N", 97 | "O", 98 | "F", 99 | "P", 100 | "S", 101 | "Cl", 102 | "Br", 103 | "I", 104 | ] 105 | --------------------------------------------------------------------------------