├── CHANGELOG.md ├── LICENSE ├── README.md ├── aizynthtrain ├── __init__.py ├── data │ ├── __init__.py │ └── uspto.py ├── modelling │ ├── __init__.py │ ├── chemformer │ │ └── create_dataset_split.py │ └── expansion_policy │ │ ├── __init__.py │ │ ├── create_template_lib.py │ │ ├── eval_multi_step.py │ │ ├── eval_one_step.py │ │ ├── featurize.py │ │ ├── split_data.py │ │ └── training.py ├── pipelines │ ├── __init__.py │ ├── chemformer_data_prep_pipeline.py │ ├── chemformer_train_pipeline.py │ ├── data │ │ ├── __init__.py │ │ ├── autotag_data_extraction.yaml │ │ ├── chemformer_data_prep_pipeline.yaml │ │ ├── chemformer_product_tagging_pipeline.yaml │ │ ├── reaction_validation_pipeline.yaml │ │ ├── stereo_checks_pipeline.yaml │ │ └── template_validation_pipeline.yaml │ ├── disconnection_chemformer_data_prep_pipeline.py │ ├── expansion_model_pipeline.py │ ├── notebooks │ │ ├── __init__.py │ │ ├── expansion_model_val.py │ │ ├── reaction_selection.py │ │ ├── stereo_selection.py │ │ ├── stereo_template_selection.py │ │ └── template_selection.py │ ├── stereo_template_pipeline.py │ └── template_pipeline.py └── utils │ ├── __init__.py │ ├── configs.py │ ├── data_utils.py │ ├── files.py │ ├── keras_utils.py │ ├── onnx_converter.py │ ├── reporting.py │ ├── stereo_flipper.py │ └── template_runner.py ├── configs └── uspto │ ├── expansion_model_pipeline_config.yml │ ├── expansion_pipeline.sh │ ├── ringbreaker_model_pipeline_config.yml │ ├── ringbreaker_pipeline.sh │ ├── routes_for_eval.json │ ├── setup_folder.py │ ├── smiles_for_eval.txt │ ├── stock_for_eval_recov.txt │ ├── template_pipeline.sh │ └── template_pipeline_config.yml ├── env-dev.yml ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py ├── data └── test_keras.hdf5 ├── test_configs.py ├── test_data_utils.py ├── test_file_utils.py └── test_onnx_converter.py /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # CHANGELOG 2 | 3 | ## Version 0.3.0 - 2024-05-28 4 | 5 | ### Features 6 | 7 | - Adding pipeine for creating dataset for disconnection-aware Chemformer model. 8 | - Fine-tuning of Chemformer has been updated to support Hydra configuration. 9 | 10 | ### Bug-fixes 11 | 12 | - Fixed an issue in the template pipeline when sibling reaction information is not available. 13 | 14 | ## Version 0.2.0 - 2024-01-09 15 | 16 | ### Features 17 | 18 | - Adding pipelines for preparing and fine-tuning Chemformer model 19 | - Adding pipelines for training retrosynthesis model for stereocontrolled reactions 20 | - Adding updates to the reaction clean-up pipeline 21 | 22 | ### Trivial changes 23 | 24 | - Adding support for aizynthfinder 4.0.0 25 | - Adding support for making ONNX models 26 | 27 | ## Version 0.1.0 - 2022-11 28 | 29 | First public release -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2022 Molecular AI, AstraZeneca Sweden AB 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # aizynthtrain 2 | 3 | aizynthtrain is a collection of routines, configurations and pipelines for training synthesis prediction models for the AiZynthFinder software. 4 | 5 | ## Prerequisites 6 | 7 | Before you begin, ensure you have met the following requirements: 8 | 9 | * Linux, Windows or macOS platforms are supported - as long as the dependencies are supported on these platforms. 10 | 11 | * You have installed [anaconda](https://www.anaconda.com/) or [miniconda](https://docs.conda.io/en/latest/miniconda.html) with python 3.9 - 3.10 12 | 13 | The tool has been developed on a Linux platform. 14 | 15 | ## Installation 16 | 17 | First clone the repository using Git. 18 | 19 | Then execute the following commands in the root of the repository 20 | 21 | conda env create -f env-dev.yml 22 | conda activate aizynthtrain 23 | poetry install 24 | 25 | the `aizynthtrain` package is now installed in editable mode. 26 | 27 | Now, you need to create a Jupyter Kernel for this environment 28 | 29 | conda activate aizynthtrain 30 | ipython kernel install --name "aizynthtrain" --user 31 | 32 | ## Usage 33 | 34 | There is a number of example configurations and SLURM scripts in the `configs/uspto` folder that 35 | was used to retrain the USPTO-based expansion model for AiZynthFinder. 36 | 37 | The example configurations can be used as-is to reproduce the modelling, whereas the SLURM scripts 38 | migt have to be adjusted to your system. 39 | 40 | To use the given configurations, follow this procedure: 41 | 42 | For the training and validation of the expansion models, you need a number of data files that can be downloaded from Zenodo. 43 | In addition the pipelines create a lof of artifcats and therefore, you will start with setting up a folder for your modelling 44 | 45 | 1. Create a new folder where the pipelines will be executed using 46 | 47 | python configs/uspto/setup_folder.py PATH_TO_YOUR_FOLDER 48 | 49 | 2. cd PATH_TO_YOUR_FOLDER 50 | 51 | 3. Perform the pipelines from the `rxnutils` package to download and prepare the USPTO datasets for modelling: https://molecularai.github.io/reaction_utils/uspto.html 52 | 53 | The `rxnutils` package pipelines should have produced a file `uspto_data_mapped.csv` that will be used as the starting point of the template-extraction pipeline 54 | 55 | 4. Adapt and execute the `configs/uspto/template_pipeline.sh` SLURM script on your machine. 56 | 57 | The template-extraction pipeline will produce a number of artifacts, the most important being 58 | - `reaction_selection_report.html` that is the report of the reaction selection 59 | - `template_selection_report.html` that is the report of the template selection 60 | - `uspto_template_library.csv` and `uspto_ringbreaker_template_library.csv` that are the final reactions and reaction templates that will be used to train the expansion model. 61 | 62 | 5. Adapt and execute the `configs/uspto/expansion_pipeline.sh` SLURM script on your machine. 63 | 64 | The expansion model pipeline will produce these important artifacts 65 | 66 | - `uspto_expansion_model_report.html` that is the report of the model training and validation 67 | - `uspto_keras_model.hdf5` the trained Keras model 68 | - `uspto_unique_templates.csv.gz` the template library for AiZynthFinder 69 | 70 | You can also execute the `configs/uspto/ringbreaker_pipeline.sh` to train a RingBreaker model. 71 | 72 | ### Testing 73 | 74 | Tests uses the ``pytest`` package, and is installed by `poetry` 75 | 76 | Run the tests using: 77 | 78 | pytest -v 79 | 80 | 81 | ## Contributing 82 | 83 | We welcome contributions, in the form of issues or pull requests. 84 | 85 | If you have a question or want to report a bug, please submit an issue. 86 | 87 | 88 | To contribute with code to the project, follow these steps: 89 | 90 | 1. Fork this repository. 91 | 2. Create a branch: `git checkout -b `. 92 | 3. Make your changes and commit them: `git commit -m ''` 93 | 4. Push to the remote branch: `git push` 94 | 5. Create the pull request. 95 | 96 | Please use ``black`` package for formatting, and follow ``pep8`` style guide. 97 | 98 | 99 | ## Contributors 100 | 101 | * [@SGenheden](https://www.github.com/SGenheden) 102 | * [@lakshidaa](https://github.com/lakshidaa) 103 | * [@anniewesterlund](https://www.github.com/anniewesterlund) 104 | 105 | The contributors have limited time for support questions, but please do not hesitate to submit an issue (see above). 106 | 107 | ## License 108 | 109 | The software is licensed under the Appache 2.0 license (see LICENSE file), and is free and provided as-is. 110 | 111 | ## References 112 | 113 | 1. Genheden S, Norrby PO, Engkvist O (2022) AiZynthTrain: robust, reproducible, and extensible pipelines for training synthesis prediction models. ChemRxiv. Prerint. https://doi.org/10.26434/chemrxiv-2022-kls5q 114 | 2. Kannas C, Thakkar A, Bjerrum E, Genheden S (2022) rxnutils – A Cheminformatics Python Library for Manipulating Chemical Reaction Data. ChemRxiv. https://doi.org/10.26434/chemrxiv-2022-wt440-v2 115 | 3. Genheden S, Thakkar A, Chadimova V, et al (2020) AiZynthFinder: a fast, robust and flexible open-source software for retrosynthetic planning. J. Cheminf. https://jcheminf.biomedcentral.com/articles/10.1186/s13321-020-00472-1 116 | 4. Thakkar A, Kogej T, Reymond J-L, et al (2019) Datasets and their influence on the development of computer assisted synthesis planning tools in the pharmaceutical domain. Chem Sci. https://doi.org/10.1039/C9SC04944D 117 | -------------------------------------------------------------------------------- /aizynthtrain/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthtrain/55df633bbdf971378e7f031e8e89190cde1b17fb/aizynthtrain/__init__.py -------------------------------------------------------------------------------- /aizynthtrain/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthtrain/55df633bbdf971378e7f031e8e89190cde1b17fb/aizynthtrain/data/__init__.py -------------------------------------------------------------------------------- /aizynthtrain/data/uspto.py: -------------------------------------------------------------------------------- 1 | """ Module containing class to transform UPSTO data prepared by reaction_utils pipelines 2 | """ 3 | import pandas as pd 4 | 5 | 6 | class UsptoImporter: 7 | def __init__(self, data_path: str, confidence_cutoff: float = 0.1) -> None: 8 | self.data = self.transform_data(data_path, confidence_cutoff) 9 | 10 | @staticmethod 11 | def transform_data(filename: str, confidence_cutoff: float) -> pd.DataFrame: 12 | df = pd.read_csv( 13 | filename, sep="\t", usecols=["ID", "Year", "mapped_rxn", "confidence"] 14 | ) 15 | sel = df["confidence"] < confidence_cutoff 16 | print( 17 | f"Removing {sel.sum()} with confidence of mapping less than {confidence_cutoff}" 18 | ) 19 | df = df[~sel] 20 | return pd.DataFrame( 21 | { 22 | "id": df["ID"], 23 | "source": ["uspto"] * len(df), 24 | "date": [f"{year}-01-01" for year in df["Year"]], 25 | "rsmi": df["mapped_rxn"], 26 | "classification": ["0.0 Unrecognized"] * len(df), 27 | } 28 | ) 29 | 30 | 31 | if __name__ == "__main__": 32 | import argparse 33 | 34 | parser = argparse.ArgumentParser("Prepare USPTO data") 35 | parser.add_argument("--data_path", required=True) 36 | parser.add_argument("--transformed_path") 37 | args = parser.parse_args() 38 | 39 | importer = UsptoImporter(args.data_path) 40 | if args.transformed_path: 41 | importer.data.to_csv(args.transformed_path, sep="\t", index=False) 42 | -------------------------------------------------------------------------------- /aizynthtrain/modelling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthtrain/55df633bbdf971378e7f031e8e89190cde1b17fb/aizynthtrain/modelling/__init__.py -------------------------------------------------------------------------------- /aizynthtrain/modelling/chemformer/create_dataset_split.py: -------------------------------------------------------------------------------- 1 | """Module for splitting a Chemformer dataset into train / validation / test sets""" 2 | import argparse 3 | import pandas as pd 4 | 5 | from aizynthtrain.utils.data_utils import split_data, extract_route_reactions 6 | from aizynthtrain.utils.configs import ( 7 | load_config, 8 | ChemformerDataPrepConfig, 9 | ) 10 | from typing import Sequence, Optional 11 | 12 | 13 | def split_dataset(config: ChemformerDataPrepConfig, dataset: pd.DataFrame) -> None: 14 | if config.routes_to_exclude: 15 | reaction_hashes = extract_route_reactions(config.routes_to_exclude) 16 | print( 17 | f"Found {len(reaction_hashes)} unique reactions given routes. Will make these test set", 18 | flush=True, 19 | ) 20 | is_external = dataset[config.reaction_hash_col].isin(reaction_hashes) 21 | subdata = dataset[~is_external] 22 | else: 23 | is_external = None 24 | subdata = dataset 25 | train_indices = [] 26 | val_indices = [] 27 | test_indices = [] 28 | 29 | subdata.apply( 30 | split_data, 31 | train_frac=config.training_fraction, 32 | random_seed=config.random_seed, 33 | train_indices=train_indices, 34 | val_indices=val_indices, 35 | test_indices=test_indices, 36 | ) 37 | 38 | subdata[config.set_col] = "train" 39 | subdata.loc[val_indices, config.set_col] = "val" 40 | subdata.loc[test_indices, config.set_col] = "test" 41 | 42 | if is_external is None: 43 | subdata.to_csv(config.chemformer_data_path, sep="\t", index=False) 44 | return 45 | 46 | dataset.loc[~is_external, config.set_col] = subdata.set.values 47 | dataset.loc[is_external, config.set_col] = "test" 48 | 49 | dataset[config.is_external_col] = False 50 | dataset.loc[is_external, config.is_external_col] = True 51 | dataset.to_csv(config.chemformer_data_path, sep="\t", index=False) 52 | 53 | 54 | def main(args: Optional[Sequence[str]]): 55 | """Command-line interface to the routines""" 56 | parser = argparse.ArgumentParser("Create dataset split for Chemformer data.") 57 | parser.add_argument("--config_path", default="The path to the configuration file") 58 | args = parser.parse_args(args=args) 59 | config: ChemformerDataPrepConfig = load_config( 60 | args.config_path, "chemformer_data_prep" 61 | ) 62 | 63 | dataset = pd.read_csv(config.reaction_components_path, sep="\t") 64 | split_dataset(config, dataset) 65 | return 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /aizynthtrain/modelling/expansion_policy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthtrain/55df633bbdf971378e7f031e8e89190cde1b17fb/aizynthtrain/modelling/expansion_policy/__init__.py -------------------------------------------------------------------------------- /aizynthtrain/modelling/expansion_policy/create_template_lib.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains scripts to generate metadata artifacts for a template library 3 | 4 | * A one-column CSV file with generated template codes, used for featurization 5 | * A gzipped CSV file with unique templates and metadata to be used with aizynthfinder 6 | * A JSON file with extensive metadata and lookup to reference ID underlying the templates 7 | * A NPZ file with indices of the training, validation and testing partitions 8 | 9 | """ 10 | import argparse 11 | import json 12 | from collections import defaultdict 13 | from typing import Sequence, Optional, Dict, List 14 | 15 | import pandas as pd 16 | import numpy as np 17 | from sklearn.preprocessing import LabelEncoder 18 | 19 | from aizynthtrain.utils.configs import ( 20 | ExpansionModelPipelineConfig, 21 | TemplateLibraryColumnsConfig, 22 | load_config, 23 | ) 24 | from aizynthtrain.utils.data_utils import split_data, extract_route_reactions 25 | 26 | 27 | def _add_and_save_template_code( 28 | dataset: pd.DataFrame, config: ExpansionModelPipelineConfig 29 | ) -> None: 30 | columns = config.library_config.columns 31 | template_labels = LabelEncoder() 32 | dataset.loc[:, columns.template_code] = template_labels.fit_transform( 33 | dataset[columns.template_hash] 34 | ) 35 | 36 | print("Saving template codes...", flush=True) 37 | dataset[columns.template_code].to_csv(config.filename("template_code"), index=False) 38 | 39 | 40 | def _make_template_dataset( 41 | dataset: pd.DataFrame, config: ExpansionModelPipelineConfig 42 | ) -> None: 43 | def group_apply( 44 | group: pd.DataFrame, 45 | id_lookup: Dict[str, List[str]], 46 | columns: TemplateLibraryColumnsConfig, 47 | template_set: str, 48 | metadata_columns: List[str], 49 | ) -> pd.Series: 50 | if id_lookup: 51 | refs = sorted( 52 | id_ 53 | for pseudo_hash in group[columns.reaction_hash] 54 | for id_ in id_lookup[pseudo_hash] 55 | ) 56 | else: 57 | refs = [] 58 | dict_ = { 59 | "_id": group[columns.template_hash].iloc[0], 60 | "count": len(refs), 61 | "dimer_only": False, 62 | "index": group[columns.template_code].iloc[0], 63 | "intra_only": False, 64 | "necessary_reagent": "", 65 | "reaction_smarts": group[columns.retro_template].iloc[0], 66 | "references": refs, 67 | "template_set": template_set, 68 | } 69 | for column in metadata_columns: 70 | dict_[column] = group[column].iloc[0] 71 | return pd.Series(dict_) 72 | 73 | if config.selected_ids_path: 74 | with open(config.selected_ids_path, "r") as fileobj: 75 | pseudo_hash_to_id = json.load(fileobj) 76 | else: 77 | pseudo_hash_to_id = None 78 | 79 | columns = config.library_config.columns 80 | lookup = dataset.groupby(columns.template_hash).apply( 81 | group_apply, 82 | id_lookup=pseudo_hash_to_id, 83 | columns=columns, 84 | template_set=config.library_config.template_set, 85 | metadata_columns=config.library_config.metadata_columns, 86 | ) 87 | return lookup 88 | 89 | 90 | def _save_lookup( 91 | template_dataset: pd.DataFrame, config: ExpansionModelPipelineConfig 92 | ) -> None: 93 | """ 94 | Save a JSON file with template data and reference to original reaction. 95 | This was intended for an internal platform originally, hence the formatting. 96 | """ 97 | print("Creating template hash lookup...", flush=True) 98 | template_dataset2 = template_dataset.drop( 99 | columns=config.library_config.metadata_columns 100 | ) 101 | with open(config.filename("template_lookup"), "w") as fileobj: 102 | json.dump( 103 | template_dataset2.to_dict(orient="records"), 104 | fileobj, 105 | ) 106 | 107 | 108 | def _save_unique_templates( 109 | template_dataset: pd.DataFrame, config: ExpansionModelPipelineConfig 110 | ) -> None: 111 | """ 112 | Save a gzipped CSV file with template code, template SMARTS and metadata 113 | that can be used by AiZynthFinder 114 | """ 115 | print("Creating unique template library...", flush=True) 116 | columns = config.library_config.columns 117 | template_dataset2 = template_dataset[ 118 | ["index", "reaction_smarts"] 119 | + config.library_config.metadata_columns 120 | + ["count"] 121 | ] 122 | template_dataset2 = template_dataset2.rename( 123 | columns={ 124 | "index": columns.template_code, 125 | "reaction_smarts": columns.retro_template, 126 | "count": columns.library_occurrence, 127 | }, 128 | ) 129 | template_dataset2.to_csv(config.filename("unique_templates"), index=False, sep="\t") 130 | 131 | 132 | def _save_split_indices( 133 | dataset: pd.DataFrame, config: ExpansionModelPipelineConfig 134 | ) -> None: 135 | """Perform a stratified split and save the indices to disc""" 136 | print("Creating split of template library...", flush=True) 137 | columns = config.library_config.columns 138 | 139 | if config.routes_to_exclude: 140 | reaction_hashes = extract_route_reactions(config.routes_to_exclude) 141 | print( 142 | f"Found {len(reaction_hashes)} unique reactions given routes. Will make these test set", 143 | flush=True, 144 | ) 145 | sel = dataset[columns.reaction_hash].isin(reaction_hashes) 146 | test_indices = dataset[sel].index.to_list() 147 | subdata = dataset[~sel] 148 | else: 149 | test_indices = [] 150 | subdata = dataset 151 | train_indices = [] 152 | val_indices = [] 153 | 154 | subdata.groupby(columns.template_hash).apply( 155 | split_data, 156 | train_frac=config.training_fraction, 157 | random_seed=config.random_seed, 158 | train_indices=train_indices, 159 | val_indices=val_indices, 160 | test_indices=None if test_indices else test_indices, 161 | ) 162 | 163 | print( 164 | f"Selecting {len(train_indices)} ({len(train_indices)/len(dataset)*100:.2f}%) records as training set" 165 | ) 166 | print( 167 | f"Selecting {len(val_indices)} ({len(val_indices)/len(dataset)*100:.2f}%) records as validation set" 168 | ) 169 | print( 170 | f"Selecting {len(test_indices)} ({len(test_indices)/len(dataset)*100:.2f}%) records as test set", 171 | flush=True, 172 | ) 173 | 174 | np.savez( 175 | config.filename("split_indices"), 176 | train=train_indices, 177 | val=val_indices, 178 | test=test_indices, 179 | ) 180 | 181 | 182 | def main(args: Optional[Sequence[str]] = None) -> None: 183 | """Command-line interface ot the routines""" 184 | parser = argparse.ArgumentParser("Create template library metadata") 185 | parser.add_argument("config", default="The path to the configuration file") 186 | args = parser.parse_args(args=args) 187 | 188 | config: ExpansionModelPipelineConfig = load_config( 189 | args.config, "expansion_model_pipeline" 190 | ) 191 | 192 | dataset = pd.read_csv( 193 | config.filename("library"), 194 | sep="\t", 195 | ) 196 | 197 | _add_and_save_template_code(dataset, config) 198 | template_dataset = _make_template_dataset(dataset, config) 199 | 200 | _save_lookup(template_dataset, config) 201 | _save_unique_templates(template_dataset, config) 202 | 203 | _save_split_indices(dataset, config) 204 | 205 | 206 | if __name__ == "__main__": 207 | main() 208 | -------------------------------------------------------------------------------- /aizynthtrain/modelling/expansion_policy/eval_multi_step.py: -------------------------------------------------------------------------------- 1 | """Module routines for evaluating one-step retrosynthesis model""" 2 | import argparse 3 | import tempfile 4 | import json 5 | import subprocess 6 | from collections import defaultdict 7 | from typing import Sequence, Optional, Dict, Any 8 | 9 | import yaml 10 | import pandas as pd 11 | import numpy as np 12 | from route_distances.ted.reactiontree import ReactionTreeWrapper 13 | 14 | from aizynthtrain.utils.configs import ( 15 | ExpansionModelEvaluationConfig, 16 | load_config, 17 | ) 18 | 19 | 20 | def _create_config( 21 | model_path: str, templates_path: str, stock_path: str, properties: Dict[str, Any] 22 | ) -> str: 23 | _, filename = tempfile.mkstemp(suffix=".yaml") 24 | dict_ = { 25 | "expansion": {"default": [model_path, templates_path]}, 26 | "stock": {"default": stock_path}, 27 | "search": properties, 28 | } 29 | with open(filename, "w") as fileobj: 30 | yaml.dump(dict_, fileobj) 31 | return filename 32 | 33 | 34 | def _eval_finding( 35 | config: ExpansionModelEvaluationConfig, finder_config_path: str 36 | ) -> None: 37 | output_path = config.filename("finder_output").replace(".hdf5", "_finding.hdf5") 38 | _run_finder(config.target_smiles, finder_config_path, output_path) 39 | 40 | finder_output = pd.read_hdf(output_path, "table") 41 | stats = { 42 | "target": [str(x) for x in finder_output["target"].to_list()], 43 | "first solution time": [ 44 | float(x) for x in finder_output["first_solution_time"].to_list() 45 | ], 46 | "is solved": [bool(x) for x in finder_output["is_solved"].to_list()], 47 | } 48 | 49 | return stats 50 | 51 | 52 | def _eval_recovery( 53 | config: ExpansionModelEvaluationConfig, finder_config_path: str 54 | ) -> None: 55 | with open(config.reference_routes, "r") as fileobj: 56 | ref_trees = json.load(fileobj) 57 | smiles = [tree["smiles"] for tree in ref_trees] 58 | 59 | _, smiles_filename = tempfile.mkstemp(suffix=".txt") 60 | with open(smiles_filename, "w") as fileobj: 61 | fileobj.write("\n".join(smiles)) 62 | 63 | output_path = config.filename("finder_output").replace(".hdf5", "_recovery.hdf5") 64 | _run_finder(smiles_filename, finder_config_path, output_path) 65 | 66 | finder_output = pd.read_hdf(output_path, "table") 67 | stats = defaultdict(list) 68 | for ref_tree, (_, row) in zip(ref_trees, finder_output.iterrows()): 69 | ref_wrapper = ReactionTreeWrapper(ref_tree) 70 | dists = [ 71 | ReactionTreeWrapper(dict_).distance_to(ref_wrapper) for dict_ in row.trees 72 | ] 73 | min_dists = float(min(dists)) 74 | stats["target"].append(ref_tree["smiles"]) 75 | stats["is solved"].append(bool(row.is_solved)) 76 | stats["found reference"].append(min_dists == 0.0) 77 | stats["closest to reference"].append(min_dists) 78 | stats["rank of closest"].append(float(np.argmin(dists))) 79 | 80 | return stats 81 | 82 | 83 | def _run_finder( 84 | smiles_filename: str, finder_config_path: str, output_path: str 85 | ) -> None: 86 | subprocess.run( 87 | [ 88 | "aizynthcli", 89 | "--smiles", 90 | smiles_filename, 91 | "--config", 92 | finder_config_path, 93 | "--output", 94 | output_path, 95 | ] 96 | ) 97 | 98 | 99 | def main(args: Optional[Sequence[str]] = None) -> None: 100 | """Command-line interface to multi-step evaluation""" 101 | parser = argparse.ArgumentParser( 102 | "Tool evaluate a multi-step retrosynthesis model using AiZynthFinder" 103 | ) 104 | parser.add_argument( 105 | "--model_path", help="overrides the model path from the config file" 106 | ) 107 | parser.add_argument( 108 | "--templates_path", help="overrides the templates path from the config file" 109 | ) 110 | parser.add_argument("config", help="the filename to a configuration file") 111 | args = parser.parse_args(args) 112 | 113 | config: ExpansionModelEvaluationConfig = load_config( 114 | args.config, "expansion_model_evaluation" 115 | ) 116 | 117 | all_stats = {} 118 | 119 | if config.stock_for_finding and config.target_smiles: 120 | finder_config_path = _create_config( 121 | args.model_path or config.filename("onnx_model"), 122 | args.templates_path or config.filename("unique_templates"), 123 | config.stock_for_finding, 124 | config.search_properties_for_finding, 125 | ) 126 | 127 | stats = _eval_finding(config, finder_config_path) 128 | all_stats["finding"] = stats 129 | 130 | if config.stock_for_recovery and config.reference_routes: 131 | finder_config_path = _create_config( 132 | args.model_path or config.filename("onnx_model"), 133 | args.templates_path or config.filename("unique_templates"), 134 | config.stock_for_recovery, 135 | config.search_properties_for_recovery, 136 | ) 137 | 138 | stats = _eval_recovery(config, finder_config_path) 139 | all_stats["recovery"] = stats 140 | 141 | with open(config.filename("multistep_report"), "w") as fileobj: 142 | json.dump(all_stats, fileobj) 143 | 144 | if "finding" in all_stats: 145 | print("\nEvaluation of route finding capabilities:") 146 | pd_stats = pd.DataFrame(all_stats["finding"]) 147 | print( 148 | f"Average first solution time: {pd_stats['first solution time'].mean():.2f}" 149 | ) 150 | print( 151 | f"Average number of solved target: {pd_stats['is solved'].mean()*100:.2f}%" 152 | ) 153 | 154 | if "recovery" in all_stats: 155 | print("\nEvaluation of route recovery capabilities:") 156 | pd_stats = pd.DataFrame(all_stats["recovery"]) 157 | print( 158 | f"Average number of solved target: {pd_stats['is solved'].mean()*100:.2f}%" 159 | ) 160 | print(f"Average found reference: {pd_stats['found reference'].mean()*100:.2f}%") 161 | print( 162 | f"Average closest to reference: {pd_stats['closest to reference'].mean():.2f}" 163 | ) 164 | print(f"Average rank of closest: {pd_stats['rank of closest'].mean():.2f}") 165 | 166 | 167 | if __name__ == "__main__": 168 | main() 169 | -------------------------------------------------------------------------------- /aizynthtrain/modelling/expansion_policy/eval_one_step.py: -------------------------------------------------------------------------------- 1 | """ Module routines for evaluating one-step retrosynthesis model 2 | """ 3 | import argparse 4 | import math 5 | import random 6 | import tempfile 7 | import json 8 | from collections import defaultdict 9 | from typing import Sequence, Optional, Dict, List, Any 10 | 11 | import pandas as pd 12 | from tqdm import tqdm 13 | from rdkit import Chem 14 | from rdkit.Chem import AllChem 15 | from rdkit.Chem.rdMolDescriptors import CalcNumRings 16 | from aizynthfinder.aizynthfinder import AiZynthExpander 17 | 18 | from aizynthtrain.utils.configs import ( 19 | ExpansionModelEvaluationConfig, 20 | load_config, 21 | ) 22 | 23 | YAML_TEMPLATE = """expansion: 24 | default: 25 | - {} 26 | - {} 27 | """ 28 | 29 | 30 | def _create_config(model_path: str, templates_path: str) -> str: 31 | _, filename = tempfile.mkstemp(suffix=".yaml") 32 | with open(filename, "w") as fileobj: 33 | fileobj.write(YAML_TEMPLATE.format(model_path, templates_path)) 34 | return filename 35 | 36 | 37 | def _create_test_reaction(config: ExpansionModelEvaluationConfig) -> str: 38 | """ 39 | Selected reactions for testing. Group reactions by parent reaction classification 40 | and then selected random reactions from each group. 41 | """ 42 | 43 | def transform_rsmi(row): 44 | reactants, _, products = row[config.columns.reaction_smiles].split(">") 45 | rxn = AllChem.ReactionFromSmarts(f"{reactants}>>{products}") 46 | AllChem.RemoveMappingNumbersFromReactions(rxn) 47 | return AllChem.ReactionToSmiles(rxn) 48 | 49 | test_lib_path = config.filename("library", "testing") 50 | data = pd.read_csv(test_lib_path, sep="\t") 51 | if config.columns.ring_breaker not in data.columns: 52 | data[config.columns.ring_breaker] = [False] * len(data) 53 | trunc_class = data[config.columns.classification].apply( 54 | lambda x: ".".join(x.split(" ")[0].split(".")[:2]) 55 | ) 56 | 57 | class_to_idx = defaultdict(list) 58 | for idx, val in enumerate(trunc_class): 59 | class_to_idx[val].append(idx) 60 | n_per_class = math.ceil(config.n_test_reactions / len(class_to_idx)) 61 | 62 | random.seed(1789) 63 | selected_idx = [] 64 | for indices in class_to_idx.values(): 65 | if len(indices) > n_per_class: 66 | selected_idx.extend(random.sample(indices, k=n_per_class)) 67 | else: 68 | selected_idx.extend(indices) 69 | 70 | data_sel = data.iloc[selected_idx] 71 | rsmi = data_sel.apply(transform_rsmi, axis=1) 72 | filename = test_lib_path.replace(".csv", "_selected.csv") 73 | pd.DataFrame( 74 | { 75 | config.columns.reaction_smiles: rsmi, 76 | config.columns.ring_breaker: data_sel[config.columns.ring_breaker], 77 | "original_index": selected_idx, 78 | } 79 | ).to_csv(filename, index=False, sep="\t") 80 | return filename 81 | 82 | 83 | def _eval_expander( 84 | expander_output: List[Dict[str, Any]], 85 | ref_reactions_path: str, 86 | config: ExpansionModelEvaluationConfig, 87 | ) -> None: 88 | ref_reactions = pd.read_csv(ref_reactions_path, sep="\t") 89 | 90 | stats = defaultdict(list) 91 | for (_, row), output in zip(ref_reactions.iterrows(), expander_output): 92 | reactants, _, product = row[config.columns.reaction_smiles].split(">") 93 | nrings_prod = CalcNumRings(Chem.MolFromSmiles(product)) 94 | reactants_inchis = set( 95 | Chem.MolToInchiKey(Chem.MolFromSmiles(smi)) for smi in reactants.split(".") 96 | ) 97 | found = False 98 | found_idx = None 99 | ring_broken = False 100 | for idx, outcome in enumerate(output["outcome"]): 101 | outcome_inchis = set( 102 | Chem.MolToInchiKey(Chem.MolFromSmiles(smi)) 103 | for smi in outcome.split(".") 104 | ) 105 | if outcome_inchis == reactants_inchis: 106 | found = True 107 | found_idx = idx + 1 108 | nrings_reactants = sum( 109 | CalcNumRings(Chem.MolFromSmiles(smi)) for smi in outcome.split(".") 110 | ) 111 | if nrings_reactants < nrings_prod: 112 | ring_broken = True 113 | stats["target"].append(product) 114 | stats["expected reactants"].append(reactants) 115 | stats["found expected"].append(found) 116 | stats["rank of expected"].append(found_idx) 117 | if row[config.columns.ring_breaker]: 118 | stats["ring broken"].append(ring_broken) 119 | else: 120 | stats["ring broken"].append(None) 121 | stats["ring breaking"].append(bool(row[config.columns.ring_breaker])) 122 | stats["non-applicable"].append(output["non-applicable"]) 123 | 124 | with open(config.filename("onestep_report"), "w") as fileobj: 125 | json.dump(stats, fileobj) 126 | 127 | stats = pd.DataFrame(stats) 128 | print(f"Evaluated {len(ref_reactions)} reactions") 129 | print(f"Average found expected: {stats['found expected'].mean()*100:.2f}%") 130 | print(f"Average rank of expected: {stats['rank of expected'].mean():.2f}") 131 | print(f"Average ring broken when expected: {stats['ring broken'].mean()*100:.2f}%") 132 | print(f"Percentage of ring reactions: {stats['ring breaking'].mean()*100:.2f}%") 133 | print(f"Average non-applicable (in top-50): {stats['non-applicable'].mean():.2f}") 134 | 135 | 136 | def _run_expander( 137 | ref_reactions_path: str, config_path: str, config: ExpansionModelEvaluationConfig 138 | ) -> List[Dict[str, Any]]: 139 | ref_reactions = pd.read_csv(ref_reactions_path, sep="\t") 140 | targets = [ 141 | rxn.split(">")[-1] for rxn in ref_reactions[config.columns.reaction_smiles] 142 | ] 143 | 144 | expander = AiZynthExpander(configfile=config_path) 145 | expander.expansion_policy.select("default") 146 | expander_output = [] 147 | for target in tqdm(targets): 148 | outcome = expander.do_expansion(target, config.top_n) 149 | outcome_list = [ 150 | item.reaction_smiles().split(">>")[1] 151 | for item_list in outcome 152 | for item in item_list 153 | ] 154 | expander_output.append( 155 | { 156 | "outcome": outcome_list, 157 | "non-applicable": expander.stats["non-applicable"], 158 | } 159 | ) 160 | 161 | with open(config.filename("expander_output"), "w") as fileobj: 162 | json.dump(expander_output, fileobj, indent=4) 163 | 164 | return expander_output 165 | 166 | 167 | def main(args: Optional[Sequence[str]] = None) -> None: 168 | """Command-line interface to multi-step evaluation""" 169 | parser = argparse.ArgumentParser( 170 | "Tool to evaluate a one-step retrosynthesis model using AiZynthExpander" 171 | ) 172 | parser.add_argument( 173 | "--model_path", help="overrides the model path from the config file" 174 | ) 175 | parser.add_argument( 176 | "--templates_path", help="overrides the templates path from the config file" 177 | ) 178 | parser.add_argument( 179 | "--test_library", help="overrides the test_library from the config file" 180 | ) 181 | parser.add_argument("config", help="the filename to a configuration file") 182 | args = parser.parse_args(args) 183 | 184 | config: ExpansionModelEvaluationConfig = load_config( 185 | args.config, "expansion_model_evaluation" 186 | ) 187 | 188 | ref_reactions_path = args.test_library or _create_test_reaction(config) 189 | 190 | config_path = _create_config( 191 | args.model_path or config.filename("onnx_model"), 192 | args.templates_path or config.filename("unique_templates"), 193 | ) 194 | 195 | expander_output = _run_expander(ref_reactions_path, config_path, config) 196 | 197 | _eval_expander(expander_output, ref_reactions_path, config) 198 | 199 | 200 | if __name__ == "__main__": 201 | main() 202 | -------------------------------------------------------------------------------- /aizynthtrain/modelling/expansion_policy/featurize.py: -------------------------------------------------------------------------------- 1 | """Module that featurizes a template library for training an expansion model""" 2 | import argparse 3 | from typing import Sequence, Optional, Tuple 4 | 5 | import numpy as np 6 | from scipy import sparse 7 | from rdkit import Chem, DataStructs 8 | from rdkit.Chem import AllChem 9 | from rxnutils.data.batch_utils import read_csv_batch 10 | 11 | from aizynthtrain.utils.configs import ( 12 | ExpansionModelPipelineConfig, 13 | load_config, 14 | ) 15 | 16 | 17 | def _smiles_to_fingerprint( 18 | args: Sequence[str], fp_radius: int, fp_length: int, chirality: bool = False 19 | ) -> np.ndarray: 20 | mol = Chem.MolFromSmiles(args[0]) 21 | bitvect = AllChem.GetMorganFingerprintAsBitVect( 22 | mol, fp_radius, fp_length, useChirality=chirality 23 | ) 24 | array = np.zeros((1,)) 25 | DataStructs.ConvertToNumpyArray(bitvect, array) 26 | return array 27 | 28 | 29 | def _make_inputs( 30 | config: ExpansionModelPipelineConfig, batch: Tuple[int, int, int] = None 31 | ) -> None: 32 | print("Generating inputs...", flush=True) 33 | smiles_dataset = read_csv_batch( 34 | config.filename("library"), 35 | batch=batch, 36 | sep="\t", 37 | usecols=[config.library_config.columns.reaction_smiles], 38 | ) 39 | products = np.asarray( 40 | [smiles.split(">")[-1] for smiles in smiles_dataset.squeeze("columns")] 41 | ) 42 | 43 | inputs = np.apply_along_axis( 44 | _smiles_to_fingerprint, 45 | 0, 46 | [products], 47 | fp_length=config.model_hyperparams.fingerprint_length, 48 | fp_radius=config.model_hyperparams.fingerprint_radius, 49 | chirality=config.model_hyperparams.chirality, 50 | ) 51 | inputs = sparse.lil_matrix(inputs.T).tocsr() 52 | 53 | filename = config.filename("model_inputs") 54 | if batch is not None: 55 | filename = filename.replace(".npz", f".{batch[0]}.npz") 56 | sparse.save_npz(filename, inputs, compressed=True) 57 | 58 | 59 | def _make_labels( 60 | config: ExpansionModelPipelineConfig, batch: Tuple[int, int, int] = None 61 | ) -> None: 62 | print("Generating labels...", flush=True) 63 | 64 | # Find out the maximum template code 65 | with open(config.filename("template_code"), "r") as fileobj: 66 | nlabels = max(int(code) for idx, code in enumerate(fileobj) if idx > 0) + 1 67 | 68 | template_code_data = read_csv_batch(config.filename("template_code"), batch=batch) 69 | template_codes = template_code_data.squeeze("columns").to_numpy() 70 | 71 | labels = sparse.lil_matrix((len(template_codes), nlabels), dtype=np.int8) 72 | labels[np.arange(len(template_codes)), template_codes] = 1 73 | labels = labels.tocsr() 74 | 75 | filename = config.filename("model_labels") 76 | if batch is not None: 77 | filename = filename.replace(".npz", f".{batch[0]}.npz") 78 | sparse.save_npz(filename, labels, compressed=True) 79 | 80 | 81 | def main(args: Optional[Sequence[str]] = None) -> None: 82 | """Command-line interface for the featurization tool""" 83 | parser = argparse.ArgumentParser( 84 | "Tool to featurize a template library to be used in training a expansion network policy" 85 | ) 86 | parser.add_argument("config", default="The path to the configuration file") 87 | parser.add_argument("--batch", type=int, nargs=3, help="the batch specification") 88 | args = parser.parse_args(args) 89 | 90 | config: ExpansionModelPipelineConfig = load_config( 91 | args.config, "expansion_model_pipeline" 92 | ) 93 | 94 | _make_inputs(config, batch=args.batch) 95 | _make_labels(config, batch=args.batch) 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /aizynthtrain/modelling/expansion_policy/split_data.py: -------------------------------------------------------------------------------- 1 | """Module routines for splitting data for expansion model""" 2 | import argparse 3 | from typing import Sequence, Optional, Union 4 | 5 | import pandas as pd 6 | import numpy as np 7 | from scipy import sparse 8 | 9 | 10 | from aizynthtrain.utils.configs import ( 11 | ExpansionModelPipelineConfig, 12 | load_config, 13 | ) 14 | 15 | 16 | def _split_and_save_data( 17 | data: Union[pd.DataFrame, np.ndarray, sparse.csr_matrix], 18 | data_label: str, 19 | config: ExpansionModelPipelineConfig, 20 | train_arr: np.ndarray, 21 | val_arr: np.ndarray, 22 | test_arr: np.ndarray, 23 | ) -> None: 24 | array_dict = {"training": train_arr, "validation": val_arr, "testing": test_arr} 25 | for subset, arr in array_dict.items(): 26 | filename = config.filename(data_label, subset=subset) 27 | if isinstance(data, pd.DataFrame): 28 | data.iloc[arr].to_csv( 29 | filename, 30 | sep="\t", 31 | index=False, 32 | ) 33 | elif isinstance(data, np.ndarray): 34 | np.savez(filename, data[arr]) 35 | else: 36 | sparse.save_npz(filename, data[arr], compressed=True) 37 | 38 | 39 | def main(args: Optional[Sequence[str]] = None) -> None: 40 | """Command-line interface for the splitting routines""" 41 | parser = argparse.ArgumentParser( 42 | "Tool to split data to be used in training a expansion network policy" 43 | ) 44 | parser.add_argument("config", help="the filename to a configuration file") 45 | args = parser.parse_args(args) 46 | 47 | config: ExpansionModelPipelineConfig = load_config( 48 | args.config, "expansion_model_pipeline" 49 | ) 50 | 51 | split_indices = np.load(config.filename("split_indices")) 52 | train_arr = split_indices["train"] 53 | val_arr = split_indices["val"] 54 | test_arr = split_indices["test"] 55 | 56 | print("Splitting template library...", flush=True) 57 | dataset = pd.read_csv( 58 | config.filename("library"), 59 | sep="\t", 60 | ) 61 | _split_and_save_data(dataset, "library", config, train_arr, val_arr, test_arr) 62 | 63 | print("Splitting labels...", flush=True) 64 | data = sparse.load_npz(config.filename("model_labels")) 65 | _split_and_save_data(data, "model_labels", config, train_arr, val_arr, test_arr) 66 | 67 | print("Splitting inputs...", flush=True) 68 | data = sparse.load_npz(config.filename("model_inputs")) 69 | _split_and_save_data(data, "model_inputs", config, train_arr, val_arr, test_arr) 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /aizynthtrain/modelling/expansion_policy/training.py: -------------------------------------------------------------------------------- 1 | """Module routines for training an expansion model""" 2 | import argparse 3 | from typing import Sequence, Optional, Tuple 4 | 5 | import numpy as np 6 | from tensorflow.keras.layers import Dense, Dropout 7 | from tensorflow.keras.models import Sequential 8 | from tensorflow.keras import regularizers 9 | from sklearn.utils import shuffle 10 | 11 | 12 | from aizynthtrain.utils.configs import ( 13 | ExpansionModelPipelineConfig, 14 | load_config, 15 | ) 16 | from aizynthtrain.utils.keras_utils import ( 17 | InMemorySequence, 18 | setup_callbacks, 19 | train_keras_model, 20 | top10_acc, 21 | top50_acc, 22 | ) 23 | 24 | 25 | class ExpansionModelSequence(InMemorySequence): 26 | """ 27 | Custom sequence class to keep sparse, pre-computed matrices in memory. 28 | Batches are created dynamically by slicing the in-memory arrays 29 | The data will be shuffled on each epoch end 30 | 31 | :ivar output_dim: the output size (number of templates) 32 | """ 33 | 34 | def __init__( 35 | self, input_filename: str, output_filename: str, batch_size: int 36 | ) -> None: 37 | super().__init__(input_filename, output_filename, batch_size) 38 | self.output_dim = self.label_matrix.shape[1] 39 | 40 | def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray]: 41 | idx_ = self._make_slice(idx) 42 | return self.input_matrix[idx_].toarray(), self.label_matrix[idx_].toarray() 43 | 44 | def on_epoch_end(self) -> None: 45 | self.input_matrix, self.label_matrix = shuffle( 46 | self.input_matrix, self.label_matrix, random_state=0 47 | ) 48 | 49 | 50 | def main(args: Optional[Sequence[str]] = None) -> None: 51 | """Command-line tool for training the model""" 52 | parser = argparse.ArgumentParser("Tool to training an expansion network policy") 53 | parser.add_argument("config", help="the filename to a configuration file") 54 | args = parser.parse_args(args) 55 | 56 | config: ExpansionModelPipelineConfig = load_config( 57 | args.config, "expansion_model_pipeline" 58 | ) 59 | 60 | train_seq = ExpansionModelSequence( 61 | config.filename("model_inputs", "training"), 62 | config.filename("model_labels", "training"), 63 | config.model_hyperparams.batch_size, 64 | ) 65 | valid_seq = ExpansionModelSequence( 66 | config.filename("model_inputs", "validation"), 67 | config.filename("model_labels", "validation"), 68 | config.model_hyperparams.batch_size, 69 | ) 70 | 71 | model = Sequential() 72 | model.add( 73 | Dense( 74 | config.model_hyperparams.hidden_nodes, 75 | input_shape=(train_seq.input_dim,), 76 | activation="elu", 77 | kernel_regularizer=regularizers.l2(0.001), 78 | ) 79 | ) 80 | model.add(Dropout(config.model_hyperparams.dropout)) 81 | model.add(Dense(train_seq.output_dim, activation="softmax")) 82 | 83 | callbacks = setup_callbacks( 84 | config.filename("training_log"), config.filename("training_checkpoint") 85 | ) 86 | 87 | train_keras_model( 88 | model, 89 | train_seq, 90 | valid_seq, 91 | "categorical_crossentropy", 92 | ["accuracy", "top_k_categorical_accuracy", top10_acc, top50_acc], 93 | callbacks, 94 | config.model_hyperparams.epochs, 95 | ) 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /aizynthtrain/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthtrain/55df633bbdf971378e7f031e8e89190cde1b17fb/aizynthtrain/pipelines/__init__.py -------------------------------------------------------------------------------- /aizynthtrain/pipelines/chemformer_data_prep_pipeline.py: -------------------------------------------------------------------------------- 1 | """Module containing preparation of a dataset for Chemformer training / evaluation""" 2 | from pathlib import Path 3 | 4 | from metaflow import FlowSpec, Parameter, step 5 | from rxnutils.pipeline.runner import main as pipeline_runner 6 | from rxnutils.data.batch_utils import create_csv_batches, combine_csv_batches 7 | 8 | from aizynthtrain.modelling.chemformer.create_dataset_split import ( 9 | main as create_dataset_split, 10 | ) 11 | from aizynthtrain.utils.configs import ChemformerDataPrepConfig, load_config 12 | 13 | 14 | class ChemformerDataPrepFlow(FlowSpec): 15 | config_path = Parameter("config", required=True) 16 | 17 | @step 18 | def start(self): 19 | """Loading configuration""" 20 | self.config: ChemformerDataPrepConfig = load_config( 21 | self.config_path, "chemformer_data_prep" 22 | ) 23 | self.next(self.reaction_components_setup) 24 | 25 | @step 26 | def reaction_components_setup(self): 27 | """Preparing splits of the reaction data for extracting reaction components""" 28 | self.reaction_partitions = self._create_batches( 29 | self.config.selected_reactions_path, self.config.reaction_components_path 30 | ) 31 | self.next(self.reaction_components, foreach="reaction_partitions") 32 | 33 | @step 34 | def reaction_components(self): 35 | """ 36 | Running pipeline for specific steps: 37 | 1. Extracting reaction components (reactants / reagents / product) 38 | 2. Removing atom-mapping from components. 39 | """ 40 | pipeline_path = str( 41 | Path(__file__).parent / "data" / "chemformer_data_prep_pipeline.yaml" 42 | ) 43 | idx, start, end = self.input 44 | if idx > -1: 45 | pipeline_runner( 46 | [ 47 | "--pipeline", 48 | pipeline_path, 49 | "--data", 50 | self.config.selected_reactions_path, 51 | "--output", 52 | f"{self.config.reaction_components_path}.{idx}", 53 | "--max-workers", 54 | "1", 55 | "--batch", 56 | str(start), 57 | str(end), 58 | "--no-intermediates", 59 | ] 60 | ) 61 | self.next(self.reaction_components_join) 62 | 63 | @step 64 | def reaction_components_join(self, inputs): 65 | """Joining split reactions""" 66 | self.config = inputs[0].config 67 | self._combine_batches(self.config.reaction_components_path) 68 | self.next(self.dataset_split) 69 | 70 | @step 71 | def dataset_split(self): 72 | """Splitting data into train / validation / test sets""" 73 | create_dataset_split( 74 | [ 75 | "--config_path", 76 | self.config_path, 77 | ] 78 | ) 79 | self.next(self.end) 80 | 81 | @step 82 | def end(self): 83 | print( 84 | f"Processed Chemformer dataset is located here: {self.config.chemformer_data_path}" 85 | ) 86 | 87 | def _combine_batches(self, filename): 88 | combine_csv_batches(filename, self.config.nbatches) 89 | 90 | def _create_batches(self, input_filename, output_filename): 91 | return create_csv_batches(input_filename, self.config.nbatches, output_filename) 92 | 93 | 94 | if __name__ == "__main__": 95 | ChemformerDataPrepFlow() 96 | -------------------------------------------------------------------------------- /aizynthtrain/pipelines/chemformer_train_pipeline.py: -------------------------------------------------------------------------------- 1 | """Module containing training a synthesis Chemformer model""" 2 | import os 3 | import subprocess 4 | 5 | from metaflow import FlowSpec, Parameter, step 6 | 7 | from aizynthtrain.utils.configs import ChemformerTrainConfig, load_config 8 | 9 | 10 | class ChemformerTrainFlow(FlowSpec): 11 | config_path = Parameter("config", required=True) 12 | 13 | @step 14 | def start(self): 15 | """Loading configuration""" 16 | self.config: ChemformerTrainConfig = load_config( 17 | self.config_path, "chemformer_train" 18 | ) 19 | self.next(self.train_model) 20 | 21 | @step 22 | def train_model(self): 23 | """Running Chemformer fine-tuning using the chemformer environment""" 24 | args = self._prompt_from_config() 25 | 26 | cmd = f"python -m molbart.fine_tune {args}" 27 | 28 | if "CHEMFORMER_ENV" in os.environ: 29 | cmd = f"conda run -p {os.environ['CHEMFORMER_ENV']} " + cmd 30 | 31 | if "CONDA_PATH" in os.environ: 32 | cmd = os.environ["CONDA_PATH"] + cmd 33 | 34 | subprocess.check_call(cmd, shell=True) 35 | self.next(self.end) 36 | 37 | @step 38 | def end(self): 39 | print( 40 | f"Fine-tuned Chemformer model located here: {self.config.output_directory}" 41 | ) 42 | 43 | def _prompt_from_config(self) -> str: 44 | """Get argument string for running fine-tuning python script from config.""" 45 | args = [] 46 | for param in self.config: 47 | if param[0] == "model_path": 48 | value = param[1].replace("=", "\=") 49 | args.append(f"'{param[0]}={value}'") 50 | continue 51 | args.append(f"'{param[0]}={param[1]}'") 52 | 53 | prompt = " ".join(args) 54 | return prompt 55 | 56 | 57 | if __name__ == "__main__": 58 | ChemformerTrainFlow() 59 | -------------------------------------------------------------------------------- /aizynthtrain/pipelines/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthtrain/55df633bbdf971378e7f031e8e89190cde1b17fb/aizynthtrain/pipelines/data/__init__.py -------------------------------------------------------------------------------- /aizynthtrain/pipelines/data/autotag_data_extraction.yaml: -------------------------------------------------------------------------------- 1 | drop_columns: 2 | columns: 3 | - reactants 4 | rename_columns: 5 | in_columns: 6 | - products_untagged 7 | out_columns: 8 | - reactants -------------------------------------------------------------------------------- /aizynthtrain/pipelines/data/chemformer_data_prep_pipeline.yaml: -------------------------------------------------------------------------------- 1 | remove_atom_mapping: 2 | in_column: RxnSmilesClean 3 | out_column: RxnSmilesNoAtomMap 4 | split_reaction: 5 | in_column: RxnSmilesNoAtomMap 6 | out_columns: 7 | - reactants 8 | - reagents 9 | - products 10 | smiles_sanitizable: 11 | in_column: reactants 12 | out_column: is_reactant_sanitizable 13 | smiles_sanitizable2: 14 | in_column: reagents 15 | out_column: is_reagent_sanitizable 16 | smiles_sanitizable3: 17 | in_column: products 18 | out_column: is_product_sanitizable 19 | drop_rows: 20 | indicator_columns: 21 | - is_reactant_sanitizable 22 | - is_reagent_sanitizable 23 | - is_product_sanitizable 24 | drop_columns: 25 | columns: 26 | - RxnSmilesNoAtomMap 27 | -------------------------------------------------------------------------------- /aizynthtrain/pipelines/data/chemformer_product_tagging_pipeline.yaml: -------------------------------------------------------------------------------- 1 | atom_map_tag_disconnection_site: 2 | in_column: RxnSmilesClean 3 | out_column: products_atom_map_tagged 4 | convert_atom_map_disconnection_tag: 5 | in_column: products_atom_map_tagged 6 | out_column_tagged: products_tagged 7 | out_column_reconstructed: products_reconstructed 8 | smiles_sanitizable: 9 | in_column: products_reconstructed 10 | out_column: is_product_sanitizable 11 | smiles_sanitizable2: 12 | in_column: reactants 13 | out_column: is_reactant_sanitizable 14 | drop_rows: 15 | indicator_columns: 16 | - is_product_sanitizable 17 | - is_reactant_sanitizable 18 | drop_columns: 19 | columns: 20 | - is_product_sanitizable 21 | - is_reactant_sanitizable 22 | - products_reconstructed 23 | rename_columns: 24 | in_columns: 25 | - products_tagged 26 | - products 27 | out_columns: 28 | - products 29 | - products_untagged -------------------------------------------------------------------------------- /aizynthtrain/pipelines/data/reaction_validation_pipeline.yaml: -------------------------------------------------------------------------------- 1 | isotope_info: 2 | in_column: rsmi 3 | out_column: rsmi_processed 4 | remove_unsanitizable: 5 | in_column: rsmi_processed 6 | out_column: rsmi_processed 7 | reagents2reactants: 8 | in_column: rsmi_processed 9 | out_column: rsmi_processed 10 | reactants2reagents: 11 | in_column: rsmi_processed 12 | out_column: rsmi_processed 13 | remove_extra_atom_mapping: 14 | in_column: rsmi_processed 15 | out_column: rsmi_processed 16 | neutralize_molecules: 17 | in_column: rsmi_processed 18 | out_column: rsmi_processed 19 | remove_unsanitizable2: 20 | in_column: rsmi_processed 21 | out_column: rsmi_processed 22 | bad_column: BadMolecules2 23 | remove_unchanged_products: 24 | in_column: rsmi_processed 25 | out_column: rsmi_processed 26 | count_components: 27 | in_column: rsmi_processed 28 | pseudo_reaction_hash: 29 | in_column: rsmi_processed 30 | no_reagents: True 31 | count_elements: 32 | in_column: rsmi_processed 33 | productsize: 34 | in_column: rsmi_processed 35 | product_atommapping_stats: 36 | in_column: rsmi_processed 37 | hasunmappedradicalatom: 38 | in_column: rsmi_processed 39 | unsanitizablereactants: 40 | rsmi_column: rsmi_processed 41 | bad_columns: 42 | - BadMolecules 43 | - BadMolecules2 44 | maxrings: 45 | in_column: rsmi_processed 46 | ringnumberchange: 47 | in_column: rsmi_processed 48 | ringbondmade: 49 | in_column: rsmi_processed 50 | ringmadesize: 51 | in_column: rsmi_processed 52 | cgr_created: 53 | in_column: rsmi_processed 54 | cgr_dynamic_bonds: 55 | in_column: rsmi_processed -------------------------------------------------------------------------------- /aizynthtrain/pipelines/data/stereo_checks_pipeline.yaml: -------------------------------------------------------------------------------- 1 | has_stereo_info: 2 | in_column: rsmi_processed 3 | query_dataframe: 4 | query: HasStereo == True 5 | stereo_centre_changes: 6 | in_column: rsmi_processed 7 | stereo_chiral_reagent: 8 | in_column: rsmi_processed 9 | stereo_centre_created: 10 | stereo_centre_removed: 11 | potential_stereo_center: 12 | in_column: rsmi_processed 13 | stereo_centre_outside: 14 | in_column: rsmi_processed 15 | meso_product: 16 | in_column: rsmi_processed -------------------------------------------------------------------------------- /aizynthtrain/pipelines/data/template_validation_pipeline.yaml: -------------------------------------------------------------------------------- 1 | count_template_components: 2 | in_column: RetroTemplate 3 | retro_template_reproduction: 4 | template_column: RetroTemplate 5 | smiles_column: RxnSmilesClean -------------------------------------------------------------------------------- /aizynthtrain/pipelines/disconnection_chemformer_data_prep_pipeline.py: -------------------------------------------------------------------------------- 1 | """Module containing preparation of a dataset for Chemformer training / evaluation""" 2 | from pathlib import Path 3 | 4 | from metaflow import FlowSpec, Parameter, step 5 | from rxnutils.pipeline.runner import main as pipeline_runner 6 | 7 | from aizynthtrain.utils.configs import ChemformerDataPrepConfig, load_config 8 | from rxnutils.data.batch_utils import combine_csv_batches, create_csv_batches 9 | 10 | 11 | class DisconnectionChemformerDataPrepFlow(FlowSpec): 12 | config_path = Parameter("config", required=True) 13 | 14 | @step 15 | def start(self): 16 | """Loading configuration""" 17 | self.config: ChemformerDataPrepConfig = load_config( 18 | self.config_path, "chemformer_data_prep" 19 | ) 20 | self.next(self.tag_products_setup) 21 | 22 | @step 23 | def tag_products_setup(self): 24 | """Preparing splits of the preprocessed reaction data for tagging disconnection sites.""" 25 | self.reaction_partitions = self._create_batches( 26 | self.config.chemformer_data_path, self.config.disconnection_aware_data_path 27 | ) 28 | self.next(self.tag_products, foreach="reaction_partitions") 29 | 30 | @step 31 | def tag_products(self): 32 | """ 33 | Running pipeline for tagging disconnection sites in products. 34 | """ 35 | pipeline_path = str( 36 | Path(__file__).parent / "data" / "chemformer_product_tagging_pipeline.yaml" 37 | ) 38 | idx, start, end = self.input 39 | if idx > -1: 40 | pipeline_runner( 41 | [ 42 | "--pipeline", 43 | pipeline_path, 44 | "--data", 45 | self.config.chemformer_data_path, 46 | "--output", 47 | f"{self.config.disconnection_aware_data_path}.{idx}", 48 | "--max-workers", 49 | "1", 50 | "--batch", 51 | str(start), 52 | str(end), 53 | "--no-intermediates", 54 | ] 55 | ) 56 | self.next(self.tag_products_join) 57 | 58 | @step 59 | def tag_products_join(self, inputs): 60 | """Joining split reactions""" 61 | self.config = inputs[0].config 62 | self._combine_batches(self.config.disconnection_aware_data_path) 63 | self.next(self.create_autotag_dataset_setup) 64 | 65 | @step 66 | def create_autotag_dataset_setup(self): 67 | """Preparing splits of the preprocessed reaction data for tagging disconnection sites.""" 68 | self.reaction_partitions = self._create_batches( 69 | self.config.disconnection_aware_data_path, self.config.autotag_data_path 70 | ) 71 | self.next(self.create_autotag_dataset, foreach="reaction_partitions") 72 | 73 | @step 74 | def create_autotag_dataset(self): 75 | 76 | pipeline_path = str( 77 | Path(__file__).parent / "data" / "autotag_data_extraction.yaml" 78 | ) 79 | idx, start, end = self.input 80 | if idx > -1: 81 | pipeline_runner( 82 | [ 83 | "--pipeline", 84 | pipeline_path, 85 | "--data", 86 | self.config.disconnection_aware_data_path, 87 | "--output", 88 | f"{self.config.autotag_data_path}.{idx}", 89 | "--max-workers", 90 | "1", 91 | "--no-intermediates", 92 | "--batch", 93 | str(start), 94 | str(end), 95 | ] 96 | ) 97 | self.next(self.create_autotag_dataset_join) 98 | 99 | @step 100 | def create_autotag_dataset_join(self, inputs): 101 | """Joining split reactions""" 102 | self.config = inputs[0].config 103 | self._combine_batches(self.config.autotag_data_path) 104 | self.next(self.end) 105 | 106 | @step 107 | def end(self): 108 | print( 109 | f"Disconnection-aware dataset is located here: {self.config.disconnection_aware_data_path}" 110 | ) 111 | print(f"Autotag dataset is located here: {self.config.autotag_data_path}") 112 | 113 | def _combine_batches(self, filename): 114 | combine_csv_batches(filename, self.config.nbatches) 115 | 116 | def _create_batches(self, input_filename, output_filename): 117 | return create_csv_batches(input_filename, self.config.nbatches, output_filename) 118 | 119 | 120 | if __name__ == "__main__": 121 | DisconnectionChemformerDataPrepFlow() 122 | -------------------------------------------------------------------------------- /aizynthtrain/pipelines/expansion_model_pipeline.py: -------------------------------------------------------------------------------- 1 | """Module containing a pipeline for training an expansion model""" 2 | from pathlib import Path 3 | 4 | from metaflow import FlowSpec, Parameter, step 5 | from rxnutils.data.batch_utils import combine_sparse_matrix_batches, create_csv_batches 6 | 7 | from aizynthtrain.utils.onnx_converter import main as convert_to_onnx 8 | from aizynthtrain.modelling.expansion_policy.create_template_lib import ( 9 | main as create_template_lib, 10 | ) 11 | from aizynthtrain.modelling.expansion_policy.eval_multi_step import ( 12 | main as eval_multi_step, 13 | ) 14 | from aizynthtrain.modelling.expansion_policy.eval_one_step import main as eval_one_step 15 | from aizynthtrain.modelling.expansion_policy.featurize import main as featurize 16 | from aizynthtrain.modelling.expansion_policy.split_data import main as split_data 17 | from aizynthtrain.modelling.expansion_policy.training import main as training_runner 18 | from aizynthtrain.utils.configs import ExpansionModelPipelineConfig, load_config 19 | from aizynthtrain.utils.reporting import main as report_runner 20 | 21 | 22 | class ExpansionModelFlow(FlowSpec): 23 | config_path = Parameter("config", required=True) 24 | 25 | @step 26 | def start(self): 27 | """Loading configuration""" 28 | self.config: ExpansionModelPipelineConfig = load_config( 29 | self.config_path, "expansion_model_pipeline" 30 | ) 31 | self.next(self.create_template_metadata) 32 | 33 | @step 34 | def create_template_metadata(self): 35 | """Preprocess the template library for model training""" 36 | if not Path(self.config.filename("unique_templates")).exists(): 37 | create_template_lib([self.config_path]) 38 | self.next(self.featurization_setup) 39 | 40 | @step 41 | def featurization_setup(self): 42 | """Preparing splits of the reaction data for feauturization""" 43 | self.reaction_partitions = self._create_batches( 44 | self.config.filename("template_code"), self.config.filename("model_labels") 45 | ) 46 | self.next(self.featurization, foreach="reaction_partitions") 47 | 48 | @step 49 | def featurization(self): 50 | """Featurization of reaction data""" 51 | idx, start, end = self.input 52 | if idx > -1: 53 | featurize( 54 | [ 55 | self.config_path, 56 | "--batch", 57 | str(idx), 58 | str(start), 59 | str(end), 60 | ] 61 | ) 62 | self.next(self.featurization_join) 63 | 64 | @step 65 | def featurization_join(self, inputs): 66 | """Joining featurized data""" 67 | self.config = inputs[0].config 68 | self._combine_batches(self.config.filename("model_labels")) 69 | self._combine_batches(self.config.filename("model_inputs")) 70 | self.next(self.split_data) 71 | 72 | @step 73 | def split_data(self): 74 | """Split featurized data into training, validation and testing""" 75 | if not Path(self.config.filename("model_labels", "training")).exists(): 76 | split_data([self.config_path]) 77 | self.next(self.model_training) 78 | 79 | @step 80 | def model_training(self): 81 | """Train the expansion model""" 82 | if not Path(self.config.filename("training_checkpoint")).exists(): 83 | training_runner([self.config_path]) 84 | self.next(self.onnx_converter) 85 | 86 | @step 87 | def onnx_converter(self): 88 | """Convert the trained Keras model to ONNX""" 89 | convert_to_onnx( 90 | [ 91 | self.config.filename("training_checkpoint"), 92 | self.config.filename("onnx_model"), 93 | ] 94 | ) 95 | self.next(self.model_validation) 96 | 97 | @step 98 | def model_validation(self): 99 | """Validate the trained model""" 100 | eval_one_step([self.config_path]) 101 | eval_multi_step([self.config_path]) 102 | notebook_path = str( 103 | Path(__file__).parent / "notebooks" / "expansion_model_val.py" 104 | ) 105 | report_runner( 106 | [ 107 | "--notebook", 108 | notebook_path, 109 | "--report_path", 110 | self.config.filename("report"), 111 | "--python_kernel", 112 | self.config.python_kernel, 113 | "--validation_metrics_filename", 114 | self.config.filename("training_log"), 115 | "--onestep_report", 116 | self.config.filename("onestep_report"), 117 | "--multistep_report", 118 | self.config.filename("multistep_report"), 119 | ] 120 | ) 121 | self.next(self.end) 122 | 123 | @step 124 | def end(self): 125 | print( 126 | f"Report on trained model is located here: {self.config.filename('report')}" 127 | ) 128 | 129 | def _combine_batches(self, filename): 130 | combine_sparse_matrix_batches(filename, self.config.nbatches) 131 | 132 | def _create_batches(self, input_filename, output_filename): 133 | return create_csv_batches(input_filename, self.config.nbatches, output_filename) 134 | 135 | 136 | if __name__ == "__main__": 137 | ExpansionModelFlow() 138 | -------------------------------------------------------------------------------- /aizynthtrain/pipelines/notebooks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthtrain/55df633bbdf971378e7f031e8e89190cde1b17fb/aizynthtrain/pipelines/notebooks/__init__.py -------------------------------------------------------------------------------- /aizynthtrain/pipelines/notebooks/expansion_model_val.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import json 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import matplotlib.pylab as plt 7 | from IPython.display import Markdown 8 | 9 | pd.options.display.float_format = "{:,.2f}".format 10 | print_ = lambda x: display(Markdown(x)) 11 | 12 | # %% tags=["parameters"] 13 | validation_metrics_filename = "" 14 | onestep_report = "" 15 | multistep_report = "" 16 | 17 | # %% [markdown] 18 | """ 19 | ## Statistics on expansion model training 20 | """ 21 | 22 | # %% 23 | val_data = pd.read_csv(validation_metrics_filename) 24 | val_data.tail() 25 | 26 | # %% 27 | print_("Convergence of validation loss and accuracy") 28 | fig = plt.figure() 29 | ax = fig.gca() 30 | ax2 = ax.twinx() 31 | val_data.plot(x="epoch", y="val_loss", ax=ax, legend=False) 32 | val_data.plot(x="epoch", y="val_accuracy", style="g", ax=ax2, legend=False) 33 | _ = fig.legend(loc="center left", bbox_to_anchor=(1.0, 0.5)) 34 | 35 | # %% 36 | print_("Metrics at the last epoch") 37 | for key, val in val_data.iloc[-1].to_dict().items(): 38 | if key == "epoch": 39 | continue 40 | print_(f"- {key} = {val:.2f}") 41 | 42 | # %% [markdown] 43 | """ 44 | ## Evaluation of multi-step retrosynthesis capabilities 45 | """ 46 | 47 | # %% 48 | if multistep_report: 49 | with open(multistep_report, "r") as fileobj: 50 | multistep_stats = json.load(fileobj) 51 | else: 52 | multistep_stats = None 53 | 54 | # %% [markdown] 55 | """ 56 | ### Route finding capabilities 57 | """ 58 | 59 | # %% 60 | if multistep_stats: 61 | pd_stats = pd.DataFrame(multistep_stats["finding"]) 62 | print_(f"Average first solution time: {pd_stats['first solution time'].mean():.2f}") 63 | print_(f"Average number of solved target: {pd_stats['is solved'].mean()*100:.2f}%") 64 | else: 65 | print_("Route finding capabilities not evaluated") 66 | 67 | 68 | # %% [markdown] 69 | """ 70 | ### Route recovery capabilities 71 | """ 72 | 73 | # %% 74 | if multistep_stats: 75 | pd_stats = pd.DataFrame(multistep_stats["recovery"]) 76 | display(pd_stats) 77 | 78 | print_(f"Average number of solved target: {pd_stats['is solved'].mean()*100:.2f}%") 79 | print_(f"Average found reference: {pd_stats['found reference'].mean()*100:.2f}%") 80 | print_( 81 | f"Average closest to reference: {pd_stats['closest to reference'].mean():.2f}" 82 | ) 83 | print_(f"Average rank of closest: {pd_stats['rank of closest'].mean():.2f}") 84 | else: 85 | print_("Route recovery capabilities not evaluated") 86 | 87 | 88 | # %% [markdown] 89 | """ 90 | ## Evaluation of one-step retrosynthesis capabilities 91 | """ 92 | 93 | # %% 94 | stats = pd.read_json(onestep_report) 95 | display(stats) 96 | 97 | print_(f"Average found expected: {stats['found expected'].mean()*100:.2f}%") 98 | print_(f"Average rank of expected: {stats['rank of expected'].mean():.2f}") 99 | print_(f"Average ring broken when expected: {stats['ring broken'].mean()*100:.2f}%") 100 | print_(f"Percentage of ring reactions: {stats['ring breaking'].mean()*100:.2f}%") 101 | print_(f"Average non-applicable (in top-50): {stats['non-applicable'].mean():.2f}") 102 | -------------------------------------------------------------------------------- /aizynthtrain/pipelines/notebooks/reaction_selection.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import json 3 | from collections import defaultdict, Counter 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from mendeleev import element as Element 8 | from IPython.display import Markdown 9 | 10 | pd.options.display.float_format = "{:,.2f}".format 11 | print_ = lambda x: display(Markdown(x)) 12 | 13 | # %% tags=["parameters"] 14 | input_filename = "" 15 | output_filename = "" 16 | 17 | # %% 18 | data = pd.read_csv(input_filename, sep="\t") 19 | 20 | 21 | # %% 22 | def make_stats_by_sources(data): 23 | sources = sorted(set(data.source)) 24 | source_stats = defaultdict(list) 25 | for source in sources: 26 | source_stats["source"].append(source) 27 | df = data[data.source == source] 28 | source_stats["nreactions"].append(len(df)) 29 | source_stats["% of total"].append(len(df) / len(data) * 100) 30 | unique_hashes = set(df["PseudoHash"]) 31 | source_stats["n unique reactions"].append(len(unique_hashes)) 32 | source_stats["% unique"].append(len(unique_hashes) / len(df) * 100) 33 | source_stats["% unclassified"].append( 34 | df["classification"].str.startswith("0.0 ").mean() * 100 35 | ) 36 | df2 = df[~df.date.isna()] 37 | dates = pd.to_datetime(df2.date, infer_datetime_format=True) 38 | source_stats["from"].append(dates.min()) 39 | source_stats["to"].append(dates.max()) 40 | 41 | if len(sources) > 1: 42 | source_stats["source"].append("all") 43 | source_stats["nreactions"].append(sum(source_stats["nreactions"])) 44 | source_stats["% of total"].append("100") 45 | source_stats["n unique reactions"].append(len(set(data["PseudoHash"]))) 46 | source_stats["% unique"].append( 47 | source_stats["n unique reactions"][-1] 48 | / source_stats["nreactions"][-1] 49 | * 100 50 | ) 51 | source_stats["% unclassified"].append( 52 | data["classification"].str.startswith("0.0 ").mean() * 100 53 | ) 54 | source_stats["from"].append(min(source_stats["from"])) 55 | source_stats["to"].append(max(source_stats["to"])) 56 | return source_stats 57 | 58 | 59 | # %% [markdown] 60 | """ 61 | ## Statistics on extracted and validated reactions 62 | """ 63 | 64 | # %% 65 | sel = data["classification"].str.startswith("0.0 ") 66 | print_(f"Total number of extracted reactions = {len(data)}") 67 | print_(f"Number of unrecognized reactions = {sel.sum()} ({sel.mean()*100:.2f}%)") 68 | nisotopes = (~data["Isotope"].isna()).sum() 69 | print_(f"Number of reactions with isotope information = {nisotopes}") 70 | ndeuterium = (data["Isotope"] == "2H").sum() 71 | print_(f"Number of reactions with deuterium = {ndeuterium}") 72 | 73 | bad_molecules = list() 74 | sel = ~data.BadMolecules.isna() 75 | for mol_list in data[sel].BadMolecules: 76 | bad_molecules.extend(mol_list.split(",")) 77 | sel2 = ~data.BadMolecules2.isna() 78 | for mol_list in data[sel2].BadMolecules2: 79 | bad_molecules.extend(mol_list.split(",")) 80 | bad_molecules = set(bad_molecules) 81 | print_( 82 | f"Identified {len(bad_molecules)} molecules that could not be sanitized by RDKit this affected {sel.sum() + sel2.sum()} reactions" 83 | ) 84 | print("") 85 | pd.DataFrame(make_stats_by_sources(data)) 86 | 87 | # %% 88 | ax = data.NReactants.value_counts().sort_index().plot.bar() 89 | _ = ax.set_title("Distribution of number of reactants") 90 | 91 | # %% 92 | ax = data.NReagents.value_counts().sort_index().plot.bar() 93 | _ = ax.set_title("Distribution of number of reagents") 94 | 95 | # %% 96 | ax = data.NProducts.value_counts().sort_index().plot.bar() 97 | _ = ax.set_title("Distribution of number of products") 98 | 99 | # %% [markdown] 100 | """ 101 | ## Removing reactions based on these filters 102 | 103 | - no. products should be 1 104 | - no. reactants should be between 1 and 5 105 | - no. un-mapped product atoms should be less than 3 106 | - no. of widow atoms should be less than 5 107 | - all reactants should be sanitizable 108 | """ 109 | 110 | # %% 111 | NREACTANTS_LIMIT = 5 112 | UNMAPPED_PROD_LIMIT = 3 113 | WIDOWS_LIMIT = 5 114 | 115 | sel_too_many_prod = data["NProducts"] > 1 116 | sel_too_few_prod = (data["NProducts"] == 0) | (data["NMappedProducts"] == 0) 117 | sel_too_few_react = (data["NReactants"] == 0) | (data["NMappedReactants"] == 0) 118 | sel_too_many_reactants = data["NMappedReactants"] > NREACTANTS_LIMIT 119 | sel_too_many_unmapped = data["UnmappedProdAtoms"] > UNMAPPED_PROD_LIMIT 120 | sel_too_many_widows = data["WidowAtoms"] > WIDOWS_LIMIT 121 | sel_unsani_react = data["HasUnsanitizableReactants"] 122 | 123 | print_(f"Removing {sel_too_many_prod.sum()} reactions with more than one product") 124 | print_(f"Removing {sel_too_few_prod.sum()} reactions without any product") 125 | print_(f"Removing {sel_too_few_react.sum()} reactions without any reactant") 126 | print_( 127 | f"Removing {sel_too_many_reactants.sum()} reactions with more than {NREACTANTS_LIMIT} reactants" 128 | ) 129 | print_( 130 | f"Removing {sel_too_many_unmapped.sum()} reactions with more than {UNMAPPED_PROD_LIMIT} unmapped product atoms" 131 | ) 132 | print_( 133 | f"Removing {sel_too_many_widows.sum()} reactions with more than {WIDOWS_LIMIT} widow atoms" 134 | ) 135 | print_(f"Removing {sel_unsani_react.sum()} reactions with unsanitizable reactants") 136 | 137 | data = data[ 138 | (~sel_too_many_prod) 139 | & (~sel_too_few_prod) 140 | & (~sel_too_few_react) 141 | & (~sel_too_many_reactants) 142 | & (~sel_too_many_unmapped) 143 | & (~sel_too_many_widows) 144 | & (~sel_unsani_react) 145 | ] 146 | 147 | # %% [markdown] 148 | """ 149 | ## Removing reactions based on these filters 150 | 151 | - the reactants must be different than the products 152 | - no wild card atoms 153 | - the no the product atoms should be between the 3% and 97% percentile 154 | - there should not be any un-mapped radicals 155 | - the elements in the reactants should be likely 156 | - the reaction class should not be 11. or 12. 157 | - the CGR should be possible to create 158 | - the number of dynamic bonds should not be more than 10 159 | - the reaction should not have more than 2 sibling reactions 160 | - the reaction should not have more than 12 rings 161 | """ 162 | 163 | 164 | # %% 165 | def unchanged_reactions(row): 166 | reactants, products = row["PseudoHash"].split(">>") 167 | return reactants == products 168 | 169 | 170 | sel_unchanged = data.apply(unchanged_reactions, axis=1) 171 | print_(f"Removing {sel_unchanged.sum()} reactions with the same reactants and products") 172 | 173 | sel_wildcard_atom = data["rsmi_processed"].str.contains("*", regex=False) 174 | print_(f"Removing {sel_wildcard_atom.sum()} reactions with wild card atoms") 175 | 176 | percentile03 = data.ProductSize.quantile(0.03) 177 | percentile97 = data.ProductSize.quantile(0.97) 178 | sel_small = data.ProductSize < percentile03 179 | sel_big = data.ProductSize > percentile97 180 | print_( 181 | f"Removing {sel_small.sum()} reactions with product size smaller than {percentile03} (3% percentile)" 182 | ) 183 | print_( 184 | f"Removing {sel_big.sum()} reactions with product size larger than {percentile97} (97% percentile)" 185 | ) 186 | 187 | sel_radical = data["HasUnmappedRadicalAtom"] 188 | print_(f"Removing {sel_radical.sum()} reactions with unmapped radical") 189 | 190 | # %% 191 | LIKELIHOOD_LIMIT = 1e-5 192 | 193 | 194 | def reject_based_on_likelihood(row, likelihoods, limit): 195 | if row.classification != "0.0 Unrecognized": 196 | return False 197 | atomic_counts = json.loads(row.ElementCount) 198 | return any( 199 | likelihoods[atomic_number] < limit for atomic_number in atomic_counts.keys() 200 | ) 201 | 202 | 203 | counts_by_atomic_number = defaultdict(int) 204 | for count_str in data.ElementCount.values: 205 | for atomic_number, count in json.loads(count_str).items(): 206 | counts_by_atomic_number[atomic_number] += count 207 | 208 | sum_counts = sum(counts_by_atomic_number.values()) 209 | likelihoods = { 210 | key: value / sum_counts for key, value in counts_by_atomic_number.items() 211 | } 212 | sel_likelihood = data.apply( 213 | reject_based_on_likelihood, axis=1, likelihoods=likelihoods, limit=LIKELIHOOD_LIMIT 214 | ) 215 | print_( 216 | f"Removing {sel_likelihood.sum()} reactions unusual elements in mapped reactants" 217 | ) 218 | elements = ", ".join( 219 | sorted( 220 | { 221 | Element(int(key)).symbol 222 | for key, value in likelihoods.items() 223 | if key != "0" and value < LIKELIHOOD_LIMIT 224 | } 225 | ) 226 | ) 227 | print_(f"Elements considered to be unlikely: {elements}") 228 | 229 | # %% 230 | unwanted_classes = data.classification.str.startswith(("11.", "12.")) 231 | print_( 232 | f"Removing {unwanted_classes.sum()} reactions with unwanted reaction classes (11., 12.)" 233 | ) 234 | 235 | sel_cgr_creation = ~data.CGRCreated 236 | print_( 237 | f"Removing {sel_cgr_creation.sum()} reactions where the CGR could not be created" 238 | ) 239 | 240 | DYNAMIC_BOND_LIMIT = 10 241 | sel_dynamic_bond = data.NDynamicBonds > DYNAMIC_BOND_LIMIT 242 | print_( 243 | f"Removing {sel_dynamic_bond.sum()} reactions where the number of dynamic bonds are greater than {DYNAMIC_BOND_LIMIT}" 244 | ) 245 | 246 | # %% 247 | info = data["id"].str.extract(r"_P(?P\d)$", expand=False) 248 | prod_val = info[~info.isna()].astype(int) 249 | sib_sel = np.zeros(len(data)).astype("bool") 250 | if len(prod_val) > 0 and "document_id" in data.columns: 251 | prod_sel = prod_val > 2 252 | data_sel = data.index.isin(prod_val[prod_sel].index) 253 | unique_doc = ( 254 | data[data_sel] 255 | .apply(lambda row: row["id"].split(f"_{row['source']}_")[0], axis=1) 256 | .unique() 257 | ) 258 | sib_sel = data["document_id"].isin(unique_doc) 259 | print_(f"Removing {sib_sel.sum()} reactions with with more than 2 siblings") 260 | 261 | # %% 262 | NRING_LIMIT = 12 263 | nrings_sel = data["MaxRings"] > NRING_LIMIT 264 | print_(f"Removing {nrings_sel.sum()} reactions with with more than {NRING_LIMIT} rings") 265 | 266 | 267 | # %% 268 | data = data[ 269 | (~sel_likelihood) 270 | & (~sel_small) 271 | & (~sel_big) 272 | & (~sel_wildcard_atom) 273 | & (~sel_unchanged) 274 | & (~sel_radical) 275 | & (~unwanted_classes) 276 | & (~sel_cgr_creation) 277 | & (~sel_dynamic_bond) 278 | & (~sib_sel) 279 | & (~nrings_sel) 280 | ] 281 | 282 | # %% 283 | print_("Statistics by source after filtering") 284 | pd.DataFrame(make_stats_by_sources(data)) 285 | 286 | # %% [markdown] 287 | """ 288 | ## Ring breaker 289 | Finding reactions to be treated with ring breaker 290 | """ 291 | 292 | # %% 293 | data["RingBreaker"] = ( 294 | data["RingBondMade"] 295 | & (data["NRingChange"] >= 1) 296 | & (data["RingMadeSize"] > 2) 297 | & (data["RingMadeSize"] < 8) 298 | ) 299 | nringbreaker = data["RingBreaker"].sum() 300 | print_( 301 | f"Marked {nringbreaker} ({data['RingBreaker'].mean()*100:.2f}) reactions for ring breaker" 302 | ) 303 | nring_sizes = [ 304 | (data[data["RingBreaker"]]["RingMadeSize"] == n).sum() for n in range(3, 8) 305 | ] 306 | pd.DataFrame( 307 | { 308 | "size of ring formed": range(3, 8), 309 | "n reactions": nring_sizes, 310 | "% of total": [n / nringbreaker * 100 for n in nring_sizes], 311 | } 312 | ) 313 | 314 | # %% [markdown] 315 | """ 316 | ## Finalizing 317 | """ 318 | 319 | # %% 320 | filename = output_filename.replace(".csv", "_all.csv") 321 | print_(f"Saving all selected reactions with all columns to {filename}") 322 | data.to_csv(filename, index=False, sep="\t") 323 | 324 | # %% 325 | hash_to_id = data.groupby("PseudoHash")["id"].apply(lambda x: list(x.values)).to_dict() 326 | filename = output_filename.replace(".csv", "_ids.json") 327 | with open(filename, "w") as fileobj: 328 | json.dump(hash_to_id, fileobj) 329 | print_(f"Saving translation from reaction hash to reaction connect id to {filename}") 330 | 331 | # %% 332 | print_("Replacing classification with most common classification for each reaction") 333 | most_common_classification_by_hash = ( 334 | data.groupby("PseudoHash")["classification"] 335 | .apply(lambda x: Counter(x.values).most_common(1)[0][0]) 336 | .to_dict() 337 | ) 338 | data["classification"] = data.apply( 339 | lambda row: most_common_classification_by_hash[row["PseudoHash"]], axis=1 340 | ) 341 | 342 | 343 | # %% 344 | data.rename(columns={"rsmi_processed": "RxnSmilesClean"}, inplace=True) 345 | data = data[["RxnSmilesClean", "PseudoHash", "classification", "RingBreaker"]] 346 | data = data.drop_duplicates(subset="PseudoHash") 347 | print_( 348 | f"Saving clean reaction SMILES, reaction hash, classification and ring breaker flag for selected reactions to {output_filename}" 349 | ) 350 | data.to_csv(output_filename, index=False, sep="\t") 351 | -------------------------------------------------------------------------------- /aizynthtrain/pipelines/notebooks/stereo_selection.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import json 3 | from collections import defaultdict, Counter 4 | 5 | import pandas as pd 6 | from IPython.display import Markdown 7 | 8 | pd.options.display.float_format = "{:,.2f}".format 9 | print_ = lambda x: display(Markdown(x)) 10 | 11 | # %% tags=["parameters"] 12 | input_filename = "" 13 | output_filename = "" 14 | 15 | # %% 16 | data = pd.read_csv(input_filename, sep="\t") 17 | 18 | # %% 19 | def make_stats_by_sources(data, extra_columns=None): 20 | extra_columns = extra_columns or [] 21 | sources = sorted(set(data.source)) 22 | source_stats = defaultdict(list) 23 | for source in sources: 24 | source_stats["source"].append(source) 25 | df = data[data.source == source] 26 | source_stats["nreactions"].append(len(df)) 27 | source_stats["% of total"].append(len(df) / len(data) * 100) 28 | unique_hashes = set(df["PseudoHash"]) 29 | source_stats["n unique reactions"].append(len(unique_hashes)) 30 | source_stats["% unique"].append(len(unique_hashes) / len(df) * 100) 31 | has_changes = ~df.StereoChanges.isna() 32 | source_stats["n reactions with stereo changes in centre"].append( 33 | has_changes.sum() 34 | ) 35 | source_stats["% with stereo changes in centre"].append(has_changes.mean() * 100) 36 | source_stats["% unclassified"].append( 37 | df["classification"].str.startswith("0.0 ").mean() * 100 38 | ) 39 | for column in extra_columns: 40 | source_stats[column].append(df[column].sum()) 41 | df2 = df[~df.date.isna()] 42 | dates = pd.to_datetime(df2.date, infer_datetime_format=True) 43 | source_stats["from"].append(dates.min()) 44 | source_stats["to"].append(dates.max()) 45 | 46 | if len(sources) > 1: 47 | source_stats["source"].append("all") 48 | source_stats["nreactions"].append(sum(source_stats["nreactions"])) 49 | source_stats["% of total"].append("100") 50 | source_stats["n unique reactions"].append(len(set(data["PseudoHash"]))) 51 | source_stats["% unique"].append( 52 | source_stats["n unique reactions"][-1] 53 | / source_stats["nreactions"][-1] 54 | * 100 55 | ) 56 | has_changes = ~data.StereoChanges.isna() 57 | source_stats["n reactions with stereo changes in centre"].append( 58 | has_changes.sum() 59 | ) 60 | source_stats["% with stereo changes in centre"].append(has_changes.mean() * 100) 61 | for column in extra_columns: 62 | source_stats[column].append(data[column].sum()) 63 | source_stats["% unclassified"].append( 64 | data["classification"].str.startswith("0.0 ").mean() * 100 65 | ) 66 | source_stats["from"].append(min(source_stats["from"])) 67 | source_stats["to"].append(max(source_stats["to"])) 68 | return source_stats 69 | 70 | 71 | # %% [markdown] 72 | """ 73 | ## Statistics on extracted and validated reactions 74 | """ 75 | 76 | # %% 77 | print_(f"Total number of extracted reactions = {len(data)}") 78 | print("") 79 | pd.DataFrame( 80 | make_stats_by_sources( 81 | data, 82 | extra_columns=[ 83 | "HasChiralReagent", 84 | "StereoCentreCreated", 85 | "StereoCentreRemoved", 86 | "PotentialStereoCentre", 87 | "StereoOutside", 88 | "MesoProduct", 89 | ], 90 | ) 91 | ) 92 | 93 | 94 | # %% 95 | sel = data.StereoChanges.isna() 96 | print( 97 | f"Removing {sel.sum()} reactions where a stereo centre did not change during the reaction" 98 | ) 99 | data = data[~sel] 100 | 101 | # %% [markdown] 102 | """ 103 | ## Categories of stereochemical reactions 104 | Finding reactions for buckets: 105 | - Stereospecific 106 | - Reagent controlled 107 | - Substrate controlled 108 | """ 109 | 110 | # %% 111 | ss_sel = ~(data["StereoCentreCreated"] | data["StereoCentreRemoved"]) 112 | ss_sel = ss_sel & (~data["MesoProduct"]) 113 | 114 | rc_sel = ( 115 | data["StereoCentreCreated"] 116 | & (~data["PotentialStereoCentre"]) 117 | & (~data["StereoOutside"]) 118 | & (data["HasChiralReagent"]) 119 | ) 120 | rc_sel = rc_sel & (~data["MesoProduct"]) 121 | 122 | sc_sel = ( 123 | data["StereoCentreCreated"] 124 | & (~data["PotentialStereoCentre"]) 125 | & (data["StereoOutside"]) 126 | & (~data["HasChiralReagent"]) 127 | ) 128 | sc_sel = sc_sel & (~data["MesoProduct"]) 129 | 130 | data = data.assign( 131 | IsStereoSpecific=ss_sel, 132 | IsReagentControlled=rc_sel, 133 | IsSubstrateControlled=sc_sel, 134 | StereoBucket="", 135 | ) 136 | data.loc[ss_sel, "StereoBucket"] = "stereospecific" 137 | data.loc[rc_sel, "StereoBucket"] = "reagent controlled" 138 | data.loc[sc_sel, "StereoBucket"] = "substrate controlled" 139 | 140 | # %% 141 | sel = data["StereoBucket"] == "" 142 | print(f"Removing {sel.sum()} reactions that does not fit into any category") 143 | data = data[~sel] 144 | 145 | # %% 146 | print_("Statistics by source after filtering and categorizing") 147 | pd.DataFrame( 148 | make_stats_by_sources( 149 | data, 150 | extra_columns=[ 151 | "IsStereoSpecific", 152 | "IsReagentControlled", 153 | "IsSubstrateControlled", 154 | ], 155 | ) 156 | ) 157 | 158 | # %% [markdown] 159 | """ 160 | ## Finalizing 161 | """ 162 | 163 | # %% 164 | filename = output_filename.replace(".csv", "_all.csv") 165 | print_(f"Saving all selected reactions with all columns to {filename}") 166 | data.to_csv(filename, index=False, sep="\t") 167 | 168 | # %% 169 | hash_to_id = data.groupby("PseudoHash")["id"].apply(lambda x: list(x.values)).to_dict() 170 | filename = output_filename.replace(".csv", "_ids.json") 171 | with open(filename, "w") as fileobj: 172 | json.dump(hash_to_id, fileobj) 173 | print_(f"Saving translation from reaction hash to reaction connect id to {filename}") 174 | 175 | # %% 176 | print_("Replacing classification with most common classification for each reaction") 177 | most_common_classification_by_hash = ( 178 | data.groupby("PseudoHash")["classification"] 179 | .apply(lambda x: Counter(x.values).most_common(1)[0][0]) 180 | .to_dict() 181 | ) 182 | data["classification"] = data.apply( 183 | lambda row: most_common_classification_by_hash[row["PseudoHash"]], axis=1 184 | ) 185 | 186 | 187 | # %% 188 | data.rename(columns={"rsmi_processed": "RxnSmilesClean"}, inplace=True) 189 | data = data[ 190 | ["RxnSmilesClean", "PseudoHash", "classification", "RingBreaker", "StereoBucket"] 191 | ] 192 | data = data.drop_duplicates(subset="PseudoHash") 193 | print_( 194 | f"Saving clean reaction SMILES, reaction hash, classification, " 195 | f"ring breaker flag and stereo bucket for selected reactions to {output_filename}", 196 | ) 197 | data.to_csv(output_filename, index=False, sep="\t") 198 | -------------------------------------------------------------------------------- /aizynthtrain/pipelines/notebooks/stereo_template_selection.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from collections import Counter, defaultdict 3 | 4 | import pandas as pd 5 | import matplotlib.pylab as plt 6 | from IPython.display import Markdown 7 | 8 | from aizynthtrain.utils.files import prefix_filename 9 | from aizynthtrain.utils.configs import ( 10 | load_config, 11 | TemplateLibraryColumnsConfig, 12 | ) 13 | 14 | pd.options.display.float_format = "{:,.2f}".format 15 | print_ = lambda x: display(Markdown(x)) 16 | 17 | # %% tags=["parameters"] 18 | input_filename = "" 19 | output_prefix = "" 20 | output_postfix = "" 21 | min_occurrence = 3 22 | config_filename = "" 23 | 24 | # %% 25 | data = pd.read_csv(input_filename, sep="\t") 26 | col_config: TemplateLibraryColumnsConfig = load_config( 27 | config_filename, "template_library_columns" 28 | ) 29 | 30 | # %% [markdown] 31 | """ 32 | ## Statistics on extracted and validated templates 33 | """ 34 | 35 | # %% 36 | sel = data["TemplateError"].isna() 37 | print_(f"Total number of extracted templates = {sel.sum()} ({sel.mean():.2f}%)") 38 | print_(f"Number of unique templates = {len(set(data[sel]['TemplateHash']))}") 39 | print_( 40 | f"Number of templates failures = {len(data)-sel.sum()} ({1-sel.sum()/len(data):.5f}%)" 41 | ) 42 | print_( 43 | f"Number of templates with flipped chirality = {data[sel]['FlippedStereo'].sum()}" 44 | ) 45 | 46 | counts = Counter(data[data["RetroTemplate"].isna()]["TemplateError"]) 47 | cat_counts = defaultdict(int) 48 | for key, count in counts.items(): 49 | if "too many unmapped atoms" in key: 50 | cat_counts["Too many unmapped atoms"] += count 51 | elif "consistent tetrahedral mapping" in key: 52 | cat_counts["Consistent tetrahedral mapping"] += count 53 | else: 54 | key = ( 55 | key.replace("Template generation failed: ", "") 56 | .replace("Template generation failed with message:", "") 57 | .replace( 58 | "Template generation failed with message: Template generation failed:", 59 | "", 60 | ) 61 | .replace("\n", "\t") 62 | ) 63 | cat_counts[key.strip()] += count 64 | print_("Errors and counts: ") 65 | for key, count in cat_counts.items(): 66 | print_(f"- {key}: {count}") 67 | 68 | print("") 69 | 70 | # %% 71 | counts = ( 72 | data[data["TemplateError"].isna()] 73 | .groupby("TemplateHash", sort=False) 74 | .size() 75 | .value_counts() 76 | .to_dict() 77 | ) 78 | ranges = [ 79 | (0, 5), 80 | (5, 10), 81 | (10, 20), 82 | (20, 50), 83 | (50, 100), 84 | (100, 500), 85 | (500, max(counts.keys())), 86 | ] 87 | count_ranges = [] 88 | for start, end in ranges: 89 | count_ranges.append(sum(counts.get(idx, 0) for idx in range(start, end))) 90 | labels = [f"{start}-{end}" for start, end in ranges[:-1]] + [f"{ranges[-1][0]}-"] 91 | plt.gca().bar(height=count_ranges, x=range(1, len(ranges) + 1)) 92 | plt.gca().set_title("Distribution of template occurence") 93 | plt.gca().set_xticks(range(1, len(ranges) + 1)) 94 | _ = plt.gca().set_xticklabels(labels) 95 | 96 | # %% 97 | ax = ( 98 | data[data["TemplateError"].isna()] 99 | .TemplateGivesNOutcomes.value_counts() 100 | .sort_index() 101 | .plot.bar() 102 | ) 103 | _ = ax.set_title("Distribution of number of outcomes") 104 | 105 | # %% 106 | ncorrect = data[data["TemplateError"].isna()].TemplateGivesTrueReactants 107 | print_( 108 | f"Number of templates that reproduces the correct reactants: {ncorrect.sum()} ({ncorrect.mean()*100:.2f}%)" 109 | ) 110 | nother = data[data["TemplateError"].isna()].TemplateGivesOtherReactants 111 | print_( 112 | f"Number of templates that produces other reactants: {nother.sum()} ({nother.mean()*100:.2f}%)" 113 | ) 114 | 115 | # %% [markdown] 116 | """ 117 | ## Removing templates based on filters 118 | """ 119 | 120 | # %% 121 | sel_too_many_prod = data["TemplateError"].isna() & (data.nproducts > 1) 122 | sel_failures = data["RetroTemplate"].isna() 123 | sel_no_reprod = data["TemplateError"].isna() & (~data.TemplateGivesTrueReactants) 124 | 125 | rejected_data = data[sel_too_many_prod | sel_failures | sel_no_reprod] 126 | filename = input_filename.replace(".csv", "_rejected.csv") 127 | rejected_data.to_csv(filename, index=False, sep="\t") 128 | 129 | data = data[(~sel_too_many_prod) & (~sel_failures) & (~sel_no_reprod)] 130 | filename = input_filename.replace(".csv", "_checked.csv") 131 | data.to_csv(filename, index=False, sep="\t") 132 | 133 | print_( 134 | f"Removing {sel_too_many_prod.sum()} reactions with more than one product in template" 135 | ) 136 | print_(f"Removing {sel_failures.sum()} reactions with failed template generation") 137 | print_( 138 | f"Removing {sel_no_reprod.sum()} reactions with template unable to reproduce correct reactants" 139 | ) 140 | 141 | 142 | # %% 143 | print_( 144 | f"Number of unique templates before occurence filtering = {len(set(data['TemplateHash']))}" 145 | ) 146 | print_(f"Removing all templates with occurrence less than {min_occurrence}") 147 | template_group = data.groupby("TemplateHash") 148 | template_group = template_group.size().sort_values(ascending=False) 149 | min_index = template_group[template_group >= int(min_occurrence)].index 150 | final_sel = data["TemplateHash"].isin(min_index) 151 | 152 | # %% [markdown] 153 | """ 154 | ## Finalizing 155 | """ 156 | 157 | # %% 158 | data = data[ 159 | [ 160 | "RxnSmilesClean", 161 | "PseudoHash", 162 | "classification", 163 | "RingBreaker", 164 | "StereoBucket", 165 | "FlippedStereo", 166 | "RetroTemplate", 167 | "TemplateHash", 168 | ] 169 | ] 170 | data.rename( 171 | columns={ 172 | "RxnSmilesClean": col_config.reaction_smiles, 173 | "PseudoHash": col_config.reaction_hash, 174 | "classification": col_config.classification, 175 | "RingBreaker": col_config.ring_breaker, 176 | "StereoBucket": col_config.stereo_bucket, 177 | "FlippedStereo": col_config.flipped_stereo, 178 | "RetroTemplate": col_config.retro_template, 179 | "TemplateHash": col_config.template_hash, 180 | }, 181 | inplace=True, 182 | ) 183 | # %% 184 | discarded_data = data[~final_sel] 185 | data = data[final_sel] 186 | print_(f"Total number of selected templates = {len(data)}") 187 | print_(f"Total number of discarded templates = {len(discarded_data)}") 188 | print_(f"Number of unique templates = {len(set(data[col_config.template_hash]))}") 189 | print_( 190 | f"Number of discarded unique templates = {len(set(discarded_data[col_config.template_hash]))}" 191 | ) 192 | 193 | buckets = ["stereospecific", "reagent controlled", "substrate controlled"] 194 | ntemplates = [(data[col_config.stereo_bucket] == bucket).sum() for bucket in buckets] 195 | nunique = [ 196 | len(set(data[data[col_config.stereo_bucket] == bucket][col_config.template_hash])) 197 | for bucket in buckets 198 | ] 199 | display( 200 | pd.DataFrame( 201 | { 202 | "category": buckets, 203 | "n templates": ntemplates, 204 | "% of total": [n / len(data) * 100 for n in ntemplates], 205 | "n unique templates": nunique, 206 | } 207 | ) 208 | ) 209 | 210 | # %% 211 | filename = prefix_filename(output_prefix, output_postfix) 212 | filname_discarded = filename.replace(".csv", "_discarded.csv") 213 | print_(f"Saving selected templates to {filename} and discarded to {filname_discarded}.") 214 | data.to_csv(filename, index=False, sep="\t") 215 | discarded_data.to_csv(filname_discarded, index=False, sep="\t") 216 | -------------------------------------------------------------------------------- /aizynthtrain/pipelines/notebooks/template_selection.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from collections import Counter, defaultdict 3 | 4 | import pandas as pd 5 | import matplotlib.pylab as plt 6 | from IPython.display import Markdown 7 | 8 | from aizynthtrain.utils.files import prefix_filename 9 | from aizynthtrain.utils.configs import ( 10 | load_config, 11 | TemplateLibraryColumnsConfig, 12 | ) 13 | 14 | pd.options.display.float_format = "{:,.2f}".format 15 | print_ = lambda x: display(Markdown(x)) 16 | 17 | # %% tags=["parameters"] 18 | input_filename = "" 19 | output_prefix = "" 20 | output_postfix = "" 21 | min_occurrence = 3 22 | config_filename = "" 23 | 24 | # %% 25 | data = pd.read_csv(input_filename, sep="\t") 26 | col_config: TemplateLibraryColumnsConfig = load_config( 27 | config_filename, "template_library_columns" 28 | ) 29 | 30 | # %% [markdown] 31 | """ 32 | ## Statistics on extracted and validated templates 33 | """ 34 | 35 | # %% 36 | sel = data["TemplateError"].isna() 37 | print_(f"Total number of extracted templates = {sel.sum()} ({sel.mean():.2f}%)") 38 | print_(f"Number of unique templates = {len(set(data[sel]['TemplateHash']))}") 39 | print_( 40 | f"Number of templates failures = {len(data)-sel.sum()} ({1-sel.sum()/len(data):.5f}%)" 41 | ) 42 | 43 | counts = Counter(data[data["RetroTemplate"].isna()]["TemplateError"]) 44 | cat_counts = defaultdict(int) 45 | for key, count in counts.items(): 46 | if "too many unmapped atoms" in key: 47 | cat_counts["Too many unmapped atoms"] += count 48 | elif "consistent tetrahedral mapping" in key: 49 | cat_counts["Consistent tetrahedral mapping"] += count 50 | else: 51 | key = ( 52 | key.replace("Template generation failed: ", "") 53 | .replace("Template generation failed with message:", "") 54 | .replace( 55 | "Template generation failed with message: Template generation failed:", 56 | "", 57 | ) 58 | .replace("\n", "\t") 59 | ) 60 | cat_counts[key.strip()] += count 61 | print_("Errors and counts: ") 62 | for key, count in cat_counts.items(): 63 | print_(f"- {key}: {count}") 64 | 65 | print("") 66 | 67 | # %% 68 | counts = ( 69 | data[data["TemplateError"].isna()] 70 | .groupby("TemplateHash", sort=False) 71 | .size() 72 | .value_counts() 73 | .to_dict() 74 | ) 75 | ranges = [ 76 | (0, 5), 77 | (5, 10), 78 | (10, 20), 79 | (20, 50), 80 | (50, 100), 81 | (100, 500), 82 | (500, max(counts.keys())), 83 | ] 84 | count_ranges = [] 85 | for start, end in ranges: 86 | count_ranges.append(sum(counts.get(idx, 0) for idx in range(start, end))) 87 | labels = [f"{start}-{end}" for start, end in ranges[:-1]] + [f"{ranges[-1][0]}-"] 88 | plt.gca().bar(height=count_ranges, x=range(1, len(ranges) + 1)) 89 | plt.gca().set_title("Distribution of template occurence") 90 | plt.gca().set_xticks(range(1, len(ranges) + 1)) 91 | _ = plt.gca().set_xticklabels(labels) 92 | 93 | # %% 94 | ax = ( 95 | data[data["TemplateError"].isna()] 96 | .TemplateGivesNOutcomes.value_counts() 97 | .sort_index() 98 | .plot.bar() 99 | ) 100 | _ = ax.set_title("Distribution of number of outcomes") 101 | 102 | # %% 103 | ncorrect = data[data["TemplateError"].isna()].TemplateGivesTrueReactants 104 | print_( 105 | f"Number of templates that reproduces the correct reactants: {ncorrect.sum()} ({ncorrect.mean()*100:.2f}%)" 106 | ) 107 | nother = data[data["TemplateError"].isna()].TemplateGivesOtherReactants 108 | print_( 109 | f"Number of templates that produces other reactants: {nother.sum()} ({nother.mean()*100:.2f}%)" 110 | ) 111 | 112 | # %% [markdown] 113 | """ 114 | ## Removing templates based on filters 115 | """ 116 | 117 | # %% 118 | sel_too_many_prod = data["TemplateError"].isna() & (data.nproducts > 1) 119 | sel_failures = data["RetroTemplate"].isna() 120 | sel_no_reprod = data["TemplateError"].isna() & (~data.TemplateGivesTrueReactants) 121 | data = data[(~sel_too_many_prod) & (~sel_failures) & (~sel_no_reprod)] 122 | print_( 123 | f"Removing {sel_too_many_prod.sum()} reactions with more than one product in template" 124 | ) 125 | print_(f"Removing {sel_failures.sum()} reactions with failed template generation") 126 | print_( 127 | f"Removing {sel_no_reprod.sum()} reactions with template unable to reproduce correct reactants" 128 | ) 129 | 130 | 131 | # %% 132 | print_( 133 | f"Number of unique templates before occurence filtering = {len(set(data['TemplateHash']))}" 134 | ) 135 | print_(f"Removing all templates with occurrence less than {min_occurrence}") 136 | template_group = data.groupby("TemplateHash") 137 | template_group = template_group.size().sort_values(ascending=False) 138 | min_index = template_group[template_group >= int(min_occurrence)].index 139 | data = data[data["TemplateHash"].isin(min_index)] 140 | 141 | # %% [markdown] 142 | """ 143 | ## Finalizing 144 | """ 145 | 146 | # %% 147 | data = data[ 148 | [ 149 | "RxnSmilesClean", 150 | "PseudoHash", 151 | "classification", 152 | "RingBreaker", 153 | "RetroTemplate", 154 | "TemplateHash", 155 | ] 156 | ] 157 | data.rename( 158 | columns={ 159 | "RxnSmilesClean": col_config.reaction_smiles, 160 | "PseudoHash": col_config.reaction_hash, 161 | "classification": col_config.classification, 162 | "RingBreaker": col_config.ring_breaker, 163 | "RetroTemplate": col_config.retro_template, 164 | "TemplateHash": col_config.template_hash, 165 | }, 166 | inplace=True, 167 | ) 168 | # %% 169 | filename = prefix_filename(output_prefix, output_postfix) 170 | print_(f"Total number of selected templates = {len(data)}") 171 | print_(f"Number of unique templates = {len(set(data[col_config.template_hash]))}") 172 | print_(f"Saving selected templates to {filename}") 173 | data.to_csv(filename, index=False, sep="\t") 174 | 175 | # %% 176 | data = data[data[col_config.ring_breaker]] 177 | filename = prefix_filename(output_prefix, "ringbreaker_" + output_postfix) 178 | print_(f"Total number of selected ring breaker templates = {len(data)}") 179 | print_( 180 | f"Number of unique ring breaker templates = {len(set(data[col_config.template_hash]))}" 181 | ) 182 | print_(f"Saving selected ring breaker templates to {filename}") 183 | data.to_csv(filename, index=False, sep="\t") 184 | -------------------------------------------------------------------------------- /aizynthtrain/pipelines/stereo_template_pipeline.py: -------------------------------------------------------------------------------- 1 | """Module containing preparing a template library""" 2 | from pathlib import Path 3 | 4 | from metaflow import FlowSpec, Parameter, step 5 | from rxnutils.pipeline.runner import main as validation_runner 6 | from rxnutils.data.batch_utils import create_csv_batches, combine_csv_batches 7 | 8 | from aizynthtrain.utils.configs import ( 9 | load_config, 10 | TemplatePipelineConfig, 11 | ) 12 | from aizynthtrain.utils.reporting import main as selection_runner 13 | from aizynthtrain.utils.template_runner import main as template_runner 14 | from aizynthtrain.utils.stereo_flipper import main as stereo_flipper 15 | from aizynthtrain.utils.files import prefix_filename 16 | 17 | 18 | class StereoTemplatesExtractionFlow(FlowSpec): 19 | config_path = Parameter("config", required=True) 20 | 21 | @step 22 | def start(self): 23 | """Loading configuration""" 24 | self.config: TemplatePipelineConfig = load_config( 25 | self.config_path, "template_pipeline" 26 | ) 27 | self.next(self.stereo_checks_setup) 28 | 29 | @step 30 | def stereo_checks_setup(self): 31 | """Preparing splits of the reaction data for stereo checks""" 32 | self.reaction_partitions = self._create_batches( 33 | self.config.selected_reactions_path.replace(".csv", "_all.csv"), 34 | self.config.stereo_reactions_path, 35 | ) 36 | self.next(self.stereo_checks, foreach="reaction_partitions") 37 | 38 | @step 39 | def stereo_checks(self): 40 | """Validating reaction data""" 41 | pipeline_path = str( 42 | Path(__file__).parent / "data" / "stereo_checks_pipeline.yaml" 43 | ) 44 | idx, start, end = self.input 45 | batch_output = f"{self.config.stereo_reactions_path}.{idx}" 46 | if idx > -1 and not Path(batch_output).exists(): 47 | validation_runner( 48 | [ 49 | "--pipeline", 50 | pipeline_path, 51 | "--data", 52 | self.config.selected_reactions_path.replace(".csv", "_all.csv"), 53 | "--output", 54 | batch_output, 55 | "--max-workers", 56 | "1", 57 | "--batch", 58 | str(start), 59 | str(end), 60 | "--no-intermediates", 61 | ] 62 | ) 63 | else: 64 | print( 65 | f"Skipping template extraction for idx {idx}. File exists = {Path(batch_output).exists()}" 66 | ) 67 | self.next(self.stereo_checks_join) 68 | 69 | @step 70 | def stereo_checks_join(self, inputs): 71 | """Joining stereo reactions""" 72 | self.config = inputs[0].config 73 | self._combine_batches(self.config.stereo_reactions_path) 74 | self.next(self.stereo_selection) 75 | 76 | @step 77 | def stereo_selection(self): 78 | """Selecting stereo reactions and producing report""" 79 | notebook_path = str(Path(__file__).parent / "notebooks" / "stereo_selection.py") 80 | if not Path(self.config.selected_stereo_reactions_path).exists(): 81 | selection_runner( 82 | [ 83 | "--notebook", 84 | notebook_path, 85 | "--report_path", 86 | self.config.stereo_report_path, 87 | "--python_kernel", 88 | self.config.python_kernel, 89 | "--input_filename", 90 | self.config.stereo_reactions_path, 91 | "--output_filename", 92 | self.config.selected_stereo_reactions_path, 93 | ] 94 | ) 95 | self.next(self.template_extraction_setup) 96 | 97 | @step 98 | def template_extraction_setup(self): 99 | """Preparing splits of the reaction data for template extraction""" 100 | self.reaction_partitions = self._create_batches( 101 | self.config.selected_stereo_reactions_path, 102 | self.config.unvalidated_templates_path, 103 | ) 104 | self.next(self.template_extraction, foreach="reaction_partitions") 105 | 106 | @step 107 | def template_extraction(self): 108 | """Extracting RDChiral reaction templates""" 109 | idx, start, end = self.input 110 | batch_output = f"{self.config.unvalidated_templates_path}.{idx}" 111 | if idx > -1 and not Path(batch_output).exists(): 112 | template_runner( 113 | [ 114 | "--input_path", 115 | self.config.selected_stereo_reactions_path, 116 | "--output_path", 117 | batch_output, 118 | "--radius", 119 | "1", 120 | "--smiles_column", 121 | "RxnSmilesClean", 122 | "--ringbreaker_column", 123 | "RingBreaker", 124 | "--batch", 125 | str(start), 126 | str(end), 127 | ] 128 | ) 129 | else: 130 | print( 131 | f"Skipping template extraction for idx {idx}. File exists = {Path(batch_output).exists()}" 132 | ) 133 | self.next(self.template_extraction_join) 134 | 135 | @step 136 | def template_extraction_join(self, inputs): 137 | """Joining extracted templates""" 138 | self.config = inputs[0].config 139 | self._combine_batches(self.config.unvalidated_templates_path) 140 | self.next(self.template_validation_setup) 141 | 142 | @step 143 | def template_validation_setup(self): 144 | """Preparing splits of the reaction data for template validation""" 145 | self.reaction_partitions = self._create_batches( 146 | self.config.unvalidated_templates_path, self.config.validated_templates_path 147 | ) 148 | self.next(self.template_validation, foreach="reaction_partitions") 149 | 150 | @step 151 | def template_validation(self): 152 | """Validating extracted templates""" 153 | pipline_path = str( 154 | Path(__file__).parent / "data" / "template_validation_pipeline.yaml" 155 | ) 156 | idx, start, end = self.input 157 | if idx > -1: 158 | validation_runner( 159 | [ 160 | "--pipeline", 161 | pipline_path, 162 | "--data", 163 | self.config.unvalidated_templates_path, 164 | "--output", 165 | f"{self.config.validated_templates_path}.{idx}", 166 | "--max-workers", 167 | "1", 168 | "--batch", 169 | str(start), 170 | str(end), 171 | "--no-intermediates", 172 | ] 173 | ) 174 | stereo_flipper( 175 | [ 176 | "--input_path", 177 | f"{self.config.validated_templates_path}.{idx}", 178 | "--query", 179 | "StereoBucket=='reagent controlled'", 180 | ] 181 | ) 182 | self.next(self.template_validation_join) 183 | 184 | @step 185 | def template_validation_join(self, inputs): 186 | """Joining validated templates""" 187 | self.config = inputs[0].config 188 | self._combine_batches(self.config.validated_templates_path) 189 | self.next(self.template_selection) 190 | 191 | @step 192 | def template_selection(self): 193 | """Selection templates and produce report""" 194 | notebook_path = str( 195 | Path(__file__).parent / "notebooks" / "stereo_template_selection.py" 196 | ) 197 | output_path = prefix_filename( 198 | self.config.selected_templates_prefix, 199 | self.config.selected_templates_postfix, 200 | ) 201 | if not Path(output_path).exists(): 202 | selection_runner( 203 | [ 204 | "--notebook", 205 | notebook_path, 206 | "--report_path", 207 | self.config.templates_report_path, 208 | "--python_kernel", 209 | self.config.python_kernel, 210 | "--input_filename", 211 | self.config.validated_templates_path, 212 | "--output_prefix", 213 | self.config.selected_templates_prefix, 214 | "--output_postfix", 215 | self.config.selected_templates_postfix, 216 | "--min_occurrence", 217 | str(self.config.min_template_occurrence), 218 | "--config_filename", 219 | self.config_path, 220 | ] 221 | ) 222 | self.next(self.end) 223 | 224 | @step 225 | def end(self): 226 | print( 227 | f"Report on extracted reaction is located here: {self.config.stereo_report_path}" 228 | ) 229 | print( 230 | f"Report on extracted templates is located here: {self.config.templates_report_path}" 231 | ) 232 | 233 | def _combine_batches(self, filename): 234 | combine_csv_batches(filename, self.config.nbatches) 235 | 236 | def _create_batches(self, input_filename, output_filename): 237 | return create_csv_batches(input_filename, self.config.nbatches, output_filename) 238 | 239 | 240 | if __name__ == "__main__": 241 | StereoTemplatesExtractionFlow() 242 | -------------------------------------------------------------------------------- /aizynthtrain/pipelines/template_pipeline.py: -------------------------------------------------------------------------------- 1 | """Module containing preparing a template library""" 2 | from pathlib import Path 3 | import importlib 4 | 5 | from metaflow import FlowSpec, Parameter, step 6 | from rxnutils.pipeline.runner import main as validation_runner 7 | 8 | from aizynthtrain.utils.configs import ( 9 | load_config, 10 | TemplatePipelineConfig, 11 | ) 12 | from aizynthtrain.utils.reporting import main as selection_runner 13 | from aizynthtrain.utils.template_runner import main as template_runner 14 | from aizynthtrain.utils.files import prefix_filename 15 | from rxnutils.data.batch_utils import create_csv_batches, combine_csv_batches 16 | 17 | 18 | class TemplatesExtractionFlow(FlowSpec): 19 | config_path = Parameter("config", required=True) 20 | 21 | @step 22 | def start(self): 23 | """Loading configuration""" 24 | self.config: TemplatePipelineConfig = load_config( 25 | self.config_path, "template_pipeline" 26 | ) 27 | self.next(self.import_data) 28 | 29 | @step 30 | def import_data(self): 31 | """Importing and transforming reaction data""" 32 | if not Path(self.config.import_data_path).exists(): 33 | module_name, cls_name = self.config.data_import_class.rsplit( 34 | ".", maxsplit=1 35 | ) 36 | loaded_module = importlib.import_module(module_name) 37 | importer = getattr(loaded_module, cls_name)( 38 | **self.config.data_import_config 39 | ) 40 | importer.data.to_csv(self.config.import_data_path, sep="\t", index=False) 41 | self.next(self.reaction_validation_setup) 42 | 43 | @step 44 | def reaction_validation_setup(self): 45 | """Preparing splits of the reaction data for reaction validation""" 46 | self.reaction_partitions = self._create_batches( 47 | self.config.import_data_path, self.config.validated_reactions_path 48 | ) 49 | self.next(self.reaction_validation, foreach="reaction_partitions") 50 | 51 | @step 52 | def reaction_validation(self): 53 | """Validating reaction data""" 54 | pipeline_path = str( 55 | Path(__file__).parent / "data" / "reaction_validation_pipeline.yaml" 56 | ) 57 | idx, start, end = self.input 58 | if idx > -1: 59 | validation_runner( 60 | [ 61 | "--pipeline", 62 | pipeline_path, 63 | "--data", 64 | self.config.import_data_path, 65 | "--output", 66 | f"{self.config.validated_reactions_path}.{idx}", 67 | "--max-workers", 68 | "1", 69 | "--batch", 70 | str(start), 71 | str(end), 72 | "--no-intermediates", 73 | ] 74 | ) 75 | self.next(self.reaction_validation_join) 76 | 77 | @step 78 | def reaction_validation_join(self, inputs): 79 | """Joining validated reactions""" 80 | self.config = inputs[0].config 81 | self._combine_batches(self.config.validated_reactions_path) 82 | self.next(self.reaction_selection) 83 | 84 | @step 85 | def reaction_selection(self): 86 | """Selecting reactions and producing report""" 87 | notebook_path = str( 88 | Path(__file__).parent / "notebooks" / "reaction_selection.py" 89 | ) 90 | if not Path(self.config.selected_reactions_path).exists(): 91 | selection_runner( 92 | [ 93 | "--notebook", 94 | notebook_path, 95 | "--report_path", 96 | self.config.reaction_report_path, 97 | "--python_kernel", 98 | self.config.python_kernel, 99 | "--input_filename", 100 | self.config.validated_reactions_path, 101 | "--output_filename", 102 | self.config.selected_reactions_path, 103 | ] 104 | ) 105 | self.next(self.template_extraction_setup) 106 | 107 | @step 108 | def template_extraction_setup(self): 109 | """Preparing splits of the reaction data for template extraction""" 110 | self.reaction_partitions = self._create_batches( 111 | self.config.selected_reactions_path, self.config.unvalidated_templates_path 112 | ) 113 | self.next(self.template_extraction, foreach="reaction_partitions") 114 | 115 | @step 116 | def template_extraction(self): 117 | """Extracting RDChiral reaction templates""" 118 | idx, start, end = self.input 119 | batch_output = f"{self.config.unvalidated_templates_path}.{idx}" 120 | if idx > -1 and not Path(batch_output).exists(): 121 | template_runner( 122 | [ 123 | "--input_path", 124 | self.config.selected_reactions_path, 125 | "--output_path", 126 | batch_output, 127 | "--radius", 128 | "1", 129 | "--smiles_column", 130 | "RxnSmilesClean", 131 | "--ringbreaker_column", 132 | "RingBreaker", 133 | "--batch", 134 | str(start), 135 | str(end), 136 | ] 137 | ) 138 | else: 139 | print( 140 | f"Skipping template extraction for idx {idx}. File exists = {Path(batch_output).exists()}" 141 | ) 142 | self.next(self.template_extraction_join) 143 | 144 | @step 145 | def template_extraction_join(self, inputs): 146 | """Joining extracted templates""" 147 | self.config = inputs[0].config 148 | self._combine_batches(self.config.unvalidated_templates_path) 149 | self.next(self.template_validation_setup) 150 | 151 | @step 152 | def template_validation_setup(self): 153 | """Preparing splits of the reaction data for template validation""" 154 | self.reaction_partitions = self._create_batches( 155 | self.config.unvalidated_templates_path, self.config.validated_templates_path 156 | ) 157 | self.next(self.template_validation, foreach="reaction_partitions") 158 | 159 | @step 160 | def template_validation(self): 161 | """Validating extracted templates""" 162 | pipline_path = str( 163 | Path(__file__).parent / "data" / "template_validation_pipeline.yaml" 164 | ) 165 | idx, start, end = self.input 166 | if idx > -1: 167 | validation_runner( 168 | [ 169 | "--pipeline", 170 | pipline_path, 171 | "--data", 172 | self.config.unvalidated_templates_path, 173 | "--output", 174 | f"{self.config.validated_templates_path}.{idx}", 175 | "--max-workers", 176 | "1", 177 | "--batch", 178 | str(start), 179 | str(end), 180 | "--no-intermediates", 181 | ] 182 | ) 183 | self.next(self.template_validation_join) 184 | 185 | @step 186 | def template_validation_join(self, inputs): 187 | """Joining validated templates""" 188 | self.config = inputs[0].config 189 | self._combine_batches(self.config.validated_templates_path) 190 | self.next(self.template_selection) 191 | 192 | @step 193 | def template_selection(self): 194 | """Selection templates and produce report""" 195 | notebook_path = str( 196 | Path(__file__).parent / "notebooks" / "template_selection.py" 197 | ) 198 | output_path = prefix_filename( 199 | self.config.selected_templates_prefix, 200 | self.config.selected_templates_postfix, 201 | ) 202 | if not Path(output_path).exists(): 203 | selection_runner( 204 | [ 205 | "--notebook", 206 | notebook_path, 207 | "--report_path", 208 | self.config.templates_report_path, 209 | "--python_kernel", 210 | self.config.python_kernel, 211 | "--input_filename", 212 | self.config.validated_templates_path, 213 | "--output_prefix", 214 | self.config.selected_templates_prefix, 215 | "--output_postfix", 216 | self.config.selected_templates_postfix, 217 | "--min_occurrence", 218 | str(self.config.min_template_occurrence), 219 | "--config_filename", 220 | self.config_path, 221 | ] 222 | ) 223 | self.next(self.end) 224 | 225 | @step 226 | def end(self): 227 | print( 228 | f"Report on extracted reaction is located here: {self.config.reaction_report_path}" 229 | ) 230 | print( 231 | f"Report on extracted templates is located here: {self.config.templates_report_path}" 232 | ) 233 | 234 | def _combine_batches(self, filename): 235 | combine_csv_batches(filename, self.config.nbatches) 236 | 237 | def _create_batches(self, input_filename, output_filename): 238 | return create_csv_batches(input_filename, self.config.nbatches, output_filename) 239 | 240 | 241 | if __name__ == "__main__": 242 | TemplatesExtractionFlow() 243 | -------------------------------------------------------------------------------- /aizynthtrain/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthtrain/55df633bbdf971378e7f031e8e89190cde1b17fb/aizynthtrain/utils/__init__.py -------------------------------------------------------------------------------- /aizynthtrain/utils/configs.py: -------------------------------------------------------------------------------- 1 | """Module containing configuration classes""" 2 | import sys 3 | from typing import Any, Dict, List, Optional, Union 4 | 5 | import yaml 6 | from pydantic import BaseModel, Field 7 | 8 | 9 | def load_config(filename: str, class_key: str) -> Any: 10 | """ 11 | Loading any configuration class from a YAML file 12 | 13 | The `class_key` should be a snake-case version of class name 14 | in this module, excluding the "Config" string. 15 | For example: template_pipeline -> TemplatePipelineConfig 16 | 17 | :param filename: the path to the YAML file 18 | :param class_key: the name of the configuration at the top-level 19 | """ 20 | with open(filename, "r") as fileobj: 21 | dict_ = yaml.load(fileobj.read(), Loader=yaml.SafeLoader) 22 | 23 | module = sys.modules[__name__] 24 | class_name = "".join(part.title() for part in class_key.split("_")) + "Config" 25 | if not hasattr(module, class_name): 26 | raise KeyError(f"No config class matches {class_key}") 27 | class_ = getattr(module, class_name) 28 | try: 29 | config = class_(**dict_.get(class_key, {})) 30 | except TypeError as err: 31 | raise ValueError(f"Unable to setup config class with message: {err}") 32 | 33 | return config 34 | 35 | 36 | class _FilenameMixin: 37 | """Mixin-class for providing easy access to filenames""" 38 | 39 | def filename(self, name: str, subset: str = "") -> str: 40 | name_value = getattr(self.postfixes, name) 41 | if subset: 42 | name_value = subset + "_" + name_value 43 | if self.file_prefix: 44 | return self.file_prefix + "_" + name_value 45 | return name_value 46 | 47 | 48 | class TemplatePipelineConfig(BaseModel): 49 | """Configuration class for template pipeline""" 50 | 51 | python_kernel: str 52 | data_import_class: str 53 | data_import_config: Dict[str, str] = Field(default_factory=dict) 54 | selected_templates_prefix: str = "" 55 | selected_templates_postfix: str = "template_library.csv" 56 | import_data_path: str = "imported_reactions.csv" 57 | validated_reactions_path: str = "validated_reactions.csv" 58 | selected_reactions_path: str = "selected_reactions.csv" 59 | reaction_report_path: str = "reaction_selection_report.html" 60 | stereo_reactions_path: str = "stereo_reactions.csv" 61 | selected_stereo_reactions_path: str = "selected_stereo_reactions.csv" 62 | stereo_report_path: str = "stereo_selection_report.html" 63 | unvalidated_templates_path: str = "reaction_templates_unvalidated.csv" 64 | validated_templates_path: str = "reaction_templates_validated.csv" 65 | templates_report_path: str = "template_selection_report.html" 66 | min_template_occurrence: int = 10 67 | nbatches: int = 200 68 | 69 | 70 | # Config classes for the chemformer model pipelines 71 | 72 | 73 | class ChemformerTrainConfig(BaseModel): 74 | """Configuration class for chemformer training""" 75 | 76 | data_path: str = "proc_selected_reactions.csv" 77 | vocabulary_path: str = "bart_vocab_disconnection_aware.json" 78 | task: str = "backward_prediction" 79 | model_path: Optional[str] = None 80 | deepspeed_config_path: str = "ds_config.json" 81 | output_directory: str = "" 82 | 83 | n_epochs: int = 70 84 | resume: bool = False 85 | n_gpus: int = 1 86 | batch_size: int = 64 87 | augmentation_probability: float = 0.5 88 | learning_rate: float = 0.001 89 | 90 | check_val_every_n_epoch: int = 1 91 | limit_val_batches: float = 1.0 92 | n_nodes: int = 1 93 | acc_batches: int = 4 94 | schedule: str = "cycle" 95 | 96 | # Arguments for model from random initialization 97 | d_model: int = 512 98 | n_layers: int = 6 99 | n_heads: int = 8 100 | d_feedforward: int = 2048 101 | 102 | 103 | class ChemformerDataPrepConfig(BaseModel): 104 | """Configuration class for chemformer data-prep pipeline""" 105 | 106 | selected_reactions_path: str = "selected_reactions.csv" 107 | reaction_components_path: str = "reaction_components.csv" 108 | routes_to_exclude: List[str] = Field(default_factory=list) 109 | training_fraction: float = 0.9 110 | random_seed: int = 11 111 | nbatches: int = 200 112 | chemformer_data_path: str = "proc_selected_reactions.csv" 113 | disconnection_aware_data_path: str = "proc_selected_reactions_disconnection.csv" 114 | autotag_data_path: str = "proc_selected_reactions_autotag.csv" 115 | 116 | reaction_hash_col: str = "reaction_hash" 117 | set_col: str = "set" 118 | is_external_col: str = "is_external" 119 | 120 | 121 | # Config classes for the expansion model pipeline 122 | 123 | 124 | class TemplateLibraryColumnsConfig(BaseModel): 125 | """Configuration class for columns in a template library file""" 126 | 127 | reaction_smiles: str = "reaction_smiles" 128 | reaction_hash: str = "reaction_hash" 129 | retro_template: str = "retro_template" 130 | template_hash: str = "template_hash" 131 | template_code: str = "template_code" 132 | library_occurrence: str = "library_occurence" 133 | classification: Optional[str] = "classification" 134 | ring_breaker: Optional[str] = "ring_breaker" 135 | stereo_bucket: Optional[str] = "stereo_bucket" 136 | flipped_stereo: Optional[str] = "flipped_stereo" 137 | 138 | 139 | class TemplateLibraryConfig(BaseModel): 140 | """Configuration class for template library generation""" 141 | 142 | columns: TemplateLibraryColumnsConfig = Field( 143 | default_factory=TemplateLibraryColumnsConfig 144 | ) 145 | metadata_columns: List[str] = Field( 146 | default_factory=lambda: ["template_hash", "classification"] 147 | ) 148 | template_set: str = "templates" 149 | 150 | 151 | class ExpansionModelPipelinePostfixes(BaseModel): 152 | """Configuration class for postfixes of files generated by expansion model pipeline""" 153 | 154 | library: str = "template_library.csv" 155 | template_code: str = "template_code.csv" 156 | unique_templates: str = "unique_templates.csv.gz" 157 | template_lookup: str = "lookup.json" 158 | split_indices: str = "split_indices.npz" 159 | model_labels: str = "labels.npz" 160 | model_inputs: str = "inputs.npz" 161 | training_log: str = "keras_training.csv" 162 | training_checkpoint: str = "keras_model.hdf5" 163 | onnx_model: str = "expansion.onnx" 164 | 165 | report: str = "expansion_model_report.html" 166 | finder_output: str = "model_validation_finder_output.hdf5" 167 | multistep_report: str = "model_validation_multistep_report.json" 168 | expander_output: str = "model_validation_expander_output.json" 169 | onestep_report: str = "model_validation_onestep_report.json" 170 | 171 | 172 | class ExpansionModelHyperparamsConfig(BaseModel): 173 | """Configuration class for hyper parameters for expansion model training""" 174 | 175 | fingerprint_radius: int = 2 176 | fingerprint_length: int = 2048 177 | batch_size: int = 256 178 | hidden_nodes: int = 512 179 | dropout: int = 0.4 180 | epochs: int = 100 181 | chirality: bool = False 182 | 183 | 184 | class ExpansionModelPipelineConfig(BaseModel, _FilenameMixin): 185 | """Configuration class for expansion model pipeline""" 186 | 187 | python_kernel: str 188 | file_prefix: str = "" 189 | nbatches: int = 200 190 | training_fraction: float = 0.9 191 | random_seed: int = 1689 192 | selected_ids_path: str = "selected_reactions_ids.json" 193 | routes_to_exclude: List[str] = Field(default_factory=list) 194 | library_config: TemplateLibraryConfig = Field(default_factory=TemplateLibraryConfig) 195 | postfixes: ExpansionModelPipelinePostfixes = Field( 196 | default_factory=ExpansionModelPipelinePostfixes 197 | ) 198 | model_hyperparams: ExpansionModelHyperparamsConfig = Field( 199 | default_factory=ExpansionModelHyperparamsConfig 200 | ) 201 | 202 | 203 | class ExpansionModelEvaluationConfig(BaseModel, _FilenameMixin): 204 | """Configuration class for evaluation of expansion model""" 205 | 206 | stock_for_finding: str = "" 207 | stock_for_recovery: str = "" 208 | search_properties_for_finding: Dict[str, Any] = Field( 209 | default_factory=lambda: {"return_first": True} 210 | ) 211 | search_properties_for_recovery: Dict[str, Any] = Field( 212 | default_factory=lambda: { 213 | "max_transforms": 10, 214 | "iteration_limit": 500, 215 | "time_limit": 3600, 216 | } 217 | ) 218 | reference_routes: str = "" 219 | target_smiles: str = "" 220 | file_prefix: str = "" 221 | top_n: int = 50 222 | n_test_reactions: int = 1000 223 | columns: TemplateLibraryColumnsConfig = Field( 224 | default_factory=TemplateLibraryColumnsConfig 225 | ) 226 | postfixes: ExpansionModelPipelinePostfixes = Field( 227 | default_factory=ExpansionModelPipelinePostfixes 228 | ) 229 | -------------------------------------------------------------------------------- /aizynthtrain/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """Module containing utility routines for data manipulation""" 2 | import json 3 | import random 4 | from typing import List, Set, Dict, Any, Optional 5 | 6 | import pandas as pd 7 | from sklearn.model_selection import train_test_split 8 | 9 | 10 | def extract_route_reactions( 11 | filenames: List[str], metadata_key="reaction_hash" 12 | ) -> Set[str]: 13 | """ 14 | Extract reaction hashes from given routes 15 | 16 | :param filenames: the paths of the routes 17 | :param metadata_key: the key in the metadata with the reaction hash 18 | :return: the extracted hashes 19 | """ 20 | 21 | def traverse_tree(tree_dict: Dict[str, Any], hashes: Set[str]) -> None: 22 | if tree_dict["type"] == "reaction": 23 | hashes.add(tree_dict["metadata"][metadata_key]) 24 | for child in tree_dict.get("children", []): 25 | traverse_tree(child, hashes) 26 | 27 | reaction_hashes = set() 28 | for filename in filenames: 29 | with open(filename, "r") as fileobj: 30 | routes = json.load(fileobj) 31 | for route in routes: 32 | traverse_tree(route, reaction_hashes) 33 | 34 | return reaction_hashes 35 | 36 | 37 | def split_data( 38 | group_data: pd.DataFrame, 39 | train_frac: float, 40 | random_seed: int, 41 | train_indices: List[int], 42 | val_indices: List[int], 43 | test_indices: Optional[List[int]] = None, 44 | ) -> None: 45 | """ 46 | Split a dataframe into training, validation and optionally test partitions 47 | 48 | The indices of the rows selected for the different partitions are added to the 49 | given lists. 50 | 51 | First split the data into a training and validation set with `train_frac` specifying 52 | the portion of the data to end up in the training partition. 53 | 54 | Second, if `test_indices` is provided take a random half of the validation as test 55 | partition. If only one row is in the validation partition after the first step, randomly 56 | assign it to test with a 50% probability. 57 | 58 | :params group_data: the data to split 59 | :params train_frac: the fractition of the data to go into the training partition 60 | :params random_seed: the seed of the random generator 61 | :params train_indices: the list of add the indices of the training partition 62 | :params val_indices: the list of add the indices of the validation partition 63 | :params test_indices: the list of add the indices of the testing partition 64 | """ 65 | if len(group_data) > 1: 66 | train_arr, val_arr = train_test_split( 67 | group_data.index, 68 | train_size=train_frac, 69 | random_state=random_seed, 70 | shuffle=True, 71 | ) 72 | else: 73 | train_arr = list(group_data.index) 74 | val_arr = [] 75 | 76 | train_indices.extend(train_arr) 77 | if test_indices is None: 78 | val_indices.extend(val_arr) 79 | return 80 | 81 | if len(val_arr) > 1: 82 | val_arr, train_arr = train_test_split( 83 | val_arr, test_size=0.5, random_state=random_seed, shuffle=True 84 | ) 85 | test_indices.extend(train_arr) 86 | val_indices.extend(val_arr) 87 | elif random.random() < 0.5: 88 | test_indices.extend(val_arr) 89 | else: 90 | val_indices.extend(val_arr) 91 | -------------------------------------------------------------------------------- /aizynthtrain/utils/files.py: -------------------------------------------------------------------------------- 1 | """Module containing various file utilities""" 2 | 3 | 4 | def prefix_filename(prefix: str, postfix: str) -> str: 5 | """ 6 | Construct pre- and post-fixed filename 7 | 8 | :param prefix: the prefix, can be empty 9 | :param postfix: the postfix 10 | :return: the concatenated string 11 | """ 12 | if prefix: 13 | return prefix + "_" + postfix 14 | return postfix 15 | -------------------------------------------------------------------------------- /aizynthtrain/utils/keras_utils.py: -------------------------------------------------------------------------------- 1 | """Module containing utility classes and routines used in training of policies""" 2 | import functools 3 | from typing import List, Any, Tuple 4 | 5 | import numpy as np 6 | 7 | from tensorflow.config import list_physical_devices 8 | from tensorflow.keras.models import Model 9 | from tensorflow.keras.optimizers import Adam 10 | from tensorflow.keras.utils import Sequence 11 | from tensorflow.keras.callbacks import ( 12 | EarlyStopping, 13 | CSVLogger, 14 | ModelCheckpoint, 15 | ReduceLROnPlateau, 16 | ) 17 | from tensorflow.keras.metrics import top_k_categorical_accuracy 18 | from scipy import sparse 19 | 20 | top10_acc = functools.partial(top_k_categorical_accuracy, k=10) 21 | top10_acc.__name__ = "top10_acc" # type: ignore 22 | 23 | top50_acc = functools.partial(top_k_categorical_accuracy, k=50) 24 | top50_acc.__name__ = "top50_acc" # type: ignore 25 | 26 | 27 | class InMemorySequence(Sequence): # pylint: disable=W0223 28 | """ 29 | Class for in-memory data management 30 | 31 | :param input_filname: the path to the model input data 32 | :param output_filename: the path to the model output data 33 | :param batch_size: the size of the batches 34 | """ 35 | 36 | def __init__( 37 | self, input_filename: str, output_filename: str, batch_size: int 38 | ) -> None: 39 | self.batch_size = batch_size 40 | self.input_matrix = self._load_data(input_filename) 41 | self.label_matrix = self._load_data(output_filename) 42 | self.input_dim = self.input_matrix.shape[1] 43 | 44 | def __len__(self) -> int: 45 | return int(np.ceil(self.label_matrix.shape[0] / float(self.batch_size))) 46 | 47 | def _make_slice(self, idx: int) -> slice: 48 | if idx < 0 or idx >= len(self): 49 | raise IndexError("index out of range") 50 | 51 | start = idx * self.batch_size 52 | end = (idx + 1) * self.batch_size 53 | return slice(start, end) 54 | 55 | @staticmethod 56 | def _load_data(filename: str) -> np.ndarray: 57 | try: 58 | return sparse.load_npz(filename) 59 | except ValueError: 60 | return np.load(filename)["arr_0"] 61 | 62 | 63 | def setup_callbacks( 64 | log_filename: str, checkpoint_filename: str 65 | ) -> Tuple[EarlyStopping, CSVLogger, ModelCheckpoint, ReduceLROnPlateau]: 66 | """ 67 | Setup Keras callback functions: early stopping, CSV logger, model checkpointing, 68 | and reduce LR on plateau 69 | 70 | :param log_filename: the filename of the CSV log 71 | :param checkpoint_filename: the filename of the checkpoint 72 | :return: all the callbacks 73 | """ 74 | early_stopping = EarlyStopping(monitor="val_loss", patience=10) 75 | csv_logger = CSVLogger(log_filename) 76 | checkpoint = ModelCheckpoint( 77 | checkpoint_filename, 78 | monitor="loss", 79 | save_best_only=True, 80 | ) 81 | 82 | reduce_lr = ReduceLROnPlateau( 83 | monitor="val_loss", 84 | factor=0.5, 85 | patience=5, 86 | verbose=0, 87 | mode="auto", 88 | min_delta=0.000001, 89 | cooldown=0, 90 | min_lr=0, 91 | ) 92 | return [early_stopping, csv_logger, checkpoint, reduce_lr] 93 | 94 | 95 | def train_keras_model( 96 | model: Model, 97 | train_seq: InMemorySequence, 98 | valid_seq: InMemorySequence, 99 | loss: str, 100 | metrics: List[Any], 101 | callbacks: List[Any], 102 | epochs: int, 103 | ) -> None: 104 | """ 105 | Train a Keras model, but first compiling it and then fitting 106 | it to the given data 107 | 108 | :param model: the initialized model 109 | :param train_seq: the training data 110 | :param valid_seq: the validation data 111 | :param loss: the loss function 112 | :param metrics: the metric functions 113 | :param callbacks: the callback functions 114 | :param epochs: the number of epochs to use 115 | """ 116 | print(f"Available GPUs: {list_physical_devices('GPU')}") 117 | adam = Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999) 118 | 119 | model.compile( 120 | optimizer=adam, 121 | loss=loss, 122 | metrics=metrics, 123 | ) 124 | 125 | model.fit( 126 | train_seq, 127 | epochs=epochs, 128 | verbose=1, 129 | callbacks=callbacks, 130 | validation_data=valid_seq, 131 | max_queue_size=20, 132 | workers=1, 133 | use_multiprocessing=False, 134 | ) 135 | -------------------------------------------------------------------------------- /aizynthtrain/utils/onnx_converter.py: -------------------------------------------------------------------------------- 1 | """Module to convert Keras models to ONNX.""" 2 | 3 | import argparse 4 | from typing import Optional, Sequence 5 | 6 | import tensorflow as tf 7 | import tf2onnx 8 | from tensorflow.keras.models import load_model 9 | 10 | from aizynthtrain.utils import keras_utils 11 | 12 | 13 | def convert_to_onnx(keras_model: str, onnx_model: str) -> None: 14 | """Converts a Keras model to ONNX. 15 | 16 | :param keras_model: the filename of the Keras model 17 | :param onnx_model: the filename to save the ONNX model at 18 | """ 19 | custom_objects = { 20 | "top10_acc": keras_utils.top10_acc, 21 | "top50_acc": keras_utils.top50_acc, 22 | "tf": tf, 23 | } 24 | model = load_model(keras_model, custom_objects=custom_objects) 25 | 26 | tf2onnx.convert.from_keras(model, output_path=onnx_model) 27 | 28 | 29 | def main(args: Optional[Sequence[str]] = None) -> None: 30 | """Command-line interface to convert a Keras model to Onnx.""" 31 | parser = argparse.ArgumentParser("Tool to convert a Keras model to Onnx") 32 | parser.add_argument("keras_model", help="the Keras model file") 33 | parser.add_argument("onnx_model", help="the ONNX model file") 34 | args = parser.parse_args(args) 35 | 36 | convert_to_onnx(args.keras_model, args.onnx_model) 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /aizynthtrain/utils/reporting.py: -------------------------------------------------------------------------------- 1 | """Module containing routines to create reports""" 2 | import tempfile 3 | import argparse 4 | from typing import Optional, Sequence, Any 5 | 6 | import nbconvert 7 | import nbformat 8 | import jupytext 9 | import papermill 10 | 11 | 12 | def create_html_report_from_notebook( 13 | notebook_path: str, html_path: str, python_kernel: str, **parameters: Any 14 | ) -> None: 15 | """ 16 | Execute a Jupyter notebook and create an HTML file from the output. 17 | 18 | :param notebook_path: the path to the Jupyter notebook in py-percent format 19 | :param html_path: the path to the HTML output 20 | :param python_kernel: the python kernel to execute the notebook in 21 | :param parameters: additional parameters given to the notebook 22 | """ 23 | _, input_notebook = tempfile.mkstemp(suffix=".ipynb") 24 | _, output_notebook = tempfile.mkstemp(suffix=".ipynb") 25 | 26 | notebook = jupytext.read(notebook_path, fmt="py:percent") 27 | jupytext.write(notebook, input_notebook, fmt="ipynb") 28 | 29 | papermill.execute_notebook( 30 | input_notebook, 31 | output_notebook, 32 | kernel_name=python_kernel, 33 | language="python", 34 | parameters=parameters, 35 | ) 36 | 37 | with open(output_notebook, "r") as fileobj: 38 | notebook_nb = nbformat.read(fileobj, as_version=4) 39 | exporter = nbconvert.HTMLExporter() 40 | exporter.exclude_input = True 41 | notebook_html, _ = exporter.from_notebook_node(notebook_nb) 42 | with open(html_path, "w") as fileobj: 43 | fileobj.write(notebook_html) 44 | 45 | 46 | def main(args: Optional[Sequence[str]] = None) -> None: 47 | """Command-line interface for report generation""" 48 | parser = argparse.ArgumentParser("Run a Jupyter notebook to produce a report") 49 | parser.add_argument("--notebook", required=True) 50 | parser.add_argument("--report_path", required=True) 51 | parser.add_argument("--python_kernel", required=True) 52 | args, extra_args = parser.parse_known_args(args=args) 53 | 54 | keys = [arg[2:] for arg in extra_args[:-1:2]] 55 | kwargs = dict(zip(keys, extra_args[1::2])) 56 | 57 | create_html_report_from_notebook( 58 | args.notebook, args.report_path, args.python_kernel, **kwargs 59 | ) 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /aizynthtrain/utils/stereo_flipper.py: -------------------------------------------------------------------------------- 1 | """Module containing routines to flip stereo centers for specific reactions""" 2 | import argparse 3 | import re 4 | from typing import Optional, Sequence 5 | 6 | import pandas as pd 7 | from rxnutils.chem.template import ReactionTemplate 8 | from rxnutils.data.batch_utils import read_csv_batch 9 | 10 | # One or two @ characters in isolation 11 | STEREOCENTER_REGEX = r"[^@]([@]{1,2})[^@]" 12 | 13 | 14 | def _count_chiral_centres(row: pd.Series) -> int: 15 | prod_template = row.split(">>")[0] 16 | return len(re.findall(STEREOCENTER_REGEX, prod_template)) 17 | 18 | 19 | def _flip_chirality(row: pd.Series) -> pd.Series: 20 | """ 21 | Change @@ to @ and vice versa in a retrosynthesis template 22 | and then create a new template and a hash for that template 23 | """ 24 | dict_ = row.to_dict() 25 | prod_template = row["RetroTemplate"].split(">>")[0] 26 | nats = len(re.search(STEREOCENTER_REGEX, prod_template)[1]) 27 | assert nats in [1, 2] 28 | if nats == 1: 29 | dict_["RetroTemplate"] = row["RetroTemplate"].replace("@", "@@") 30 | else: 31 | dict_["RetroTemplate"] = row["RetroTemplate"].replace("@@", "@") 32 | dict_["TemplateHash"] = ReactionTemplate( 33 | dict_["RetroTemplate"], direction="retro" 34 | ).hash_from_bits() 35 | return pd.Series(dict_) 36 | 37 | 38 | def flip_stereo(data: pd.DataFrame, selection_query: str) -> pd.DataFrame: 39 | """ 40 | Find templates with one stereo center and flip it thereby creating 41 | new templates. These templates are appended onto the existing dataframe. 42 | 43 | A column "FlippedStereo" will be added to indicate if a stereocenter 44 | was flipped. 45 | 46 | :param data: the template library 47 | :param selection_query: only flip the stereo for a subset of rows 48 | :returns: the concatenated dataframe. 49 | """ 50 | sel_data = data.query(selection_query) 51 | sel_data = sel_data[~sel_data["RetroTemplate"].isna()] 52 | 53 | chiral_centers_count = sel_data["RetroTemplate"].apply(_count_chiral_centres) 54 | sel_flipping = chiral_centers_count == 1 55 | flipped_data = sel_data[sel_flipping].apply(_flip_chirality, axis=1) 56 | 57 | existing_hashes = set(sel_data["TemplateHash"]) 58 | keep_flipped = flipped_data["TemplateHash"].apply( 59 | lambda hash_: hash_ not in existing_hashes 60 | ) 61 | flipped_data = flipped_data[keep_flipped] 62 | 63 | all_data = pd.concat([data, flipped_data]) 64 | flag_column = [False] * len(data) + [True] * len(flipped_data) 65 | return all_data.assign(FlippedStereo=flag_column) 66 | 67 | 68 | def main(args: Optional[Sequence[str]] = None) -> None: 69 | """Command-line interface for stereo flipper""" 70 | parser = argparse.ArgumentParser("Generate flipped stereo centers") 71 | parser.add_argument("--input_path", required=True) 72 | parser.add_argument("--query", required=True) 73 | parser.add_argument("--batch", type=int, nargs=2) 74 | args = parser.parse_args(args=args) 75 | 76 | data = read_csv_batch(args.input_path, sep="\t", index_col=False, batch=args.batch) 77 | data = flip_stereo( 78 | data, 79 | args.query, 80 | ) 81 | data.to_csv(args.input_path, index=False, sep="\t") 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /aizynthtrain/utils/template_runner.py: -------------------------------------------------------------------------------- 1 | """Module containing routines to extract templates from reactions""" 2 | import argparse 3 | from typing import Optional, Sequence 4 | 5 | import pandas as pd 6 | from rxnutils.chem.reaction import ChemicalReaction, ReactionException 7 | from rxnutils.data.batch_utils import read_csv_batch 8 | 9 | 10 | def generate_templates( 11 | data: pd.DataFrame, 12 | radius: int, 13 | expand_ring: bool, 14 | expand_hetero: bool, 15 | ringbreaker_column: str, 16 | smiles_column: str, 17 | ) -> None: 18 | """ 19 | Generate templates for the reaction in a given dataframe 20 | 21 | This function will add 3 columns to the dataframe 22 | * RetroTemplate: the extracted retro template 23 | * TemplateHash: a unique identifier based on fingerprint bits 24 | * TemplateError: if not None, will identicate a reason why the extraction failed 25 | 26 | :param data: the data with reactions 27 | :param radius: the radius to use, unless using Ringbreaker logic 28 | :param expand_ring: if True, will expand template with ring atoms 29 | :param expand_hetero: if True, will expand template with bonded heteroatoms 30 | :param ringbreaker_column: if given, will apply Rinbreaker logic to rows where this column is True 31 | :param smiles_column: the column with the atom-mapped reaction SMILES 32 | """ 33 | 34 | def _row_apply( 35 | row: pd.Series, 36 | column: str, 37 | radius: int, 38 | expand_ring: bool, 39 | expand_hetero: bool, 40 | ringbreaker_column: str, 41 | ) -> pd.Series: 42 | rxn = ChemicalReaction(row[column], clean_smiles=False) 43 | general_error = { 44 | "RetroTemplate": None, 45 | "TemplateHash": None, 46 | "TemplateError": "General error", 47 | } 48 | 49 | if ringbreaker_column and row[ringbreaker_column]: 50 | expand_ring = True 51 | expand_hetero = True 52 | radius = 0 53 | elif ringbreaker_column and not row[ringbreaker_column]: 54 | expand_ring = False 55 | expand_hetero = False 56 | try: 57 | _, retro_template = rxn.generate_reaction_template( 58 | radius=radius, expand_ring=expand_ring, expand_hetero=expand_hetero 59 | ) 60 | except ReactionException as err: 61 | general_error["TemplateError"] = str(err) 62 | return pd.Series(general_error) 63 | except Exception: 64 | general_error["TemplateError"] = "General error when generating template" 65 | return pd.Series(general_error) 66 | 67 | try: 68 | hash_ = retro_template.hash_from_bits() 69 | except Exception: 70 | general_error[ 71 | "TemplateError" 72 | ] = "General error when generating template hash" 73 | return pd.Series(general_error) 74 | 75 | return pd.Series( 76 | { 77 | "RetroTemplate": retro_template.smarts, 78 | "TemplateHash": hash_, 79 | "TemplateError": None, 80 | } 81 | ) 82 | 83 | template_data = data.apply( 84 | _row_apply, 85 | axis=1, 86 | radius=radius, 87 | expand_ring=expand_ring, 88 | expand_hetero=expand_hetero, 89 | ringbreaker_column=ringbreaker_column, 90 | column=smiles_column, 91 | ) 92 | return data.assign( 93 | **{column: template_data[column] for column in template_data.columns} 94 | ) 95 | 96 | 97 | def main(args: Optional[Sequence[str]] = None) -> None: 98 | """Command-line interface for template extraction""" 99 | parser = argparse.ArgumentParser("Generate retrosynthesis templates") 100 | parser.add_argument("--input_path", required=True) 101 | parser.add_argument("--output_path", required=True) 102 | parser.add_argument("--radius", type=int, required=True) 103 | parser.add_argument("--expand_ring", action="store_true", default=False) 104 | parser.add_argument("--expand_hetero", action="store_true", default=False) 105 | parser.add_argument("--ringbreaker_column", default="") 106 | parser.add_argument("--smiles_column", required=True) 107 | parser.add_argument("--batch", type=int, nargs=2) 108 | args = parser.parse_args(args=args) 109 | 110 | data = read_csv_batch(args.input_path, sep="\t", index_col=False, batch=args.batch) 111 | data = generate_templates( 112 | data, 113 | args.radius, 114 | args.expand_ring, 115 | args.expand_hetero, 116 | args.ringbreaker_column, 117 | args.smiles_column, 118 | ) 119 | data.to_csv(args.output_path, index=False, sep="\t") 120 | 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /configs/uspto/expansion_model_pipeline_config.yml: -------------------------------------------------------------------------------- 1 | expansion_model_pipeline: 2 | python_kernel: aizynthtrain 3 | file_prefix: uspto 4 | routes_to_exclude: 5 | - ref_routes_n1.json 6 | - ref_routes_n5.json 7 | 8 | expansion_model_evaluation: 9 | file_prefix: uspto 10 | stock_for_finding: stock_for_eval_find.hdf5 11 | target_smiles: smiles_for_eval.txt 12 | stock_for_recovery: stock_for_eval_recov.txt 13 | reference_routes: routes_for_eval.json -------------------------------------------------------------------------------- /configs/uspto/expansion_pipeline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --time=24:00:00 3 | #SBATCH --nodes=1 4 | #SBATCH --cpus-per-task=8 5 | #SBATCH --mem=64g 6 | 7 | source ~/.bashrc 8 | conda activate aizynthtrain 9 | unset JUPYTER_PATH 10 | 11 | python -m aizynthtrain.pipelines.expansion_model_pipeline run --config expansion_model_pipeline_config.yml --max-workers 8 --max-num-splits 200 12 | -------------------------------------------------------------------------------- /configs/uspto/ringbreaker_model_pipeline_config.yml: -------------------------------------------------------------------------------- 1 | expansion_model_pipeline: 2 | python_kernel: aizynthtrain 3 | file_prefix: uspto_ringbreaker 4 | 5 | expansion_model_evaluation: 6 | file_prefix: uspto_ringbreaker -------------------------------------------------------------------------------- /configs/uspto/ringbreaker_pipeline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --time=8:00:00 3 | #SBATCH --nodes=1 4 | #SBATCH --cpus-per-task=8 5 | #SBATCH --mem=64g 6 | 7 | source ~/.bashrc 8 | conda activate aizynthtrain 9 | unset JUPYTER_PATH 10 | 11 | python -m aizynthtrain.pipelines.expansion_model_pipeline run --config ringbreaker_model_pipeline_config.yml --max-workers 8 --max-num-splits 200 -------------------------------------------------------------------------------- /configs/uspto/setup_folder.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import sys 3 | import os 4 | from pathlib import Path 5 | 6 | import requests 7 | import tqdm 8 | 9 | 10 | FILES_TO_DOWNLOAD = [ 11 | { 12 | "url": "https://ndownloader.figshare.com/files/23086469", 13 | "filename": "stock_for_eval_find.hdf5", 14 | }, 15 | { 16 | "url": "https://zenodo.org/record/7341155/files/ref_routes_n1.json?download=1", 17 | "filename": "ref_routes_n1.json", 18 | }, 19 | { 20 | "url": "https://zenodo.org/record/7341155/files/ref_routes_n5.json?download=1", 21 | "filename": "ref_routes_n5.json", 22 | }, 23 | ] 24 | 25 | 26 | def _download_file(url: str, filename: str) -> None: 27 | with requests.get(url, stream=True) as response: 28 | response.raise_for_status() 29 | total_size = int(response.headers.get("content-length", 0)) 30 | pbar = tqdm.tqdm( 31 | total=total_size, 32 | desc=f"Downloading {os.path.basename(filename)}", 33 | unit="B", 34 | unit_scale=True, 35 | ) 36 | with open(filename, "wb") as fileobj: 37 | for chunk in response.iter_content(chunk_size=1024): 38 | fileobj.write(chunk) 39 | pbar.update(len(chunk)) 40 | pbar.close() 41 | 42 | 43 | USPTO_CONFIG_PATH = Path(__file__).parent 44 | 45 | FILES_TO_COPY = [ 46 | "expansion_model_pipeline_config.yml", 47 | "ringbreaker_model_pipeline_config.yml", 48 | "template_pipeline_config.yml", 49 | "routes_for_eval.json", 50 | "smiles_for_eval.txt", 51 | "stock_for_eval_recov.txt", 52 | ] 53 | 54 | if __name__ == "__main__": 55 | project_path = Path(sys.argv[1]) 56 | 57 | project_path.mkdir(exist_ok=True) 58 | 59 | for filename in FILES_TO_COPY: 60 | print(f"Copy {filename} to project folder") 61 | shutil.copy(USPTO_CONFIG_PATH / filename, project_path) 62 | 63 | for filespec in FILES_TO_DOWNLOAD: 64 | try: 65 | _download_file(filespec["url"], str(project_path / filespec["filename"])) 66 | except requests.HTTPError as err: 67 | print(f"Download failed with message {str(err)}") 68 | sys.exit(1) 69 | -------------------------------------------------------------------------------- /configs/uspto/smiles_for_eval.txt: -------------------------------------------------------------------------------- 1 | CSc1nnc(C(C)(C)C)c(=O)n1N=Cc1ccc(C(C)C)cc1 2 | O=C(C1=CC2c3cccc4[nH]cc(c34)C[C@H]2N(C(=O)Nc2cccs2)C1)N1CCCC1 3 | COc1cc(OC)c2c(c1)O[C@@]1(c3ccc(Br)cc3)[C@H](c3cccc(F)c3)C[C@H](NC(N)=O)[C@@]21O 4 | C[C@]12CC[C@H](OC=O)C[C@H]1CC[C@@H]1[C@@H]2CC[C@]2(C)[C@@H]([C@@]34C=CC(=O)O[C@@H]3O4)C[C@H]3O[C@@]312 5 | [C-]#[N+][C@@]1(C)[C@H]2CC[C@H]3[C@@H]4[C@@H]2[C@@H](C[C@H](C)[C@H]4CC[C@]3(C)N=C=S)C[C@H]1C 6 | CN(CCCC(C=O)NC(=O)OC(C)(C)C)C(=O)C(F)(F)F 7 | CC1(C)Cc2sccc2C=[N+]1[O-] 8 | CC(C)=CCC/C(C)=C/CC/C(C)=C/COCC1(O)CN2CCC1CC2 9 | O[C@H]1[C@H](O)[C@@H](O)[C@H](F)[C@@H](O)[C@@H]1O 10 | N=C(NCCO)C(Cl)(Cl)Cl 11 | Cc1cc2c3c(c1C)C1(C(=O)N3C(C)(C)CC2C)C(C#N)C(=N)OC2=CC(C)(C)CC(=O)C21 12 | O=C(C(O)O)C(O)C(=O)C(O)C(O)O 13 | CC(C)CCN1CCN(C2CSCCSC2)CC1CCO 14 | CCc1nnc(NC(=O)C(C)=NC(C)=O)s1 15 | N=CCCC(=O)O 16 | OC(O)=c1c(C(O)O)c2oc1=CC2 17 | CCCCCCN1CCc2c(c3cc(I)ccc3n2C)C1 18 | CN(C)C(=O)OC1CC2CCCC1N2 19 | Cc1sc2ncnc(Sc3nc4ccccc4s3)c2c1C 20 | Cc1ccc[n+](C)c1CCCO 21 | NCCCC(N[C@@H](CNC(=S)Nc1ccccc1)Cc1ccccc1)C(=O)NCc1ccccc1 22 | CCCC1(CCC)OOC(CCC)(CCC)OO1 23 | Fc1cc(Cl)cc(C#CCSc2nsnc2C23CN4CC2C3C4)c1 24 | CN(C)c1ccc(-n2cccc2/C=C/[N+](=O)[O-])cc1 25 | Oc1ccc2cn(Cc3ccc(F)cc3F)c(O)c2c1O 26 | FC(F)(F)c1cccc(N2CCN(C3CCN(c4ccc(-c5ncns5)nn4)CC3)CC2)c1 27 | C=CCN(c1ncccc1CNc1nc(Nc2cccc(Br)c2)ncc1C(F)(F)F)S(C)(=O)=O 28 | Cc1cc(C)c(-c2c(N)cc(Br)cc2Br)c(N)c1 29 | Cc1csc(=N)n1C(=O)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)F 30 | NC1CCC2(CCCCC2)CC1 31 | CNS(=O)(=O)CCCCN=[N+]=[N-] 32 | COc1c(C)cc(/C(=C\CCc2nnc(C)o2)c2cc(C)c3onc(OC)c3c2)cc1C(=O)SC 33 | C#CCN(C)[C@H](C[18F])CC(C)C 34 | CC1=C(C2=C(C)C(C(C)C)C3C(=O)C(=O)C(O)C3C2O)C(=O)C2C(O)C(O)C(=O)C2C1C(C)C 35 | COc1c(O)c2c(c(I)c1OC)C[C@H](C)[C@H](C)Cc1c(I)c(OC)c(OC)c(OC)c1-2 36 | C1CC(CN2C[C@@H]3CC[C@H]2CN(Cc2noc(C4CCCC4)n2)C3)C1 37 | O=C1CCC[C@@H](C2=CCCCC2=O)C1 38 | CC(C=NN(C)C(=N)N)=NN(C)C(=N)N 39 | NC(=S)c1ncnc2c1ncn2C1O[C@H](CO)[C@@H](O)[C@H]1O 40 | Cc1ccc(-n2c(O)c3c(c2O)[C@H](C2C(=O)N(C4CCCCC4)[C@H]2c2ccc(Cl)cc2)C=CC3)cc1 41 | C1=Nn2c(nnc2-c2[nH]nc3c2CCC3)SC1 42 | CC(C)C1NC(=O)C[C@H]2/C=C/CCSSCC(NC1=O)C(=O)Nc1ccccc1C(=O)NCC(=O)O2 43 | O=c1cc(-c2ccsc2)oc2cc(-c3ccc4c(c3)OCO4)ccc12 44 | ON=C(O)c1cnc(NC23CC4CC(CC(C4)C2)C3)nc1 45 | C=Cc1[nH]c(O)c(Cc2c(O)[nH]c(C=C)c2O)c1O 46 | Nc1nnc2c([nH]c3ccccc32)c1-c1ccccc1Cl 47 | O=c1[nH]c2c(F)c(F)c(F)c(F)c2[nH]c1=O 48 | N#Cc1ccc(C2CNCCNCCNCCN2)cc1 49 | CCN(Cc1ccncc1)Cc1coc(-c2ccc(OC)cc2)n1 50 | CC/C=C/C=C/C=C/C=C/C=C/OC[C@@H](O)CO 51 | -------------------------------------------------------------------------------- /configs/uspto/stock_for_eval_recov.txt: -------------------------------------------------------------------------------- 1 | FCYRSDMGOLYDHL-UHFFFAOYSA-N 2 | SDMCZCALYDCRBH-UHFFFAOYSA-N 3 | DWJPOVQCNHQSKY-UHFFFAOYSA-N 4 | YPGCWEMNNLXISK-ZETCQYMHSA-N 5 | YAPQBXQYLJRXSA-UHFFFAOYSA-N 6 | CSCPPACGZOOCGX-UHFFFAOYSA-N 7 | BDAGIHXWWSANSR-UHFFFAOYSA-N 8 | KGEXISHTCZHGFT-UHFFFAOYSA-N 9 | QSHLXVTVXQTHBS-UHFFFAOYSA-N 10 | ROSDSFDQCJNGOL-UHFFFAOYSA-N 11 | OAFSBTSMSMVVAK-UHFFFAOYSA-N 12 | CWMFRHBXRUITQE-UHFFFAOYSA-N 13 | BUDQDWGNQVEFAC-UHFFFAOYSA-N 14 | YQYGPGKTNQNXMH-UHFFFAOYSA-N 15 | LYCAIKOWRPUZTN-UHFFFAOYSA-N 16 | XZTSFVPMMQNIAJ-UHFFFAOYSA-N 17 | IVRMZWNICZWHMI-UHFFFAOYSA-N 18 | PIICEJLVQHRZGT-UHFFFAOYSA-N 19 | KRRTXVSBTPCDOS-UHFFFAOYSA-N 20 | YFNUFJPVGUIACT-UHFFFAOYSA-N 21 | QNEGDGPAXKYZHZ-UHFFFAOYSA-N 22 | YZDUNYUENNWBRE-KBPBESRZSA-N 23 | NQRYJNQNLNOLGT-UHFFFAOYSA-N 24 | MOHYOXXOKFQHDC-UHFFFAOYSA-N 25 | TZGFQIXRVUHDLE-UHFFFAOYSA-N 26 | YYROPELSRYBVMQ-UHFFFAOYSA-N 27 | JTYUIAOHIYZBPB-UHFFFAOYSA-N 28 | XTHPWXDJESJLNJ-UHFFFAOYSA-N 29 | SUSQOBVLVYHIEX-UHFFFAOYSA-N 30 | VGYVBEJDXIPSDL-UHFFFAOYSA-N 31 | GEYOCULIXLDCMW-UHFFFAOYSA-N 32 | QZILRKMXCVPWQS-UHFFFAOYSA-N 33 | FFTDQGOHNPADKU-UHFFFAOYSA-N 34 | PCLIMKBDDGJMGD-UHFFFAOYSA-N 35 | WFDIJRYMOXRFFG-UHFFFAOYSA-N 36 | PAQNFUWHQKGPHT-JRZJBTRGSA-N 37 | XHFUVBWCMLLKOZ-UHFFFAOYSA-N 38 | ORJRJVSMOMIYGH-UHFFFAOYSA-N 39 | AVXURJPOCDRRFD-UHFFFAOYSA-N 40 | FOZVXADQAHVUSV-UHFFFAOYSA-N 41 | XLYOFNOQVPJJNP-UHFFFAOYSA-N 42 | ZQWPRMPSCMSAJU-UHFFFAOYSA-N 43 | MEQYSZHIRFNJSS-UHFFFAOYSA-N 44 | OKFVRSHFNHZGPQ-UHFFFAOYSA-N 45 | IGJLXIUSOMMUIV-UHFFFAOYSA-N 46 | HVTICUPFWKNHNG-UHFFFAOYSA-N 47 | WEVYAHXRMPXWCK-UHFFFAOYSA-N 48 | TVSXDZNUTPLDKY-UHFFFAOYSA-N 49 | XMBWDFGMSWQBCA-UHFFFAOYSA-N 50 | YPLVPFUSXYSHJD-UHFFFAOYSA-N 51 | OAKJQQAXSVQMHS-UHFFFAOYSA-N 52 | ZMXDDKWLCZADIW-UHFFFAOYSA-N 53 | DYHSDKLCOJIUFX-UHFFFAOYSA-N 54 | LUOBOECBFPRLFH-RKDXNWHRSA-N 55 | GLUUGHFHXGJENI-UHFFFAOYSA-N 56 | XKJCHHZQLQNZHY-UHFFFAOYSA-N 57 | FWLWTILKTABGKQ-UHFFFAOYSA-N 58 | WERABQRUGJIMKQ-UHFFFAOYSA-N 59 | BCNZYOJHNLTNEZ-UHFFFAOYSA-N 60 | UGEJOEBBMPOJMT-UHFFFAOYSA-N -------------------------------------------------------------------------------- /configs/uspto/template_pipeline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --time=24:00:00 3 | #SBATCH --nodes=1 4 | #SBATCH --cpus-per-task=32 5 | #SBATCH --mem=64g 6 | 7 | source ~/.bashrc 8 | conda activate aizynthtrain 9 | unset JUPYTER_PATH 10 | 11 | python -m aizynthtrain.pipelines.template_pipeline run --config template_pipeline_config.yml --max-workers 32 --max-num-splits 200 -------------------------------------------------------------------------------- /configs/uspto/template_pipeline_config.yml: -------------------------------------------------------------------------------- 1 | template_pipeline: 2 | python_kernel: aizynthtrain 3 | data_import_class: aizynthtrain.data.uspto.UsptoImporter 4 | data_import_config: 5 | data_path: uspto_data_mapped.csv 6 | selected_templates_prefix: uspto 7 | min_template_occurrence: 3 8 | -------------------------------------------------------------------------------- /env-dev.yml: -------------------------------------------------------------------------------- 1 | name: aizynthtrain 2 | channels: 3 | - https://conda.anaconda.org/conda-forge 4 | - defaults 5 | dependencies: 6 | - python>=3.9,<3.11 7 | - poetry>=1.1.4,<2.0 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "aizynthtrain" 3 | version = "0.3.0" 4 | description = "Tools to train reaction predictions models for AiZynthFinder" 5 | authors = ["Genheden, Samuel "] 6 | license = "Apache 2.0" 7 | 8 | [tool.poetry.dependencies] 9 | python = ">=3.9,<3.11" 10 | jupytext = "^1.3.3" 11 | papermill = { git = "https://github.com/nteract/papermill.git" } 12 | mendeleev = "^0.10.0" 13 | nbconvert = "^7.2.3" 14 | onnx = "^1.13.1" 15 | tf2onnx = "^1.13.0" 16 | urllib3 = "<2.0" 17 | aizynthfinder = { version = "^4.0.0", extras = ["all", "tf"] } 18 | reaction-utils = "^1.3.0" 19 | onnxruntime = "<1.17.0" 20 | 21 | [tool.poetry.dev-dependencies] 22 | 23 | [build-system] 24 | requires = ["poetry-core>=1.0.0"] 25 | build-backend = "poetry.core.masonry.api" 26 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthtrain/55df633bbdf971378e7f031e8e89190cde1b17fb/tests/__init__.py -------------------------------------------------------------------------------- /tests/data/test_keras.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MolecularAI/aizynthtrain/55df633bbdf971378e7f031e8e89190cde1b17fb/tests/data/test_keras.hdf5 -------------------------------------------------------------------------------- /tests/test_configs.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import pytest 3 | 4 | from aizynthtrain.utils.configs import ( 5 | load_config, 6 | TemplatePipelineConfig, 7 | ExpansionModelPipelineConfig, 8 | TemplateLibraryConfig, 9 | ExpansionModelHyperparamsConfig, 10 | ExpansionModelEvaluationConfig, 11 | ) 12 | 13 | 14 | @pytest.fixture 15 | def make_config_file(tmpdir): 16 | filename = str(tmpdir / "config.yml") 17 | 18 | def wrapper(config_dict): 19 | with open(filename, "w") as fileobj: 20 | yaml.dump(config_dict, fileobj) 21 | return filename 22 | 23 | return wrapper 24 | 25 | 26 | def test_template_pipeline_config(make_config_file): 27 | dict_ = { 28 | "python_kernel": "aizynth", 29 | "data_import_class": "aizynthtrain.data.uspto.importer.UsptoImporter", 30 | } 31 | filename = make_config_file({"template_pipeline": dict_}) 32 | 33 | config = load_config(filename, "template_pipeline") 34 | 35 | assert type(config) == TemplatePipelineConfig 36 | assert config.python_kernel == dict_["python_kernel"] 37 | assert config.data_import_class == dict_["data_import_class"] 38 | 39 | 40 | def test_expansion_model_pipeline_config(make_config_file): 41 | dict_ = { 42 | "python_kernel": "aizynth", 43 | "file_prefix": "uspto", 44 | "routes_to_exclude": ["ref_routes_n1.json", "ref_routes_n5.json"], 45 | "library_config": {"template_set": "uspto_templates"}, 46 | "postfixes": {"report": "model_report.html"}, 47 | "model_hyperparams": {"epochs": 50}, 48 | } 49 | filename = make_config_file({"expansion_model_pipeline": dict_}) 50 | 51 | config = load_config(filename, "expansion_model_pipeline") 52 | 53 | assert type(config) == ExpansionModelPipelineConfig 54 | assert config.python_kernel == dict_["python_kernel"] 55 | assert config.file_prefix == dict_["file_prefix"] 56 | assert len(config.routes_to_exclude) == 2 57 | 58 | assert type(config.library_config) == TemplateLibraryConfig 59 | assert config.library_config.template_set == dict_["library_config"]["template_set"] 60 | 61 | assert type(config.model_hyperparams) == ExpansionModelHyperparamsConfig 62 | assert config.model_hyperparams.epochs == 50 63 | 64 | assert config.filename("report") == "uspto_model_report.html" 65 | assert config.filename("model_labels", "training") == "uspto_training_labels.npz" 66 | 67 | 68 | def test_expansion_model_evaluation_config(make_config_file): 69 | dict_ = { 70 | "file_prefix": "uspto", 71 | "stock_for_finding": "stock_for_eval_find.hdf5", 72 | "target_smiles": "target_smiles", 73 | "search_properties_for_finding": {}, 74 | } 75 | filename = make_config_file({"expansion_model_evaluation": dict_}) 76 | 77 | config = load_config(filename, "expansion_model_evaluation") 78 | 79 | assert type(config) == ExpansionModelEvaluationConfig 80 | assert config.file_prefix == dict_["file_prefix"] 81 | assert config.stock_for_finding == dict_["stock_for_finding"] 82 | assert config.target_smiles == dict_["target_smiles"] 83 | assert config.search_properties_for_finding == {} 84 | 85 | assert ( 86 | config.filename("multistep_report") 87 | == "uspto_model_validation_multistep_report.json" 88 | ) 89 | -------------------------------------------------------------------------------- /tests/test_data_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pytest 4 | import pandas as pd 5 | 6 | from aizynthtrain.utils.data_utils import extract_route_reactions, split_data 7 | 8 | 9 | @pytest.fixture 10 | def save_trees(tmpdir): 11 | trees = [ 12 | { 13 | "type": "mol", 14 | "children": [ 15 | {"type": "reaction", "metadata": {"reaction_hash": "AAAA"}}, 16 | { 17 | "type": "reaction", 18 | "metadata": {"reaction_hash": "BBBB"}, 19 | "children": [{"type": "mol"}, {"type": "mol"}], 20 | }, 21 | ], 22 | }, 23 | { 24 | "type": "mol", 25 | "children": [ 26 | { 27 | "type": "reaction", 28 | "metadata": {"reaction_hash": "BBBB"}, 29 | "children": [ 30 | {"type": "mol"}, 31 | { 32 | "type": "mol", 33 | "children": [ 34 | { 35 | "type": "reaction", 36 | "metadata": {"reaction_hash": "CCC"}, 37 | } 38 | ], 39 | }, 40 | ], 41 | }, 42 | ], 43 | }, 44 | ] 45 | filenames = [str(tmpdir / f"tree{idx}.json") for idx in range(len(trees))] 46 | for tree, filename in zip(trees, filenames): 47 | with open(filename, "w") as fileobj: 48 | json.dump([tree], fileobj) 49 | return filenames 50 | 51 | 52 | def test_extract_route_reactions(save_trees): 53 | filenames = save_trees 54 | 55 | assert extract_route_reactions(filenames[:1]) == {"AAAA", "BBBB"} 56 | assert extract_route_reactions(filenames[1:]) == {"CCC", "BBBB"} 57 | assert extract_route_reactions(filenames) == {"AAAA", "CCC", "BBBB"} 58 | 59 | 60 | def test_split_even_in_two(): 61 | df = pd.DataFrame({"A": [1, 2, 3, 4]}) 62 | 63 | val_indices = [] 64 | train_indices = [] 65 | 66 | split_data(df, 0.5, 666, train_indices, val_indices) 67 | 68 | assert len(train_indices) == 2 69 | assert len(val_indices) == 2 70 | assert set(val_indices).intersection(train_indices) == set() 71 | 72 | 73 | def test_split_even_in_three(): 74 | df = pd.DataFrame({"A": [1, 2, 3, 4, 5, 6]}) 75 | 76 | val_indices = [] 77 | train_indices = [] 78 | test_indices = [] 79 | 80 | split_data(df, 0.34, 666, train_indices, val_indices, test_indices) 81 | 82 | assert len(train_indices) == 2 83 | assert len(val_indices) == 2 84 | assert len(test_indices) == 2 85 | assert set(val_indices).intersection(train_indices) == set() 86 | assert set(val_indices).intersection(test_indices) == set() 87 | 88 | 89 | def test_split_uneven_three(): 90 | df = pd.DataFrame({"A": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}) 91 | 92 | val_indices = [] 93 | train_indices = [] 94 | test_indices = [] 95 | 96 | split_data(df, 0.90, 666, train_indices, val_indices, test_indices) 97 | 98 | assert len(train_indices) == 9 99 | assert len(val_indices) + len(test_indices) == 1 100 | assert set(val_indices).intersection(train_indices) == set() 101 | assert set(val_indices).intersection(test_indices) == set() 102 | 103 | 104 | def test_split_low_data(): 105 | df = pd.DataFrame({"A": [1]}) 106 | 107 | val_indices = [] 108 | train_indices = [] 109 | 110 | split_data(df, 0.99, 666, train_indices, val_indices) 111 | 112 | assert len(train_indices) == 1 113 | assert len(val_indices) == 0 114 | 115 | -------------------------------------------------------------------------------- /tests/test_file_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from aizynthtrain.utils.files import prefix_filename 4 | 5 | 6 | @pytest.mark.parametrize( 7 | ("prefix", "postfix", "expected"), 8 | [("a", "b", "a_b"), ("", "b", "b"), ("a", "", "a_"), ("", "", "")], 9 | ) 10 | def test_prefix_filename(prefix, postfix, expected): 11 | assert prefix_filename(prefix, postfix) == expected 12 | -------------------------------------------------------------------------------- /tests/test_onnx_converter.py: -------------------------------------------------------------------------------- 1 | """Tests the utils module to convert Keras models to ONNX.""" 2 | 3 | import pytest 4 | import pytest_mock 5 | 6 | from aizynthtrain.utils import onnx_converter 7 | 8 | 9 | @pytest.fixture 10 | def mock_tf2onnx_convert_from_keras(mocker: pytest_mock.MockerFixture): 11 | return mocker.patch( 12 | "aizynthtrain.utils.onnx_converter.tf2onnx.convert.from_keras", 13 | return_value=None, 14 | ) 15 | 16 | 17 | @pytest.fixture 18 | def mock_load_keras_model(mocker: pytest_mock.MockerFixture): 19 | return mocker.patch( 20 | "aizynthtrain.utils.onnx_converter.load_model", 21 | return_value=None, 22 | ) 23 | 24 | 25 | def test_mock_convert_to_onnx(mock_tf2onnx_convert_from_keras, mock_load_keras_model) -> None: 26 | test_convert = onnx_converter.convert_to_onnx("test_model_path", "test_onnx_path") 27 | mock_tf2onnx_convert_from_keras.assert_called_once() 28 | mock_load_keras_model.assert_called_once() 29 | 30 | 31 | def test_convert_to_onnx(tmpdir, shared_datadir) -> None: 32 | test_convert = onnx_converter.convert_to_onnx( 33 | shared_datadir / "test_keras.hdf5", 34 | str(tmpdir / "test_model.onnx") 35 | ) 36 | assert (tmpdir / "test_model.onnx").exists() 37 | --------------------------------------------------------------------------------