├── .dockerignore ├── .gitignore ├── LICENSE ├── README.md ├── data ├── .gitkeep ├── USPTO_50k_MHN_prepro.csv.gz ├── figs │ └── overview_tikz_transp.png ├── processed │ ├── uspto_sm_historian.uspto_sm_.json.gz │ ├── uspto_sm_reactions.uspto_sm_.json.gz │ ├── uspto_sm_retro.templates.uspto_sm_.json.gz │ ├── uspto_sm_templates.df.json.gz │ ├── uspto_sm_test.appl_matrix.npz │ ├── uspto_sm_test.input.smiles.npy │ ├── uspto_sm_test.labels.classes.npy │ ├── uspto_sm_train.appl_matrix.npz │ ├── uspto_sm_train.input.smiles.npy │ ├── uspto_sm_train.labels.classes.npy │ ├── uspto_sm_valid.appl_matrix.npz │ ├── uspto_sm_valid.input.smiles.npy │ └── uspto_sm_valid.labels.classes.npy └── temprel-fortunato │ └── template-relevance-master │ ├── .gitignore │ ├── .gitlab-ci.yml │ ├── Dockerfile │ ├── Dockerfile.gpu │ ├── README.md │ ├── bin │ ├── calculate_applicabilty.py │ ├── get_uspto_50k.py │ ├── hyperopt.sh │ ├── process.py │ ├── save_model.py │ ├── test.py │ ├── train.py │ └── train_appl.py │ ├── requirements.txt │ ├── setup.py │ └── temprel │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── download.py │ └── loaders.py │ ├── evaluate │ ├── accuracy.py │ ├── diversity.py │ ├── reciprocal_rank.py │ ├── roc.py │ ├── template_popularity.py │ └── topk_appl.py │ ├── models │ ├── __init__.py │ ├── layers.py │ ├── losses.py │ ├── metrics.py │ └── models.py │ ├── rdkit.py │ └── templates │ ├── __init__.py │ ├── extract.py │ └── validate.py ├── env.yml ├── mhnreact ├── .gitkeep ├── __init__.py ├── data.py ├── inference.py ├── inspect.py ├── model.py ├── molutils.py ├── plotutils.py ├── retroeval.py ├── retrosyn.py ├── train.py ├── utils.py └── view.py ├── notebooks ├── 01_prepro_uspto_sm_lg.ipynb ├── 02_prepro_uspto_50k.ipynb ├── 03_prepro_uspto_full.ipynb ├── 04_prepro_time_split.ipynb ├── 11_training_template_relevance_prediction.ipynb ├── 12_training_single_step_retrosynthesis.ipynb ├── 20_evaluation.ipynb ├── 30_retrieval_fast_scalable.ipynb └── colab_MHNreact_demo.ipynb ├── scripts ├── .gitkeep ├── make_env.sh ├── train_ssr_mhn.sh ├── train_tr_dnn_fortunato.sh ├── train_tr_dnn_segler.sh └── train_tr_mhn.sh ├── setup.py └── tools └── docker ├── Dockerfile ├── README.md └── env.yml /.dockerignore: -------------------------------------------------------------------------------- 1 | **/.git 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | /mhnreact/.ipynb_checkpoints/ 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | NeurIPS | 2021 License 2 | Copyright (c) 2021, Institute for Machine Learning, Johannes Kepler University Linz 3 | 4 | All rights reserved. 5 | 6 | Use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 10 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 11 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 12 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 13 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 14 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 15 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 16 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 17 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 18 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 19 | 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MHNreact 2 | [![arXiv](https://img.shields.io/badge/acs.jcim-1c01065-yellow.svg)](https://doi.org/10.1021/acs.jcim.1c01065) 3 | [![arXiv](https://img.shields.io/badge/arXiv-2104.03279-b31b1b.svg)](https://arxiv.org/abs/2104.03279) 4 | [![License](https://img.shields.io/badge/License-BSD%202--Clause-orange.svg)](https://opensource.org/licenses/BSD-2-Clause) 5 | [![Hf Demo](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/uragankatrrin/MHN-React) 6 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-jku/mhn-react/blob/main/notebooks/colab_MHNreact_demo.ipynb) 7 | 8 | **[Abstract](#modern-hopfield-networks-for-few-and-zero-shot-reaction-template-prediction)** 9 | | **[Environment](#environment)** 10 | | **[Data](#data-and-processing)** 11 | | **[Training](#training)** 12 | | **[Loading](#loading-in-trained-models-and-evaluation)** 13 | | **[Citation](#citation)** 14 | 15 | Adapting modern Hopfield networks [(Ramsauer et al., 2021)](#mhn) (MHN) to associate different data modalities, molecules and reaction templates, to improve predictive performance for rare templates and single-step retrosynthesis. 16 | 17 | ![overview_image](data/figs/overview_tikz_transp.png?raw=true "Overview Figure") 18 | 19 | ## Improving Few- and Zero-Shot Reaction Template Prediction Using Modern Hopfield Networks 20 | [paper](https://pubs.acs.org/doi/10.1021/acs.jcim.1c01065) 21 | 22 | Philipp Seidl, Philipp Renz, Natalia Dyubankova, Paulo Neves, Jonas Verhoeven, Marwin Segler, Jörg K. Wegner, Sepp Hochreiter, Günter Klambauer 23 | 24 | Finding synthesis routes for molecules of interest is essential in the discovery of new drugs and materials. To find such routes, computer-assisted synthesis planning (CASP) methods are employed, which rely on a single-step model of chemical reactivity. In this study, we introduce a template-based single-step retrosynthesis model based on Modern Hopfield Networks, which learn an encoding of both molecules and reaction templates in order to predict the relevance of templates for a given molecule. The template representation allows generalization across different reactions and significantly improves the performance of template relevance prediction, especially for templates with few or zero training examples. With inference speed up to orders of magnitude faster than baseline methods, we improve or match the state-of-the-art performance for top-k exact match accuracy for k ≥ 3 in the retrosynthesis benchmark USPTO-50k. 25 | 26 | ## Minimal working example 27 | 28 | Opent the following colab for a quick training example [train-colab](https://colab.research.google.com/github/ml-jku/mhn-react/blob/main/notebooks/colab_MHNreact_demo.ipynb). 29 | 30 | ## Environment 31 | 32 | ### Anaconda 33 | 34 | When using `conda`, an environment can be set up using 35 | ```bash 36 | conda env create -f env.yml 37 | ``` 38 | To activate the environment call ```conda activate mhnreact_env```. 39 | 40 | Additionally one needs to install [template-relevance](https://gitlab.com/mefortunato/template-relevance) which is included in this package, as well as [rdchiral](https://github.com/connorcoley/rdchiral) using pip: 41 | ```bash 42 | cd data/temprel-fortunato/template-relevance-master/ 43 | pip install -e . 44 | pip install -e "git://github.com/connorcoley/rdchiral.git#egg=rdchiral" 45 | ``` 46 | 47 | You may need to adjust the CUDA version. 48 | The code was tested with: 49 | - rdkit 2021.03.1 and 2020.03.4 50 | - python 3.7 and 3.8 51 | - pytorch 1.6 52 | - rdchiral 53 | - template-relevance (./data/temprel-forunato) 54 | - CGRtools (only for preparing USPTO-golden) 55 | 56 | 57 | A second option is to run 58 | ```bash 59 | conda create -n mhnreact_env python=3.8 60 | eval "$(conda shell.bash hook)" 61 | conda activate mhnreact_env 62 | conda install -c conda-forge rdkit 63 | pip install torch scipy ipykernel matplotlib sklearn swifter 64 | cd data/temprel-fortunato/template-relevance-master/ 65 | pip install -e . 66 | pip install -e "git://github.com/connorcoley/rdchiral.git#egg=rdchiral" 67 | ``` 68 | which is equivialent to running the script ```bash ./scripts/make_env.sh``` 69 | 70 | ### Docker 71 | 72 | Another option is the provided docker-file within ```./tools/docker/``` with the following command. 73 | 74 | ``` 75 | DOCKER_BUILDKIT=1 docker build -t mhnreact:latest -f Dockerfile ../.. 76 | ``` 77 | 78 | ## Data and processing 79 | 80 | The preprocessed data is contained in this repository: ```./data/processed/uspto_sm_*``` files contain the preprocessed files for template relevance prediction. 81 | 82 | For single-step retrosynthesis the preprocessed and split data can be found in ````./data/USPTO_50k_MHN_prepro.csv```` 83 | 84 | All preprocessing steps can be replicated and found in ````./examples/prepro_*.ipynb```` for USPTO-sm, USPTO-lg as well as for USPTO-50k and USPTO-full. 85 | USPTO-lg as well as USPTO-full are not contained due to their size, and would have to be created using the coresponding notebook. 86 | 87 | 88 | ## Training 89 | 90 | Models can be trained using ````python mhnreact/train.py -m```` 91 | 92 | Selected calls are documented within ````./notebooks/*_training_*.ipynb````. 93 | 94 | Arguments are documented within the module which can be retreived by adding ```--help``` to the call. Within the ```notebooks``` folder there are notebooks containing several examples. 95 | 96 | Some main parameters are: 97 | - ``model_type``: Model-type, choose from 'segler', 'fortunato', 'mhn' or 'staticQK', default:'mhn' 98 | - ``dataset_type``: Dataset 'sm', 'lg' for template relevance prediction; (use --csv_path for single-step retrosynthesis input) 99 | - ``fp_type``: Fingerprint type for the input only!: default: 'morgan', other options: 'rdk', 'ECFP', 'ECFC', 'MxFP', 'Morgan2CBF' 100 | - ``template_fp_tpye``: Template-fingerprint type: default: 'rdk', other options: 'MxFP', 'rdkc', 'tfidf', 'random' 101 | - ``hopf_beta``: hopfield beta parameter, default=0.005 102 | - ``hopf_asso_dim``: association dimension, default=512 103 | - ``ssretroeval``: single-step retrosynthesis evaluation, default=False 104 | 105 | 106 | an example call for single-step retrosynthesis is: 107 | ```bash 108 | python -m mhnreact.train --model_type=mhn --fp_size=4096 --fp_type morgan --template_fp_type rdk --concat_rand_template_thresh 1 \ 109 | --exp_name test --dataset_type 50k --csv_path ./data/USPTO_50k_MHN_prepro.csv.gz --ssretroeval True --seed 0 110 | ``` 111 | 112 | ## Loading in trained Models and Evaluation 113 | 114 | How to load in trained models can be seen in ```./examples/20_evaluation.ipynb```. The model is then used to predict on a test set. 115 | 116 | ## Train on custom data 117 | 118 | Preprocess the data in a format as can be found in ````./data/USPTO_50k_MHN_prepro.csv```` and use the argument ```--csv_path```. 119 | 120 | ## Citation 121 | 122 | To cite this work, you can use the following bibtex entry: 123 | ```bibtex 124 | @article{seidl2021modern, 125 | author = {Seidl, Philipp and Renz, Philipp and Dyubankova, Natalia and Neves, Paulo and Verhoeven, Jonas and Segler, Marwin and Wegner, J{\"o}rg K. and Hochreiter, Sepp and Klambauer, G{\"u}nter}, 126 | title = {Improving Few- and Zero-Shot Reaction Template Prediction Using Modern Hopfield Networks}, 127 | journal = {Journal of Chemical Information and Modeling}, 128 | volume = {62}, 129 | number = {9}, 130 | pages = {2111-2120}, 131 | institution = {Institute for Machine Learning, Johannes Kepler University, Linz}, 132 | year = {2022}, 133 | doi = {10.1021/acs.jcim.1c01065}, 134 | url = {https://doi.org/10.1021/acs.jcim.1c01065}, 135 | } 136 | ``` 137 | 138 | ## References 139 | - Ramsauer et al.(2020). ICLR2021 ([pdf](https://arxiv.org/abs/2008.02217)) 140 | 141 | ## Keywords 142 | Drug Discovery, CASP, Machine Learning, Synthesis, Zero-shot, Modern Hopfield Networks 143 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/.gitkeep -------------------------------------------------------------------------------- /data/USPTO_50k_MHN_prepro.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/USPTO_50k_MHN_prepro.csv.gz -------------------------------------------------------------------------------- /data/figs/overview_tikz_transp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/figs/overview_tikz_transp.png -------------------------------------------------------------------------------- /data/processed/uspto_sm_historian.uspto_sm_.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/processed/uspto_sm_historian.uspto_sm_.json.gz -------------------------------------------------------------------------------- /data/processed/uspto_sm_reactions.uspto_sm_.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/processed/uspto_sm_reactions.uspto_sm_.json.gz -------------------------------------------------------------------------------- /data/processed/uspto_sm_retro.templates.uspto_sm_.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/processed/uspto_sm_retro.templates.uspto_sm_.json.gz -------------------------------------------------------------------------------- /data/processed/uspto_sm_templates.df.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/processed/uspto_sm_templates.df.json.gz -------------------------------------------------------------------------------- /data/processed/uspto_sm_test.appl_matrix.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/processed/uspto_sm_test.appl_matrix.npz -------------------------------------------------------------------------------- /data/processed/uspto_sm_test.input.smiles.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/processed/uspto_sm_test.input.smiles.npy -------------------------------------------------------------------------------- /data/processed/uspto_sm_test.labels.classes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/processed/uspto_sm_test.labels.classes.npy -------------------------------------------------------------------------------- /data/processed/uspto_sm_train.appl_matrix.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/processed/uspto_sm_train.appl_matrix.npz -------------------------------------------------------------------------------- /data/processed/uspto_sm_train.input.smiles.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/processed/uspto_sm_train.input.smiles.npy -------------------------------------------------------------------------------- /data/processed/uspto_sm_train.labels.classes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/processed/uspto_sm_train.labels.classes.npy -------------------------------------------------------------------------------- /data/processed/uspto_sm_valid.appl_matrix.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/processed/uspto_sm_valid.appl_matrix.npz -------------------------------------------------------------------------------- /data/processed/uspto_sm_valid.input.smiles.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/processed/uspto_sm_valid.input.smiles.npy -------------------------------------------------------------------------------- /data/processed/uspto_sm_valid.labels.classes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/processed/uspto_sm_valid.labels.classes.npy -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Python modules 2 | *.pyc 3 | *.so 4 | *.pyd 5 | 6 | # artfacts from using JupyterLab 7 | .jupyter 8 | *ipynb_checkpoints* 9 | 10 | # raw data 11 | *data/* 12 | 13 | # example 14 | *example/* 15 | 16 | .alias -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | image: docker:stable 2 | 3 | services: 4 | - docker:dind 5 | 6 | variables: 7 | DOCKER_HOST: tcp://docker:2375 8 | DOCKER_DRIVER: overlay2 9 | IMAGE_TAG: $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG 10 | LATEST_TAG: $CI_REGISTRY_IMAGE:latest 11 | 12 | before_script: 13 | - docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY 14 | 15 | build: 16 | stage: build 17 | script: 18 | - docker build -t $LATEST_TAG . 19 | - docker push $LATEST_TAG 20 | only: 21 | - /^master$/ 22 | -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM registry.gitlab.com/mefortunato/docker-images/rdkit:2019.03.3-tf-mpi 2 | 3 | COPY requirements.txt requirements.txt 4 | 5 | RUN apt-get install -y unzip git && \ 6 | pip install -r requirements.txt && \ 7 | pip install -e "git://github.com/connorcoley/rdchiral.git#egg=rdchiral" 8 | 9 | COPY . /usr/local/temprel 10 | 11 | RUN pip install -e /usr/local/temprel && \ 12 | cp /usr/local/temprel/bin/* /usr/local/bin 13 | -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/Dockerfile.gpu: -------------------------------------------------------------------------------- 1 | FROM registry.gitlab.com/mefortunato/docker-images/rdkit:2019.03.3-tf-gpu 2 | 3 | COPY requirements.txt requirements.txt 4 | 5 | RUN apt-get install -y unzip && \ 6 | pip install -r requirements.txt && \ 7 | pip install -e "git://github.com/connorcoley/rdchiral.git#egg=rdchiral" 8 | 9 | COPY . /usr/local/temprel 10 | 11 | RUN pip install -e /usr/local/temprel && \ 12 | cp /usr/local/temprel/bin/* /usr/local/bin -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/README.md: -------------------------------------------------------------------------------- 1 | # Template relevance training 2 | 3 | ### Prerequisites 4 | 5 | This repository is designed to function using the Docker container: 6 | `registry.gitlab.com/mefortunato/template-relevance` 7 | which contains [RDKit](https://github.com/rdkit), [TensorFlow](https://www.tensorflow.org/), and various other dependencies, including OpenMPI in order to scale this workflow up to a much larger number of reactions. Because this is such a small data set of reactions (described below in `Step 0`), the entire workflow can be run in just a few hours on a modest machine. 8 | 9 | In order to run these scripts using the Docker container, first [make sure Docker is installed on your machine](https://docs.docker.com/install/), then run: 10 | ``` 11 | docker pull registry.gitlab.com/mefortunato/template-relevance 12 | ``` 13 | 14 | To make it easier to execute commands in the future, it is recommended to create an alias in `~/.bashrc` such as: 15 | ``` 16 | alias rdkit="docker run -it --rm -v ${PWD}:/work -w /work registry.gitlab.com/mefortunato/template-relevance" 17 | ``` 18 | or if you're using singularity: 19 | ``` 20 | alias rdkit="singularity exec -B $(pwd -P):/work --pwd /work /path/to/singularity/image.simg" 21 | ``` 22 | 23 | which will mount the current working directory (where ever you run `rdkit` from in the future) into `/work` inside the container. Any new files that get created will then also be available from the host machine, in the current working directory. 24 | 25 | The rest of this README will assume you are running the commands from inside the container, i.e. using the above alias: 26 | ``` 27 | rdkit python myscript.py 28 | ``` 29 | 30 | In even simpler terms, preface the rest of the commands listed in this README with the alias shown above. 31 | 32 | ### Step 0 - Get the raw reaction data 33 | 34 | Command(s) to run: 35 | ``` 36 | python get_uspto_50k.py 37 | ``` 38 | 39 | This example uses a small set of reactions from [https://doi.org/10.1021/ci5006614](https://doi.org/10.1021/ci5006614): 40 | ``` 41 | Development of a Novel Fingerprint for Chemical Reactions and Its Application to Large-Scale Reaction Classification and Similarity 42 | Nadine Schneider, Daniel M. Lowe, Roger A. Sayle, Gregory A. Landrum 43 | J. Chem. Inf. Model. 2015 55 1 39-53 44 | ``` 45 | 46 | The reactions were curated from the USPTO data set and provided in the SI. The script downloads the SI and puts the raw data into a new directory `data/raw`. If you plan to use a different set of reactions, it could be useful to run this script once to see the expected format of `data/raw/reactions.json.gz`, which is required for the following step. 47 | 48 | ### Step 1 - Extract templates 49 | 50 | Command(s) to run: 51 | ``` 52 | python process.py 53 | ``` 54 | 55 | The `process.py` script (which uses `template_extractor.py` from the [rdchrial Python package](https://github.com/connorcoley/rdchiral)) will extract retrosynthetic templates for each reaction and attempt to verify their accuracy (i.e. - if I apply the template to the product do I recover the reactant(s)). Ultimately all reactions with valid templates and their associated template will be saved to disk via pandas at `data/processed/templates.df.json.gz`. 56 | 57 | The `process.py` script will also perform a train/validation/test split according to the `--calc-split` flag which can be one of the following `['stratified', 'random']` and will create the following files for training a machine learning model: 58 | * train.input.smiles.npy 59 | * valid.input.smiles.npy 60 | * valid.labels.classes.npy 61 | * train.labels.classes.npy 62 | * test.input.smiles.npy 63 | * test.labels.classes.npy 64 | 65 | If you would like to use your own data split, please create the appropriate corresponding files. 66 | 67 | If you are using a dataset other than the default uspto_50k, you can give your dataset a name which will show up in various file names with the `--template-set-name` flag (this defaults to `uspto_50k`). 68 | 69 | This script also prepares certain files for the ASKCOS software package in `data/processed` named `reactions.\[template_set_name\].json.gz` and `historian.\[template_set_name\].json.gz` in addition to `retro.templates.\[template_set_name\].json.gz`. 70 | 71 | ### Step 2 - Training a baseline model 72 | 73 | Command(s) to run: 74 | ``` 75 | python train.py --templates data/processed/retro.templates.uspto_50k.json.gz 76 | ``` 77 | 78 | To train a template relevance baseline model (without pre-training on template applicability) use the `train.py` script. The available command line arguments are as follows: 79 | 80 | * `--train-smiles` (str): Path to training smiles .npy file 81 | * `--valid-smiles` (str): Path to validation smiles .npy file 82 | * `--train-labels` (str): Path to training class labels .npy file 83 | * `--valid-labels` (str): Path to validation class labels .npy file 84 | * `--no-validation` (bool): Option to skip validation if no validation set exists 85 | * `--num-classes` (int): Number of classes in dataset. Either this or the `--templates` flag is required. 86 | * `--templates` (str): Path to JSON file of unique templates generated during processing step above. This or `--num-classes` is required. 87 | * `--fp-length` (int): Length of Morgan fingerprint to use for input to model (default=2048) 88 | * `--fp-radius` (int): Radius of Morgan fingerprint to use for input to model (default=2) 89 | * `--pretrain-weights` (str): Path to model weights file to use to initialize model weights 90 | * `--weight-classes` (bool): Boolean flag whether to weight contribution to loss by template popularity (default=True) 91 | * `--num-hidden` (int): Number of hidden layers to use in neural network (default=1) 92 | * `--hidden-size` (int): Size for hidden layers in neural network (default=1024) 93 | * `--num_highway` (int): Number of highway layers (default=5) 94 | * `--dropout` (float): Dropout to use after each hidden layer during training (default=0.2) 95 | * `--learning-rate` (float): Learning rate to use with Adam optimizer (default=0.001) 96 | * `--activation` (str): Type of activation to use (default='relu') 97 | * `--batch-size` (int): Batch size to use during training (default=512) 98 | * `--epochs` (int): Max number of epochs during training (default=25) 99 | * `--early-stopping` (int): Number of epochs to use as patience for early stopping (default=3) 100 | * `--model-name` (str): Name to give to model, used in various file names (default='template-relevance') 101 | * `--nproc` (str): Number of processors to use for data pre-processing, 0 means don't processes in parallel (default=0) 102 | 103 | ### Step 3 - Test a baseline model 104 | 105 | Command(s) to run: 106 | ``` 107 | python test.py --templates data/processed/retro.templates.uspto_50k.json.gz --model-weights training/template-relevance-weights.hdf5 --model-name template-relevance --accuracy --reciprocal-rank 108 | ``` 109 | 110 | This will produce output two files in a new `evaluation/` folder: `template-relevance.accuracy.json` and `template-relevance.recip_rank.json`. 111 | -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/bin/calculate_applicabilty.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | import numpy as np 4 | import pandas as pd 5 | from scipy import sparse 6 | from joblib import Parallel, delayed 7 | 8 | from rdkit import rdBase 9 | rdBase.DisableLog('rdApp.*') 10 | 11 | from rdkit import Chem 12 | from rdkit.Chem import AllChem 13 | 14 | def parse_arguments(): 15 | parser = argparse.ArgumentParser(description='Calculate pairwise template applicability between template and molecule sets') 16 | parser.add_argument('--templates', dest='templates', default='data/processed/retro.templates.uspto_50k.json.gz') 17 | parser.add_argument('--train-smiles', dest='train_smiles', default='data/processed/train.input.smiles.npy') 18 | parser.add_argument('--valid-smiles', dest='valid_smiles', default='data/processed/valid.input.smiles.npy') 19 | parser.add_argument('--test-smiles', dest='test_smiles', default='data/processed/test.input.smiles.npy') 20 | parser.add_argument('--output-prefix', dest='output_prefix', default='data/processed/') 21 | return parser.parse_args() 22 | 23 | def smiles_to_mol(smiles): 24 | return [Chem.MolFromSmiles(smi) for smi in smiles] 25 | 26 | def calc_appl(mols, templates): 27 | lil = sparse.lil_matrix((len(mols), len(templates)), dtype=int) 28 | for j, smarts in enumerate(templates): 29 | rxn = AllChem.ReactionFromSmarts(smarts) 30 | for i, mol in enumerate(mols): 31 | try: 32 | res = rxn.RunReactants([mol]) 33 | except KeyboardInterrupt: 34 | print('Interrupted') 35 | raise KeyboardInterrupt 36 | except Exception as e: 37 | print(e) 38 | res = None 39 | if res: 40 | lil[i, j] = 1 41 | return lil.tocsr() 42 | 43 | 44 | def start_timer(comm): 45 | rank = comm.Get_rank() 46 | if rank == 0: 47 | return time.time() 48 | else: 49 | return None 50 | 51 | def print_timing(comm, t0, name=None): 52 | rank = comm.Get_rank() 53 | if rank == 0: 54 | print_str = 'elapsed' 55 | if name: 56 | print_str += ' [{}]'.format(name) 57 | print_str += ': {}'.format(time.time()-t0) 58 | print(print_str) 59 | return time.time() 60 | else: 61 | return None 62 | 63 | 64 | def mpi_appl(comm, smiles, smarts, save_path): 65 | rank = comm.Get_rank() 66 | size = comm.Get_size() 67 | t0 = start_timer(comm) 68 | if rank == 0: smiles = np.array_split(smiles, size) 69 | smiles = comm.scatter(smiles, root=0) 70 | t0 = print_timing(comm, t0, name='scatter') 71 | mols = smiles_to_mol(smiles) 72 | comm.Barrier() 73 | t0 = print_timing(comm, t0, name='convert') 74 | appl = calc_appl(mols, smarts) 75 | comm.Barrier() 76 | t0 = print_timing(comm, t0, name='appl') 77 | appl = comm.gather(appl, root=0) 78 | t0 = print_timing(comm, t0, name='gather') 79 | if rank == 0: 80 | appl = sparse.vstack(appl) 81 | sparse.save_npz(save_path, appl) 82 | t0 = print_timing(comm, t0, name='save') 83 | comm.Barrier() 84 | 85 | if __name__ == '__main__': 86 | from mpi4py import MPI 87 | comm = MPI.COMM_WORLD 88 | rank = comm.Get_rank() 89 | size = comm.Get_size() 90 | 91 | args = parse_arguments() 92 | 93 | t0 = start_timer(comm) 94 | 95 | templates = pd.read_json(args.templates) 96 | templates = templates.sort_values('index') 97 | template_smarts = templates['reaction_smarts'].values 98 | 99 | if rank == 0: 100 | train_smi = np.load(args.train_smiles) 101 | val_smi = np.load(args.valid_smiles) 102 | test_smi = np.load(args.test_smiles) 103 | else: 104 | train_smi = None 105 | val_smi = None 106 | test_smi = None 107 | 108 | t0 = print_timing(comm, t0, name='read') 109 | 110 | mpi_appl(comm, test_smi, template_smarts, args.output_prefix+'test.appl_matrix.npz') 111 | mpi_appl(comm, val_smi, template_smarts, args.output_prefix+'valid.appl_matrix.npz') 112 | mpi_appl(comm, train_smi, template_smarts, args.output_prefix+'train.appl_matrix.npz') -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/bin/get_uspto_50k.py: -------------------------------------------------------------------------------- 1 | from temprel.data.download import get_uspto_50k 2 | get_uspto_50k() 3 | -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/bin/hyperopt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | crun="docker run -it --rm -v ${PWD}:/work -w /work -u $(id -u $USER):$(id -g $USER) registry.gitlab.com/mlpds_mit/askcos/template-relevance:latest" 4 | 5 | for LAYERS in {1..3..1}; 6 | do 7 | for HIDDEN in {1000..3000..1000}; 8 | do 9 | for HIGHWAY in 0 1 3 5; 10 | do 11 | NAME="${LAYERS}x_${HIDDEN}_h${HIGHWAY}" 12 | $crun python train.py --templates data/processed/retro.templates.uspto_50k.json.gz --num-hidden $LAYERS --hidden-size $HIDDEN --num-highway $HIGHWAY --nproc=40 --model-name $NAME > log.$NAME 13 | $crun python train_appl.py --num-hidden $LAYERS --hidden-size $HIDDEN --num-highway $HIGHWAY --nproc=40 --model-name ${NAME}_appl > log.${NAME}_appl 14 | $crun python train.py --pretrain-weights training/${NAME}_appl-weights.hdf5 --templates data/processed/retro.templates.uspto_50k.json.gz --num-hidden $LAYERS --hidden-size $HIDDEN --num-highway $HIGHWAY --nproc=40 --model-name ${NAME}_pretrained > log.${NAME}_pretrain 15 | done 16 | done 17 | done 18 | -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/bin/process.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | import pandas as pd 4 | from temprel.templates.extract import templates_from_reactions, process_for_training, process_for_askcos 5 | 6 | def parse_arguments(): 7 | parser = argparse.ArgumentParser(description='Process reaction smiles for template relevance training with ASKCOS') 8 | parser.add_argument('--reactions', dest='reactions', default='data/raw/reactions.json.gz') 9 | parser.add_argument('--nproc', dest='nproc', type=int, default=1) 10 | parser.add_argument('--output-prefix', dest='output_prefix', default='data/processed/') 11 | parser.add_argument('--calc-split', dest='calc_split', default='stratified') 12 | parser.add_argument('--template-set-name', dest='template_set_name', default='uspto_50k') 13 | return parser.parse_args() 14 | 15 | def print_time(task_name, t0): 16 | new_t0 = time.time() 17 | print('elapsed {}: {}'.format(task_name, new_t0-t0)) 18 | return new_t0 19 | 20 | 21 | if __name__ == '__main__': 22 | args = parse_arguments() 23 | t0 = time.time() 24 | reactions = pd.read_json(args.reactions) 25 | t0 = print_time('read', t0) 26 | templates = templates_from_reactions(reactions, nproc=args.nproc) 27 | t0 = print_time('extract', t0) 28 | process_for_training(templates, output_prefix=args.output_prefix, calc_split=args.calc_split) 29 | t0 = print_time('featurize', t0) 30 | process_for_askcos(templates, template_set_name=args.template_set_name, output_prefix=args.output_prefix) 31 | t0 = print_time('askcos_process', t0) -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/bin/save_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | from temprel.models import relevance 4 | import tensorflow as tf 5 | 6 | def parse_arguments(): 7 | parser = argparse.ArgumentParser(description='Train a Morgan fingerprint teplate relevance network') 8 | parser.add_argument('--num-classes', dest='num_classes', type=int) 9 | parser.add_argument('--templates', dest='templates') 10 | parser.add_argument('--fp-length', dest='fp_length', type=int, default=2048) 11 | parser.add_argument('--fp-radius', dest='fp_radius', type=int, default=2) 12 | parser.add_argument('--num-hidden', dest='num_hidden', type=int, default=1) 13 | parser.add_argument('--hidden-size', dest='hidden_size', type=int, default=1024) 14 | parser.add_argument('--num-highway', dest='num_highway', type=int, default=0) 15 | parser.add_argument('--activation', dest='activation', default='relu') 16 | parser.add_argument('--model-weights', dest='model_weights', default=None) 17 | parser.add_argument('--model-name', dest='model_name', default='template-relevance-appl') 18 | return parser.parse_args() 19 | 20 | if __name__ == '__main__': 21 | args = parse_arguments() 22 | 23 | if not args.num_classes and not args.templates: 24 | raise ValueError('Error: --num-classes or --templates required') 25 | if args.num_classes: 26 | num_classes = args.num_classes 27 | else: 28 | templates = pd.read_json(args.templates) 29 | num_classes = len(templates) 30 | 31 | model = relevance( 32 | input_shape=(args.fp_length), output_shape=num_classes, num_hidden=args.num_hidden, 33 | hidden_size=args.hidden_size, activation=args.activation, num_highway=args.num_highway, 34 | compile_model=False 35 | ) 36 | model.load_weights(args.model_weights) 37 | model.add(tf.keras.layers.Activation('softmax')) 38 | model.save(args.model_name, save_format='tf') 39 | -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/bin/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | import pandas as pd 6 | from scipy import sparse 7 | from joblib import Parallel, delayed 8 | 9 | from temprel.models import relevance 10 | from temprel.rdkit import smiles_to_fingerprint, templates_from_smarts_list 11 | from temprel.evaluate.diversity import diversity 12 | from temprel.evaluate.accuracy import accuracy_by_popularity 13 | from temprel.evaluate.topk_appl import topk_appl_recall_and_precision 14 | from temprel.evaluate.reciprocal_rank import reciprocal_rank_by_popularity 15 | 16 | def parse_arguments(): 17 | parser = argparse.ArgumentParser(description='Test a Morgan fingerprint teplate relevance network') 18 | parser.add_argument('--test-smiles', dest='test_smiles', default='data/processed/test.input.smiles.npy') 19 | parser.add_argument('--test-labels', dest='test_labels', default='data/processed/test.labels.classes.npy') 20 | parser.add_argument('--test-appl-labels', dest='test_appl_labels', default='data/processed/test.appl_matrix.npz') 21 | parser.add_argument('--train-labels', dest='train_labels', default='data/processed/train.labels.classes.npy') 22 | parser.add_argument('--templates', dest='templates_path') 23 | parser.add_argument('--topk', dest='topk', type=int, default=100) 24 | parser.add_argument('--fp-length', dest='fp_length', type=int, default=2048) 25 | parser.add_argument('--fp-radius', dest='fp_radius', type=int, default=2) 26 | parser.add_argument('--num-hidden', dest='num_hidden', type=int, default=1) 27 | parser.add_argument('--hidden-size', dest='hidden_size', type=int, default=1024) 28 | parser.add_argument('--num-highway', dest='num_highway', type=int, default=0) 29 | parser.add_argument('--activation', dest='activation', default='relu') 30 | parser.add_argument('--model-weights', dest='model_weights', default=None) 31 | parser.add_argument('--batch-size', dest='batch_size', type=int, default=512) 32 | parser.add_argument('--model-name', dest='model_name', default='baseline') 33 | parser.add_argument('--accuracy', dest='accuracy', action='store_true', default=False) 34 | parser.add_argument('--reciprocal-rank', dest='rr', action='store_true', default=False) 35 | parser.add_argument('--topk_appl', dest='topk_appl', action='store_true', default=False) 36 | parser.add_argument('--diversity', dest='diversity', action='store_true', default=False) 37 | parser.add_argument('--nproc', dest='nproc', type=int, default=1) 38 | return parser.parse_args() 39 | 40 | if __name__ == '__main__': 41 | if not os.path.exists('evaluation'): 42 | os.makedirs('evaluation') 43 | args = parse_arguments() 44 | test_smiles = np.load(args.test_smiles) 45 | test_labels = np.load(args.test_labels) 46 | train_labels = np.load(args.train_labels) 47 | if os.path.exists(args.test_appl_labels): 48 | test_appl_labels = sparse.load_npz(args.test_appl_labels) 49 | 50 | test_fps = Parallel(n_jobs=args.nproc, verbose=1)( 51 | delayed(smiles_to_fingerprint)(smi, length=args.fp_length, radius=args.fp_radius) for smi in test_smiles 52 | ) 53 | test_fps = np.array(test_fps) 54 | 55 | templates = pd.read_json(args.templates_path) 56 | 57 | model = relevance( 58 | input_shape=(args.fp_length), output_shape=len(templates), num_hidden=args.num_hidden, 59 | hidden_size=args.hidden_size, activation=args.activation, num_highway=args.num_highway 60 | ) 61 | model.load_weights(args.model_weights) 62 | 63 | if args.accuracy: 64 | acc = accuracy_by_popularity(model, test_fps, test_labels, train_labels, batch_size=args.batch_size) 65 | pd.DataFrame.from_dict(acc, orient='index', columns=model.metrics_names).to_json('evaluation/{}.accuracy.json'.format(args.model_name)) 66 | 67 | if args.rr: 68 | rr = reciprocal_rank_by_popularity(model, test_fps, test_labels, train_labels, batch_size=args.batch_size) 69 | with open('evaluation/{}.recip_rank.json'.format(args.model_name), 'w') as f: 70 | json.dump(rr, f) 71 | 72 | if args.topk_appl: 73 | topk_appl_recall, topk_appl_precision = topk_appl_recall_and_precision(model, test_fps, test_appl_labels) 74 | with open('evaluation/{}.appl_recall.json'.format(args.model_name), 'w') as f: 75 | json.dump(topk_appl_recall, f) 76 | with open('evaluation/{}.appl_precision.json'.format(args.model_name), 'w') as f: 77 | json.dump(topk_appl_precision, f) 78 | 79 | if args.diversity: 80 | templates_rxn = np.array(templates_from_smarts_list(templates['reaction_smarts'], nproc=args.nproc)) 81 | div = diversity(model, test_smiles, templates_rxn, topk=args.topk, fp_length=args.fp_length, fp_radius=args.fp_radius, nproc=args.nproc) 82 | np.save('evaluation/{}.diversity.npy'.format(args.model_name), div) -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/bin/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import pandas as pd 5 | from temprel.models import relevance 6 | from temprel.data.loaders import fingerprint_training_dataset 7 | import tensorflow as tf 8 | import sklearn 9 | 10 | def parse_arguments(): 11 | parser = argparse.ArgumentParser(description='Train a Morgan fingerprint teplate relevance network') 12 | parser.add_argument('--train-smiles', dest='train_smiles', default='data/processed/train.input.smiles.npy') 13 | parser.add_argument('--valid-smiles', dest='valid_smiles', default='data/processed/valid.input.smiles.npy') 14 | parser.add_argument('--train-labels', dest='train_labels', default='data/processed/train.labels.classes.npy') 15 | parser.add_argument('--valid-labels', dest='valid_labels', default='data/processed/valid.labels.classes.npy') 16 | parser.add_argument('--no-validation', dest='no_validation', action='store_true', default=False) 17 | parser.add_argument('--num-classes', dest='num_classes', type=int) 18 | parser.add_argument('--templates', dest='templates', default='data/processed/retro.templates.json.gz') 19 | parser.add_argument('--fp-length', dest='fp_length', type=int, default=2048) 20 | parser.add_argument('--fp-radius', dest='fp_radius', type=int, default=2) 21 | parser.add_argument('--precompute-fps', dest='precompute_fps', action='store_true', default=True) 22 | parser.add_argument('--pretrain-weights', dest='pretrain_weights', default=None) 23 | parser.add_argument('--weight-classes', dest='weight_classes', action='store_true', default=True) 24 | parser.add_argument('--num-hidden', dest='num_hidden', type=int, default=1) 25 | parser.add_argument('--hidden-size', dest='hidden_size', type=int, default=1024) 26 | parser.add_argument('--num-highway', dest='num_highway', type=int, default=0) 27 | parser.add_argument('--dropout', dest='dropout', type=float, default=0.2) 28 | parser.add_argument('--learning-rate', dest='learning_rate', type=float, default=0.001) 29 | parser.add_argument('--activation', dest='activation', default='relu') 30 | parser.add_argument('--batch-size', dest='batch_size', type=int, default=512) 31 | parser.add_argument('--clipnorm', dest='clipnorm', action='store_true', default=None) 32 | parser.add_argument('--epochs', dest='epochs', type=int, default=25) 33 | parser.add_argument('--early-stopping', dest='early_stopping', type=int, default=3) 34 | parser.add_argument('--model-name', dest='model_name', default='template-relevance') 35 | parser.add_argument('--nproc', dest='nproc', type=int, default=0) 36 | return parser.parse_args() 37 | 38 | def try_read_npy(filename): 39 | if not os.path.exists(filename): 40 | raise ValueError('File does not exist: {}'.format(filename)) 41 | return np.load(filename) 42 | 43 | def shuffle_arrays(a, b): 44 | p = np.random.permutation(len(a)) 45 | return a[p], b[p] 46 | 47 | if __name__ == '__main__': 48 | args = parse_arguments() 49 | if not args.train_smiles or not args.train_labels: 50 | raise ValueError('Error: training data (--train-smiles and --train-labels) required') 51 | train_smiles = try_read_npy(args.train_smiles) 52 | train_labels = try_read_npy(args.train_labels) 53 | train_smiles, train_labels = shuffle_arrays(train_smiles, train_labels) 54 | if not args.no_validation: 55 | valid_smiles = try_read_npy(args.valid_smiles) 56 | valid_labels = try_read_npy(args.valid_labels) 57 | valid_smiles, valid_labels = shuffle_arrays(valid_smiles, valid_labels) 58 | if not args.num_classes and not args.templates: 59 | raise ValueError('Error: --num-classes or --templates required') 60 | if args.num_classes: 61 | num_classes = args.num_classes 62 | else: 63 | templates = pd.read_json(args.templates) 64 | num_classes = len(templates) 65 | 66 | train_ds = fingerprint_training_dataset( 67 | train_smiles, train_labels, batch_size=args.batch_size, train=True, 68 | fp_length=args.fp_length, fp_radius=args.fp_radius, nproc=40, precompute=args.precompute_fps 69 | ) 70 | train_steps = np.ceil(len(train_smiles)/args.batch_size).astype(int) 71 | 72 | if not args.no_validation: 73 | valid_ds = fingerprint_training_dataset( 74 | valid_smiles, valid_labels, batch_size=args.batch_size, train=False, 75 | fp_length=args.fp_length, fp_radius=args.fp_radius, nproc=40, precompute=args.precompute_fps 76 | ) 77 | valid_steps = np.ceil(len(valid_smiles)/args.batch_size).astype(int) 78 | else: 79 | valid_ds = None 80 | valid_steps = None 81 | 82 | model = relevance( 83 | input_shape=(args.fp_length,), 84 | output_shape=num_classes, 85 | num_hidden=args.num_hidden, 86 | hidden_size=args.hidden_size, 87 | num_highway=args.num_highway, 88 | dropout=args.dropout, 89 | learning_rate=args.learning_rate, 90 | activation=args.activation, 91 | clipnorm=args.clipnorm 92 | ) 93 | if args.pretrain_weights: 94 | model.load_weights(args.pretrain_weights) 95 | 96 | if not os.path.exists('training'): 97 | os.makedirs('training') 98 | model_output = 'training/{}-weights.hdf5'.format(args.model_name) 99 | history_output = 'training/{}-history.json'.format(args.model_name) 100 | 101 | callbacks = [] 102 | if args.early_stopping: 103 | callbacks.append( 104 | tf.keras.callbacks.EarlyStopping( 105 | patience=args.early_stopping, 106 | restore_best_weights=True 107 | ) 108 | ) 109 | callbacks.append( 110 | tf.keras.callbacks.ModelCheckpoint( 111 | model_output, monitor='val_loss', save_weights_only=True 112 | ) 113 | ) 114 | 115 | if args.weight_classes: 116 | class_weight = sklearn.utils.class_weight.compute_class_weight( 117 | 'balanced', np.unique(train_labels), train_labels 118 | ) 119 | else: 120 | class_weight = sklearn.utils.class_weight.compute_class_weight( 121 | None, np.unique(train_labels), train_labels 122 | ) 123 | 124 | if args.nproc: 125 | multiproc = True 126 | nproc = args.nproc 127 | else: 128 | multiproc = False 129 | nproc = None 130 | 131 | 132 | history = model.fit( 133 | train_ds, epochs=args.epochs, steps_per_epoch=train_steps, 134 | validation_data=valid_ds, validation_steps=valid_steps, 135 | callbacks=callbacks, class_weight=class_weight, 136 | use_multiprocessing=multiproc, workers=nproc 137 | ) 138 | 139 | pd.DataFrame(history.history).to_json(history_output) -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/bin/train_appl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import pandas as pd 5 | from scipy import sparse 6 | from temprel.models import applicability 7 | from temprel.data.loaders import fingerprint_training_dataset 8 | import tensorflow as tf 9 | import sklearn 10 | 11 | def parse_arguments(): 12 | parser = argparse.ArgumentParser(description='Train a Morgan fingerprint teplate relevance network') 13 | parser.add_argument('--train-smiles', dest='train_smiles', default='data/processed/train.input.smiles.npy') 14 | parser.add_argument('--valid-smiles', dest='valid_smiles', default='data/processed/valid.input.smiles.npy') 15 | parser.add_argument('--train-labels', dest='train_labels', default='data/processed/train.appl_matrix.npz') 16 | parser.add_argument('--valid-labels', dest='valid_labels', default='data/processed/valid.appl_matrix.npz') 17 | parser.add_argument('--no-validation', dest='no_validation', action='store_true', default=False) 18 | parser.add_argument('--fp-length', dest='fp_length', type=int, default=2048) 19 | parser.add_argument('--fp-radius', dest='fp_radius', type=int, default=2) 20 | parser.add_argument('--precompute-fps', dest='precompute_fps', action='store_true', default=True) 21 | parser.add_argument('--weight-classes', dest='weight_classes', action='store_true', default=True) 22 | parser.add_argument('--num-hidden', dest='num_hidden', type=int, default=1) 23 | parser.add_argument('--hidden-size', dest='hidden_size', type=int, default=1024) 24 | parser.add_argument('--num-highway', dest='num_highway', type=int, default=0) 25 | parser.add_argument('--dropout', dest='dropout', type=float, default=0.2) 26 | parser.add_argument('--learning-rate', dest='learning_rate', type=float, default=0.001) 27 | parser.add_argument('--activation', dest='activation', default='relu') 28 | parser.add_argument('--batch-size', dest='batch_size', type=int, default=128) 29 | parser.add_argument('--clipnorm', dest='clipnorm', action='store_true', default=None) 30 | parser.add_argument('--epochs', dest='epochs', type=int, default=25) 31 | parser.add_argument('--early-stopping', dest='early_stopping', type=int, default=3) 32 | parser.add_argument('--model-name', dest='model_name', default='template-relevance-appl') 33 | parser.add_argument('--nproc', dest='nproc', type=int, default=0) 34 | return parser.parse_args() 35 | 36 | def try_read_npz(filename): 37 | if not os.path.exists(filename): 38 | raise ValueError('File does not exist: {}'.format(filename)) 39 | return sparse.load_npz(filename) 40 | 41 | def try_read_npy(filename): 42 | if not os.path.exists(filename): 43 | raise ValueError('File does not exist: {}'.format(filename)) 44 | return np.load(filename) 45 | 46 | def shuffle_arrays(a, b): 47 | p = np.random.permutation(len(a)) 48 | return a[p], b[p] 49 | 50 | if __name__ == '__main__': 51 | args = parse_arguments() 52 | if not args.train_smiles or not args.train_labels: 53 | raise ValueError('Error: training data (--train-smiles and --train-labels) required') 54 | train_smiles = try_read_npy(args.train_smiles) 55 | train_labels = try_read_npz(args.train_labels) 56 | train_smiles, train_labels = shuffle_arrays(train_smiles, train_labels) 57 | if not args.no_validation: 58 | valid_smiles = try_read_npy(args.valid_smiles) 59 | valid_labels = try_read_npz(args.valid_labels) 60 | valid_smiles, valid_labels = shuffle_arrays(valid_smiles, valid_labels) 61 | 62 | num_classes = train_labels.shape[1] 63 | 64 | if args.weight_classes: 65 | template_example_counts = train_labels.sum(axis=0).A.reshape(-1) 66 | template_example_counts[np.argwhere(template_example_counts==0).reshape(-1)] = 1 67 | template_class_weights = template_example_counts.sum()/template_example_counts.shape[0]/template_example_counts 68 | else: 69 | template_class_weights = {n: 1. for n in range(train_labels.shape[1])} 70 | 71 | train_ds = fingerprint_training_dataset( 72 | train_smiles, train_labels, batch_size=args.batch_size, train=True, 73 | fp_length=args.fp_length, fp_radius=args.fp_radius, nproc=40, 74 | sparse_labels=True, cache=False, precompute=args.precompute_fps 75 | ) 76 | train_steps = np.ceil(len(train_smiles)/args.batch_size).astype(int) 77 | 78 | if not args.no_validation: 79 | valid_ds = fingerprint_training_dataset( 80 | valid_smiles, valid_labels, batch_size=args.batch_size, train=False, 81 | fp_length=args.fp_length, fp_radius=args.fp_radius, nproc=40, 82 | sparse_labels=True, cache=False, precompute=args.precompute_fps 83 | ) 84 | valid_steps = np.ceil(len(valid_smiles)/args.batch_size).astype(int) 85 | else: 86 | valid_ds = None 87 | valid_steps = None 88 | 89 | model = applicability( 90 | input_shape=(args.fp_length,), 91 | output_shape=num_classes, 92 | num_hidden=args.num_hidden, 93 | hidden_size=args.hidden_size, 94 | num_highway=args.num_highway, 95 | dropout=args.dropout, 96 | learning_rate=args.learning_rate, 97 | activation=args.activation, 98 | clipnorm=args.clipnorm 99 | ) 100 | 101 | if not os.path.exists('training'): 102 | os.makedirs('training') 103 | model_output = 'training/{}-weights.hdf5'.format(args.model_name) 104 | history_output = 'training/{}-history.json'.format(args.model_name) 105 | 106 | callbacks = [] 107 | if args.early_stopping: 108 | callbacks.append( 109 | tf.keras.callbacks.EarlyStopping( 110 | patience=args.early_stopping, 111 | restore_best_weights=True 112 | ) 113 | ) 114 | callbacks.append( 115 | tf.keras.callbacks.ModelCheckpoint( 116 | model_output, monitor='val_loss', save_weights_only=True 117 | ) 118 | ) 119 | 120 | if args.nproc: 121 | multiproc = True 122 | nproc = args.nproc 123 | else: 124 | multiproc = False 125 | nproc = None 126 | 127 | history = model.fit( 128 | train_ds, epochs=args.epochs, steps_per_epoch=train_steps, 129 | validation_data=valid_ds, validation_steps=valid_steps, 130 | callbacks=callbacks, class_weight=template_class_weights, 131 | use_multiprocessing=multiproc, workers=nproc 132 | ) 133 | 134 | pd.DataFrame(history.history).to_json(history_output) -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/requirements.txt: -------------------------------------------------------------------------------- 1 | hdbscan 2 | joblib 3 | jupyter 4 | matplotlib 5 | numpy 6 | pandas 7 | pillow 8 | rdchiral 9 | requests 10 | scikit-learn 11 | scipy 12 | seaborn 13 | -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="temprel", 8 | version="1.0", 9 | author="Mike Fortunato", 10 | author_email="mef231@gmail.com", 11 | description="Reaction template relevance training pipeline code", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://gitlab.com/mefortunato/template-relevance/", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ], 21 | python_requires='>=3.5', 22 | ) -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/temprel-fortunato/template-relevance-master/temprel/__init__.py -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/temprel-fortunato/template-relevance-master/temprel/data/__init__.py -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/data/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gzip 3 | import pickle 4 | import requests 5 | import subprocess 6 | import pandas as pd 7 | 8 | def download_file(url, output_path=None): 9 | if not output_path: 10 | output_path = url.split('/')[-1] 11 | with requests.get(url, stream=True) as r: 12 | r.raise_for_status() 13 | with open(output_path, 'wb') as f: 14 | for chunk in r.iter_content(chunk_size=8192): 15 | if chunk: 16 | f.write(chunk) 17 | 18 | def get_uspto_480k(): 19 | if not os.path.exists('data'): 20 | os.mkdir('data') 21 | if not os.path.exists('data/raw'): 22 | os.mkdir('data/raw') 23 | os.chdir('data/raw') 24 | download_file( 25 | 'https://github.com/connorcoley/rexgen_direct/raw/master/rexgen_direct/data/train.txt.tar.gz', 26 | 'train.txt.tar.gz' 27 | ) 28 | subprocess.run(['tar', 'zxf', 'train.txt.tar.gz']) 29 | download_file( 30 | 'https://github.com/connorcoley/rexgen_direct/raw/master/rexgen_direct/data/valid.txt.tar.gz', 31 | 'valid.txt.tar.gz' 32 | ) 33 | subprocess.run(['tar', 'zxf', 'valid.txt.tar.gz']) 34 | download_file( 35 | 'https://github.com/connorcoley/rexgen_direct/raw/master/rexgen_direct/data/test.txt.tar.gz', 36 | 'test.txt.tar.gz' 37 | ) 38 | subprocess.run(['tar', 'zxf', 'test.txt.tar.gz']) 39 | 40 | with open('train.txt') as f: 41 | train = [ 42 | { 43 | 'reaction_smiles': line.strip(), 44 | 'split': 'train' 45 | } 46 | for line in f.readlines() 47 | ] 48 | with open('valid.txt') as f: 49 | valid = [ 50 | { 51 | 'reaction_smiles': line.strip(), 52 | 'split': 'valid' 53 | } 54 | for line in f.readlines() 55 | ] 56 | with open('test.txt') as f: 57 | test = [ 58 | { 59 | 'reaction_smiles': line.strip(), 60 | 'split': 'test' 61 | } 62 | for line in f.readlines() 63 | ] 64 | 65 | pd.concat([ 66 | pd.DataFrame(train), 67 | pd.DataFrame(valid), 68 | pd.DataFrame(test) 69 | ]).reset_index().to_json('reactions.json.gz', compression='gzip') 70 | 71 | def get_uspto_50k(): 72 | ''' 73 | get SI from: 74 | Nadine Schneider; Daniel M. Lowe; Roger A. Sayle; Gregory A. Landrum. J. Chem. Inf. Model.201555139-53 75 | ''' 76 | if not os.path.exists('data'): 77 | os.mkdir('data') 78 | if not os.path.exists('data/raw'): 79 | os.mkdir('data/raw') 80 | os.chdir('data/raw') 81 | subprocess.run(['wget', 'https://pubs.acs.org/doi/suppl/10.1021/ci5006614/suppl_file/ci5006614_si_002.zip']) 82 | subprocess.run(['unzip', '-o', 'ci5006614_si_002.zip']) 83 | data = [] 84 | with gzip.open('ChemReactionClassification/data/training_test_set_patent_data.pkl.gz') as f: 85 | while True: 86 | try: 87 | data.append(pickle.load(f)) 88 | except EOFError: 89 | break 90 | reaction_smiles = [d[0] for d in data] 91 | reaction_reference = [d[1] for d in data] 92 | reaction_class = [d[2] for d in data] 93 | df = pd.DataFrame() 94 | df['reaction_smiles'] = reaction_smiles 95 | df['reaction_reference'] = reaction_reference 96 | df['reaction_class'] = reaction_class 97 | df.to_json('reactions.json.gz', compression='gzip') 98 | -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/data/loaders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from joblib import Parallel, delayed 4 | from ..rdkit import smiles_to_fingerprint 5 | 6 | def fingerprint_training_dataset( 7 | smiles, labels, batch_size=256, train=True, 8 | fp_length=2048, fp_radius=2, fp_use_features=False, fp_use_chirality=True, 9 | sparse_labels=False, shuffle_buffer=1024, nproc=8, cache=True, precompute=False 10 | ): 11 | smiles_ds = fingerprint_dataset_from_smiles(smiles, fp_length, fp_radius, fp_use_features, fp_use_chirality, nproc, precompute) 12 | labels_ds = labels_dataset(labels, sparse_labels) 13 | ds = tf.data.Dataset.zip((smiles_ds, labels_ds)) 14 | ds = ds.shuffle(shuffle_buffer).batch(batch_size) 15 | if train: 16 | ds = ds.repeat() 17 | if cache: 18 | ds = ds.cache() 19 | ds = ds.prefetch(buffer_size=batch_size*3) 20 | return ds 21 | 22 | def fingerprint_dataset_from_smiles(smiles, length, radius, useFeatures, useChirality, nproc=8, precompute=False): 23 | def smiles_tensor_to_fp(smi, length, radius, useFeatures, useChirality): 24 | smi = smi.numpy().decode('utf-8') 25 | length = int(length.numpy()) 26 | radius = int(radius.numpy()) 27 | useFeatures = bool(useFeatures.numpy()) 28 | useChirality = bool(useChirality.numpy()) 29 | fp_bit = smiles_to_fingerprint(smi, length, radius, useFeatures, useChirality) 30 | return np.array(fp_bit) 31 | def parse_smiles(smi): 32 | output = tf.py_function( 33 | smiles_tensor_to_fp, 34 | inp=[smi, length, radius, useFeatures, useChirality], 35 | Tout=tf.float32 36 | ) 37 | output.set_shape((length,)) 38 | return output 39 | if not precompute: 40 | ds = tf.data.Dataset.from_tensor_slices(smiles) 41 | ds = ds.map(map_func=parse_smiles, num_parallel_calls=nproc) 42 | else: 43 | fps = Parallel(n_jobs=nproc, verbose=1)( 44 | delayed(smiles_to_fingerprint)(smi, length, radius, useFeatures, useChirality) for smi in smiles 45 | ) 46 | fps = np.array(fps) 47 | ds = tf.data.Dataset.from_tensor_slices(fps) 48 | return ds 49 | 50 | def labels_dataset(labels, sparse=False): 51 | if not sparse: 52 | return tf.data.Dataset.from_tensor_slices(labels) 53 | coo = labels.tocoo() 54 | indices = np.array([coo.row, coo.col]).T 55 | labels = tf.SparseTensor(indices, coo.data, coo.shape) 56 | labels_ds = tf.data.Dataset.from_tensor_slices(labels) 57 | labels_ds = labels_ds.map(map_func=tf.sparse.to_dense) 58 | return labels_ds -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/evaluate/accuracy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def accuracy_by_popularity(model, test_fps, test_labels, train_labels, batch_size=512): 4 | class_counts = np.bincount(train_labels) 5 | cls_ind = {} 6 | for n in range(0, 11): 7 | cls_ind[n] = np.isin(test_labels, np.argwhere(class_counts==n).reshape(-1)) 8 | cls_ind['med'] = np.isin(test_labels, np.argwhere((class_counts>10) & (class_counts<=50)).reshape(-1)) 9 | cls_ind['high'] = np.isin(test_labels, np.argwhere(class_counts>50).reshape(-1)) 10 | performance = {} 11 | performance['all'] = model.evaluate(test_fps, test_labels, batch_size=batch_size) 12 | for n in range(0, 11): 13 | if len(test_fps[cls_ind[n]]) == 0: 14 | continue 15 | performance[n] = model.evaluate(test_fps[cls_ind[n]], test_labels[cls_ind[n]], batch_size=batch_size) 16 | performance['med'] = model.evaluate(test_fps[cls_ind['med']], test_labels[cls_ind['med']], batch_size=batch_size) 17 | performance['high'] = model.evaluate(test_fps[cls_ind['high']], test_labels[cls_ind['high']], batch_size=batch_size) 18 | return performance 19 | -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/evaluate/diversity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ..rdkit import smiles_list_to_fingerprints, precursors_from_templates 3 | 4 | def tanimoto(fp1, fp2): 5 | a = fp1.sum() 6 | b = fp2.sum() 7 | c = float((fp1&fp2).sum()) 8 | return c/(a+b-c) 9 | 10 | def pairwise_tanimoto(arr1, arr2, metric=tanimoto): 11 | if arr1.size == 0: 12 | return np.array([[]]) 13 | return cdist(arr1, arr2, metric=metric) 14 | 15 | def diversity_from_smiles_list(smiles_list, fp_length=2048, fp_radius=2, nproc=1): 16 | fps = smiles_list_to_fingerprints(smiles_list, fp_length=fp_length, fp_radius=fp_radius, nproc=nproc) 17 | similarity = pairwise_tanimoto(fps, fps) 18 | diversity = 1 - similarity 19 | np.fill_diagonal(diversity, np.nan) 20 | return diversity 21 | 22 | def diversity(model, test_smiles, templates, topk=100, fp_length=2048, fp_radius=2, nproc=1): 23 | div = [] 24 | for smi in test_smiles: 25 | fp = smiles_to_fingerprint(smi, length=fp_length, radius=fp_radius) 26 | pred = model.predict(fp.reshape(1, -1)).reshape(-1) 27 | ind = np.argsort(-pred)[:topk] 28 | precursors = precursors_from_templates(smi, templates[ind], nproc=nproc) 29 | div.append(diversity_from_smiles_list(precursors, nproc=nproc)) 30 | return div -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/evaluate/reciprocal_rank.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def reciprocal_rank(model, test_fps, test_labels): 4 | pred = model.predict(test_fps) 5 | ind = np.argsort(-pred) 6 | ranks = 1 + np.argwhere(test_labels.reshape(-1, 1) == ind)[:, 1] 7 | return 1/ranks 8 | 9 | def reciprocal_rank_by_popularity(model, test_fps, test_labels, train_labels, batch_size=512): 10 | class_counts = np.bincount(train_labels) 11 | cls_ind = {} 12 | for n in range(0, 11): 13 | cls_ind[n] = np.isin(test_labels, np.argwhere(class_counts==n).reshape(-1)) 14 | cls_ind['med'] = np.isin(test_labels, np.argwhere((class_counts>10) & (class_counts<=50)).reshape(-1)) 15 | cls_ind['high'] = np.isin(test_labels, np.argwhere(class_counts>50).reshape(-1)) 16 | rr = {} 17 | rr['all'] = reciprocal_rank(model, test_fps, test_labels).tolist() 18 | 19 | for n in range(0, 11): 20 | rr[n] = reciprocal_rank(model, test_fps[cls_ind[n]], test_labels[cls_ind[n]]).tolist() 21 | rr['med'] = reciprocal_rank(model, test_fps[cls_ind['med']], test_labels[cls_ind['med']]).tolist() 22 | rr['high'] = reciprocal_rank(model, test_fps[cls_ind['high']], test_labels[cls_ind['high']]).tolist() 23 | return rr 24 | -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/evaluate/roc.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import roc_curve, auc, roc_auc_score 2 | from sklearn.preprocessing import label_binarize 3 | 4 | def roc_curve(model, test_fps, test_appl_labels, class_idx=0): 5 | y_score = model.predict(test_fps) 6 | fpr, tpr, _ = roc_curve(test_appl_labels[:, class_idx].A.reshape(-1), y_score[:, class_idx]) 7 | roc_auc = auc(fpr, tpr) 8 | return fpr, tpr, roc_auc 9 | 10 | def roc_auc(model, test_fps, test_appl_labels, multi_class='ovr', average='weighted'): 11 | y_score = model.predict(test_fps) 12 | mask = test_appl_labels.A.sum(axis=0)>=2 13 | return roc_auc_score(test_appl_labels[:, mask], y_score[:, mask], multi_class=multi_class, average=average) 14 | -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/evaluate/template_popularity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def template_popularity(model, test_fps, train_labels, test_appl_labels, topk=100, batch_size=512): 4 | class_counts = np.bincount(train_labels) 5 | pred = model.predict(test_fps) 6 | ind = np.argsort(-pred)[:, :topk] 7 | appl_ind = [i[np.isin(i, np.argwhere(appl_label.A))] for i, appl_label in zip(ind, test_appl_labels)] 8 | appl_pop = np.array([class_counts[i] for i in appl_ind]) 9 | max_length = max([len(a) for a in appl_pop]) 10 | appl_pop = np.vstack([np.pad(pop.astype(float), pad_width=(0, max_length-len(pop)), mode='constant', constant_values=np.NaN) for pop in appl_pop]) 11 | return appl_pop -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/evaluate/topk_appl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def appl_recall_and_precision(model, test_fps, test_appl_labels, k=10): 4 | row, col = test_appl_labels.nonzero() 5 | scores = model.predict(test_fps) 6 | indices = np.argsort(-scores, axis=1) 7 | appl_recall = [] 8 | appl_precision = [] 9 | for n in range(len(indices)): 10 | n_appl = np.isin(indices[n][:k], col[row==n]).sum() 11 | total_appl = col[row==n].size 12 | total_reccommended = k 13 | appl_recall.append(n_appl/total_appl) 14 | appl_precision.append(n_appl/total_reccommended) 15 | return appl_recall, appl_precision 16 | 17 | def topk_appl_recall_and_precision(model, test_fps, test_appl_labels, k=[1, 5, 10, 25, 50, 100, 250, 1000]): 18 | recall = {} 19 | precision = {} 20 | for kval in k: 21 | r, p = appl_recall_and_precision(model, test_fps, test_appl_labels, k=kval) 22 | recall[kval] = r 23 | precision[kval] = p 24 | return recall, precision 25 | -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import relevance, applicability -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/models/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import backend as K 3 | from tensorflow.keras.layers import Layer, Dense 4 | 5 | class Highway(Layer): 6 | def __init__(self, activation='relu', transform_activation='sigmoid', **kwargs): 7 | self.activation = activation 8 | self.transform_activation = transform_activation 9 | super(Highway, self).__init__(**kwargs) 10 | 11 | def build(self, input_shape): 12 | self.dense = Dense(units=input_shape[-1], activation=self.activation, bias_initializer='zeros') 13 | self.dense_gate = Dense(units=input_shape[-1], activation=self.transform_activation, bias_initializer='zeros') 14 | self.input_dim = input_shape[-1] 15 | super(Highway, self).build(input_shape) 16 | 17 | def call(self, x): 18 | transform = self.dense(x) 19 | transform_gate = self.dense_gate(x) 20 | carry_gate = K.ones_like(transform_gate) - transform_gate 21 | output = transform*transform_gate + x*carry_gate 22 | return output -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/models/losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def sparse_categorical_crossentropy_from_logits(labels, logits): 4 | return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True) -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/models/metrics.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from functools import partial 3 | 4 | def top_k(k=1): 5 | partial_fn = partial(tf.keras.metrics.sparse_top_k_categorical_accuracy, k=k) 6 | partial_fn.__name__ = 'top_{}'.format(k) 7 | return partial_fn -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/models/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from .layers import Highway 3 | from .metrics import top_k 4 | from .losses import sparse_categorical_crossentropy_from_logits 5 | 6 | def build_model( 7 | input_shape, output_shape, num_hidden, hidden_size, num_highway, 8 | activation='relu', output_activation=None, dropout=0.0, clipnorm=None, 9 | optimizer=None, learning_rate=0.001, 10 | compile_model=True, loss=None, metrics=None 11 | ): 12 | model = tf.keras.models.Sequential() 13 | model.add(tf.keras.layers.Input(input_shape)) 14 | for _ in range(num_hidden): 15 | model.add(tf.keras.layers.Dense(hidden_size, activation=activation)) 16 | if dropout: 17 | model.add(tf.keras.layers.Dropout(dropout)) 18 | for _ in range(num_highway): 19 | model.add(Highway()) 20 | if dropout: 21 | model.add(tf.keras.layers.Dropout(dropout)) 22 | model.add(tf.keras.layers.Dense(output_shape, activation=output_activation)) 23 | if optimizer is None or optimizer == 'adam': 24 | optimizer = tf.keras.optimizers.Adam(learning_rate) 25 | if clipnorm is not None: 26 | optimizer.clipnorm = clipnorm 27 | if compile_model: 28 | model.compile( 29 | optimizer=optimizer, 30 | loss=loss, 31 | metrics=metrics 32 | ) 33 | return model 34 | 35 | def relevance(**kwargs): 36 | loss = sparse_categorical_crossentropy_from_logits 37 | metrics = [ 38 | top_k(k=1), 39 | top_k(k=10), 40 | top_k(k=50), 41 | top_k(k=100), 42 | ] 43 | options = { 44 | 'loss': loss, 45 | 'metrics': metrics 46 | } 47 | options.update(kwargs) 48 | return build_model(**options) 49 | 50 | def applicability(**kwargs): 51 | loss = tf.keras.losses.categorical_crossentropy 52 | metrics = [ 53 | tf.keras.metrics.Recall(), 54 | tf.keras.metrics.Precision() 55 | ] 56 | options = { 57 | 'loss': loss, 58 | 'metrics': metrics, 59 | 'output_activation': 'sigmoid' 60 | } 61 | options.update(kwargs) 62 | return build_model(**options) 63 | -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/rdkit.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit import Chem 3 | from rdkit.Chem import AllChem 4 | from joblib import Parallel, delayed 5 | 6 | def unmap(smiles, canonicalize=True): 7 | mol = Chem.MolFromSmiles(smiles) 8 | [atom.SetAtomMapNum(0) for atom in mol.GetAtoms()] 9 | return Chem.MolToSmiles(mol, isomericSmiles=canonicalize) 10 | 11 | def smiles_to_fingerprint(smi, length=2048, radius=2, useFeatures=False, useChirality=True): 12 | mol = Chem.MolFromSmiles(smi) 13 | if not mol: 14 | raise ValueError('Cannot parse {}'.format(smi)) 15 | fp_bit = AllChem.GetMorganFingerprintAsBitVect( 16 | mol=mol, radius=radius, nBits = length, 17 | useFeatures=useFeatures, useChirality=useChirality 18 | ) 19 | return np.array(fp_bit) 20 | 21 | def smiles_list_to_fingerprints(smiles_list, fp_length=2048, fp_radius=2, nproc=1): 22 | if nproc > 1: 23 | fps = Parallel(n_jobs=nproc, verbose=1)( 24 | delayed(smiles_to_fingerprint)(smi, length=fp_length, radius=fp_radius) for smi in smiles_list 25 | ) 26 | fps = np.array(fps) 27 | else: 28 | fps = np.array([smiles_to_fingerprint(smi, length=fp_length, radius=fp_radius) for smi in smiles_list]) 29 | return fps 30 | 31 | def remove_spectating_reactants(reaction_smiles): 32 | reactants, spectators, products = reaction_smiles.split('>') 33 | rmols = [Chem.MolFromSmiles(r) for r in reactants.split('.')] 34 | pmol = Chem.MolFromSmiles(products) 35 | product_map_numbers = set([atom.GetAtomMapNum() for atom in pmol.GetAtoms()]) 36 | react_strings = [] 37 | if spectators: 38 | spectators = spectators.split('.') 39 | else: 40 | spectators = [] 41 | for rmol in rmols: 42 | map_numbers = set([atom.GetAtomMapNum() for atom in rmol.GetAtoms()]) 43 | intersection = map_numbers.intersection(product_map_numbers) 44 | if intersection: 45 | react_strings.append(Chem.MolToSmiles(rmol)) 46 | else: 47 | spectators.append(Chem.MolToSmiles(rmol)) 48 | return '.'.join(react_strings) + '>' + '.'.join(spectators) + '>' + products 49 | 50 | def precursors_from_template(mol, template): 51 | precursors = set() 52 | results = template.RunReactants([mol]) 53 | for res in results: 54 | res_smiles = '.'.join(sorted([Chem.MolToSmiles(m, isomericSmiles=True) for m in res])) 55 | if Chem.MolFromSmiles(res_smiles): 56 | precursors.add(res_smiles) 57 | return list(precursors) 58 | 59 | def precursors_from_templates(target_smiles, templates, nproc=1): 60 | mol = Chem.MolFromSmiles(target_smiles) 61 | precursor_set_list = Parallel(n_jobs=nproc, verbose=1)( 62 | delayed(precursors_from_template)(mol, template) for template in templates 63 | ) 64 | return list(set().union(*precursor_set_list)) 65 | 66 | def templates_from_smarts_list(smarts_list, nproc=1): 67 | templates = Parallel(n_jobs=nproc, verbose=1)( 68 | delayed(AllChem.ReactionFromSmarts)(smarts) for smarts in smarts_list 69 | ) -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/templates/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/data/temprel-fortunato/template-relevance-master/temprel/templates/__init__.py -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/templates/extract.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import hashlib 5 | import numpy as np 6 | import pandas as pd 7 | from joblib import Parallel, delayed 8 | from .validate import validate_template 9 | from ..rdkit import smiles_to_fingerprint, remove_spectating_reactants 10 | from ..rdkit import unmap as rdkit_unmap 11 | from rdchiral.template_extractor import extract_from_reaction 12 | 13 | def create_hash(pd_row): 14 | return hashlib.md5(pd_row.to_json().encode()).hexdigest() 15 | 16 | def unmap(smarts): 17 | return re.sub(r':[0-9]+]', ']', smarts) 18 | 19 | def count_products(smiles): 20 | return 1 + smiles.count('.') 21 | 22 | def extract_template(reaction): 23 | try: 24 | return extract_from_reaction(reaction) 25 | except KeyboardInterrupt: 26 | print('Interrupted') 27 | raise KeyboardInterrupt 28 | except Exception as e: 29 | return { 30 | 'reaction_id': reaction['_id'], 31 | 'error': str(e) 32 | } 33 | 34 | def templates_from_reactions(df, nproc=8, out_json=None): 35 | if out_json is None: 36 | if not os.path.exists('data/processed'): 37 | os.makedirs('data/processed') 38 | out_json = 'data/processed/templates.df.json.gz' 39 | assert type(df) is pd.DataFrame 40 | assert 'reaction_smiles' in df.columns 41 | df['reaction_smiles'] = Parallel(n_jobs=nproc, verbose=1)( 42 | delayed(remove_spectating_reactants)(rsmi) for rsmi in df['reaction_smiles'] 43 | ) 44 | rxn_split = df['reaction_smiles'].str.split('>', expand=True) 45 | df[['reactants', 'spectators', 'products']] = rxn_split.rename( 46 | columns={ 47 | 0: 'reactants', 48 | 1: 'spectator', 49 | 2: 'products' 50 | } 51 | ) 52 | if '_id' not in df.columns: 53 | df['_id'] = df.apply(create_hash, axis=1) 54 | num_products = df['products'].apply(count_products) 55 | df = df[num_products==1] 56 | templates = Parallel(n_jobs=nproc, verbose=1)( 57 | delayed(extract_template)(reaction) for reaction in df.to_dict(orient='records') 58 | ) 59 | templates = pd.DataFrame(filter(lambda x: x, templates)) 60 | keep_cols = ['dimer_only', 'intra_only', 'necessary_reagent', 'reaction_id', 'reaction_smarts'] 61 | df = df.merge(templates[keep_cols], left_on='_id', right_on='reaction_id') 62 | df = df.dropna(subset=['reaction_smarts']) 63 | valid = Parallel(n_jobs=nproc, verbose=1)( 64 | delayed(validate_template)(template) for template in df.to_dict(orient='records') 65 | ) 66 | df = df[valid] 67 | df['unmapped_template'] = df['reaction_smarts'].apply(unmap) 68 | unique_templates = pd.DataFrame( 69 | df['unmapped_template'].unique(), columns=['unmapped_template'] 70 | ).reset_index() 71 | df = df.merge(unique_templates, on='unmapped_template') 72 | df = df.drop_duplicates(subset=['products', 'index']) 73 | df['count'] = df.groupby('index')['index'].transform('count') 74 | if out_json[-2] == 'gz': 75 | df.to_json(out_json, compression='gzip') 76 | else: 77 | df.to_json(out_json) 78 | return df 79 | 80 | def create_split(df, method, class_col='index'): 81 | df = df.sample(frac=1).reset_index(drop=True) 82 | if method == 'random': 83 | i80 = int(len(df)*0.8) 84 | i90 = int(len(df)*0.9) 85 | df.loc[:i80, 'split'] = 'train' 86 | df.loc[i80:i90, 'split'] = 'valid' 87 | df.loc[i90:, 'split'] = 'test' 88 | elif method == 'stratified': 89 | labels = df[class_col].values 90 | class_counts = np.bincount(df['index'].values) 91 | classes_with_1_example = np.argwhere(class_counts==1).reshape(-1) 92 | np.random.shuffle(classes_with_1_example) 93 | classes_with_2_example = np.argwhere(class_counts==2).reshape(-1) 94 | classes_with_few_example = np.argwhere((class_counts>2)&(class_counts<=10)).reshape(-1) 95 | classes_with_many_example = np.argwhere(class_counts>10).reshape(-1) 96 | train_ind = [] 97 | val_ind = [] 98 | test_ind = [] 99 | ind = np.argwhere(np.isin(labels, classes_with_1_example)).reshape(-1) 100 | i80 = int(len(ind)*0.8) 101 | i90 = int(len(ind)*0.9) 102 | train_ind.extend(ind[:i80]) 103 | val_ind.extend(ind[i80:i90]) 104 | test_ind.extend(ind[i90:]) 105 | for cls in classes_with_2_example: 106 | ind = np.argwhere(np.isin(labels, cls)).reshape(-1) 107 | train_ind.append(ind[0]) 108 | test_ind.append(ind[1]) 109 | for cls in classes_with_few_example: 110 | ind = np.argwhere(np.isin(labels, cls)).reshape(-1) 111 | np.random.shuffle(ind) 112 | test_ind.append(ind[0]) 113 | val_ind.append(ind[1]) 114 | train_ind.extend(ind[2:]) 115 | for cls in classes_with_many_example: 116 | ind = np.argwhere(np.isin(labels, cls)).reshape(-1) 117 | np.random.shuffle(ind) 118 | i80 = int(len(ind)*0.8) 119 | i90 = int(len(ind)*0.9) 120 | train_ind.extend(ind[:i80]) 121 | val_ind.extend(ind[i80:i90]) 122 | test_ind.extend(ind[i90:]) 123 | df.loc[train_ind, 'split'] = 'train' 124 | df.loc[val_ind, 'split'] = 'valid' 125 | df.loc[test_ind, 'split'] = 'test' 126 | df = df.sample(frac=1) 127 | return df 128 | 129 | def process_for_training(templates_df, output_prefix, split_col='split', calc_split=None, smiles_col='products', class_col='index'): 130 | if calc_split in ['random', 'stratified']: 131 | templates_df = create_split(templates_df, calc_split, class_col) 132 | 133 | if split_col not in templates_df.columns: 134 | raise ValueError( 135 | 'split column "{}" not in DataFrame.' 136 | ) 137 | 138 | if smiles_col not in templates_df.columns: 139 | ValueError( 140 | 'smiles column "{}" not in DataFrame.' 141 | ) 142 | 143 | if class_col not in templates_df.columns: 144 | ValueError( 145 | 'class column "{}" not in DataFrame.' 146 | ) 147 | 148 | if output_prefix is None: 149 | output_prefix = '' 150 | 151 | train_df = templates_df[templates_df[split_col] == 'train'] 152 | valid_df = templates_df[templates_df[split_col] == 'valid'] 153 | test_df = templates_df[templates_df[split_col] == 'test'] 154 | 155 | train_smiles = train_df[smiles_col].values.astype(str) 156 | valid_smiles = valid_df[smiles_col].values.astype(str) 157 | test_smiles = test_df[smiles_col].values.astype(str) 158 | 159 | train_classes = train_df[class_col].values.astype(int) 160 | valid_classes = valid_df[class_col].values.astype(int) 161 | test_classes = test_df[class_col].values.astype(int) 162 | 163 | np.save(output_prefix+'train.input.smiles.npy', train_smiles) 164 | np.save(output_prefix+'valid.input.smiles.npy', valid_smiles) 165 | np.save(output_prefix+'test.input.smiles.npy', test_smiles) 166 | 167 | np.save(output_prefix+'train.labels.classes.npy', train_classes) 168 | np.save(output_prefix+'valid.labels.classes.npy', valid_classes) 169 | np.save(output_prefix+'test.labels.classes.npy', test_classes) 170 | 171 | def process_for_askcos(templates_df, template_set_name, output_prefix, nproc=8): 172 | templates_df['reaction_id'] = range(len(templates_df)) 173 | template_references = templates_df.groupby('index')['reaction_id'].apply(list) 174 | templates_df['references'] = templates_df['index'].map(template_references) 175 | templates_df['template_set'] = template_set_name 176 | template_columns_to_keep = [ 177 | 'index', 'reaction_smarts', 'necessary_reagent', 178 | 'intra_only', 'dimer_only', 'count', 'template_set', 179 | 'references' 180 | ] 181 | templates_to_save = templates_df.sort_values('index').drop_duplicates('index')[template_columns_to_keep] 182 | reactants_concat = templates_df['reactants'].values 183 | products = templates_df['products'].values 184 | reactants = [] 185 | for react in reactants_concat: 186 | reactants.extend(react.split('.')) 187 | reactants = Parallel(n_jobs=nproc, verbose=1)( 188 | delayed(rdkit_unmap)(smiles) for smiles in reactants 189 | ) 190 | products = Parallel(n_jobs=nproc, verbose=1)( 191 | delayed(rdkit_unmap)(smiles) for smiles in products 192 | ) 193 | react_df = pd.DataFrame(reactants, columns=['as_reactant']) 194 | prod_df = pd.DataFrame(products, columns=['as_product']) 195 | react_count = react_df.groupby('as_reactant')['as_reactant'].count() 196 | prod_count = prod_df.groupby('as_product')['as_product'].count() 197 | react_count = react_count.to_frame() 198 | react_count.index.name = 'smiles' 199 | prod_count = prod_count.to_frame() 200 | prod_count.index.name = 'smiles' 201 | historian_df = react_count.merge( 202 | prod_count, how='outer', left_index=True, right_index=True 203 | ).fillna(0).astype(int).reset_index() 204 | historian_df['template_set'] = template_set_name 205 | 206 | templates_to_save['_id'] = templates_to_save.apply(create_hash, axis=1) 207 | templates_to_save.to_json( 208 | output_prefix+'retro.templates.{}.json.gz'.format(template_set_name), 209 | orient='records', 210 | compression='gzip' 211 | ) 212 | 213 | reactions_df = templates_df.drop(columns=template_columns_to_keep+['unmapped_template']) 214 | reactions_df['template_set'] = template_set_name 215 | reactions_df.to_json( 216 | output_prefix+'reactions.{}.json.gz'.format(template_set_name), 217 | orient='records', 218 | compression='gzip' 219 | ) 220 | 221 | historian_df.to_json( 222 | output_prefix+'historian.{}.json.gz'.format(template_set_name), 223 | orient='records', 224 | compression='gzip' 225 | ) -------------------------------------------------------------------------------- /data/temprel-fortunato/template-relevance-master/temprel/templates/validate.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | from rdkit.Chem import AllChem 3 | 4 | def validate_template(template): 5 | rxn = AllChem.ReactionFromSmarts(template['reaction_smarts']) 6 | if not rxn: 7 | return False 8 | prod_mol = Chem.MolFromSmiles(template['products']) 9 | if not prod_mol: 10 | return False 11 | react_mol = Chem.MolFromSmiles(template['reactants']) 12 | if not react_mol: 13 | return False 14 | for atom in react_mol.GetAtoms(): 15 | atom.SetAtomMapNum(0) 16 | try: 17 | rxn_res = rxn.RunReactants([prod_mol]) 18 | except ValueError: 19 | return False 20 | if not rxn_res: 21 | return False 22 | for res in rxn_res: 23 | react_smi = '.'.join([Chem.MolToSmiles(react) for react in res]) 24 | res_mol = Chem.MolFromSmiles(react_smi) 25 | if not res_mol: 26 | continue 27 | for atom in res_mol.GetAtoms(): 28 | atom.SetAtomMapNum(0) 29 | react_smi = Chem.MolToSmiles(res_mol) 30 | if react_smi in Chem.MolToSmiles(react_mol): 31 | return True 32 | return False 33 | 34 | def mapping_validity(reaction_smiles): 35 | reactants, spectators, products = reaction_smiles.split('>') 36 | reactant_mol = Chem.MolFromSmiles(reactants) 37 | product_mol = Chem.MolFromSmiles(products) 38 | 39 | reactant_mapping = {} 40 | for atom in reactant_mol.GetAtoms(): 41 | map_number = atom.GetAtomMapNum() 42 | if not map_number: continue 43 | if map_number in reactant_mapping: 44 | return 'DUPLICATE_REACTANT_MAPPING' 45 | reactant_mapping[map_number] = atom.GetIdx() 46 | 47 | product_mapping = {} 48 | for atom in product_mol.GetAtoms(): 49 | map_number = atom.GetAtomMapNum() 50 | if not map_number: continue 51 | if map_number in product_mapping: 52 | return 'DUPLICATE_PRODUCT_MAPPING' 53 | product_mapping[map_number] = atom.GetIdx() 54 | 55 | if len(reactant_mapping) < len(product_mapping): 56 | return 'UNMAPPED_REACTANT_ATOM(S)' 57 | 58 | for map_number in product_mapping.keys(): 59 | if map_number not in reactant_mapping: 60 | return 'UNMAPPED_PRODUCT_ATOM' 61 | reactant_atom = reactant_mol.GetAtomWithIdx(reactant_mapping[map_number]) 62 | product_atom = product_mol.GetAtomWithIdx(product_mapping[map_number]) 63 | 64 | if reactant_atom.GetSymbol() != product_atom.GetSymbol(): 65 | return 'ALCHEMY' 66 | 67 | return 'VALID' -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: mhnreact_env 2 | channels: 3 | - bioconda 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | # change to your cuda version 8 | - cudatoolkit=10.2 9 | - torch==1.6 10 | - torchvision==0.7 11 | - pandas=1.0.5 12 | - pip=20.1.1=py_1 13 | - python=3.7 14 | - rdkit=2021.03.1 #2020.03.4 15 | # optionally 16 | - ipython 17 | - jupyterlab 18 | - pip: 19 | - numpy==1.19 20 | - scikit-learn==0.23.1 21 | - scipy==1.4 22 | - hydra-core 23 | - tqdm 24 | - rdchiral==1.1.0 25 | -------------------------------------------------------------------------------- /mhnreact/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/mhnreact/.gitkeep -------------------------------------------------------------------------------- /mhnreact/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | -------------------------------------------------------------------------------- /mhnreact/data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Author: Philipp Seidl 4 | ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning 5 | Johannes Kepler University Linz 6 | Contact: seidl@ml.jku.at 7 | 8 | File contains functions that help prepare and download USPTO-related datasets 9 | """ 10 | 11 | import os 12 | import gzip 13 | import pickle 14 | import requests 15 | import subprocess 16 | import pandas as pd 17 | import numpy as np 18 | from scipy import sparse 19 | import json 20 | 21 | def download_temprel_repo(save_path='data/temprel-fortunato', chunk_size=128): 22 | "downloads the template-relevance master branch" 23 | url = "https://gitlab.com/mefortunato/template-relevance/-/archive/master/template-relevance-master.zip" 24 | r = requests.get(url, stream=True) 25 | with open(save_path, 'wb') as fd: 26 | for chunk in r.iter_content(chunk_size=chunk_size): 27 | fd.write(chunk) 28 | 29 | def unzip(path): 30 | "unzips a file given a path" 31 | import zipfile 32 | with zipfile.ZipFile(path, 'r') as zip_ref: 33 | zip_ref.extractall(path.replace('.zip','')) 34 | 35 | 36 | def download_file(url, output_path=None): 37 | """ 38 | # code from fortunato 39 | # could also import from temprel.data.download import get_uspto_50k but slightly altered ;) 40 | 41 | """ 42 | if not output_path: 43 | output_path = url.split('/')[-1] 44 | with requests.get(url, stream=True) as r: 45 | r.raise_for_status() 46 | with open(output_path, 'wb') as f: 47 | for chunk in r.iter_content(chunk_size=8192): 48 | if chunk: 49 | f.write(chunk) 50 | 51 | def get_uspto_480k(): 52 | if not os.path.exists('data'): 53 | os.mkdir('data') 54 | if not os.path.exists('data/raw'): 55 | os.mkdir('data/raw') 56 | os.chdir('data/raw') 57 | download_file( 58 | 'https://github.com/connorcoley/rexgen_direct/raw/master/rexgen_direct/data/train.txt.tar.gz', 59 | 'train.txt.tar.gz' 60 | ) 61 | subprocess.run(['tar', 'zxf', 'train.txt.tar.gz']) 62 | download_file( 63 | 'https://github.com/connorcoley/rexgen_direct/raw/master/rexgen_direct/data/valid.txt.tar.gz', 64 | 'valid.txt.tar.gz' 65 | ) 66 | subprocess.run(['tar', 'zxf', 'valid.txt.tar.gz']) 67 | download_file( 68 | 'https://github.com/connorcoley/rexgen_direct/raw/master/rexgen_direct/data/test.txt.tar.gz', 69 | 'test.txt.tar.gz' 70 | ) 71 | subprocess.run(['tar', 'zxf', 'test.txt.tar.gz']) 72 | 73 | with open('train.txt') as f: 74 | train = [ 75 | { 76 | 'reaction_smiles': line.strip(), 77 | 'split': 'train' 78 | } 79 | for line in f.readlines() 80 | ] 81 | with open('valid.txt') as f: 82 | valid = [ 83 | { 84 | 'reaction_smiles': line.strip(), 85 | 'split': 'valid' 86 | } 87 | for line in f.readlines() 88 | ] 89 | with open('test.txt') as f: 90 | test = [ 91 | { 92 | 'reaction_smiles': line.strip(), 93 | 'split': 'test' 94 | } 95 | for line in f.readlines() 96 | ] 97 | 98 | df = pd.concat([ 99 | pd.DataFrame(train), 100 | pd.DataFrame(valid), 101 | pd.DataFrame(test) 102 | ]).reset_index() 103 | df.to_json('uspto_lg_reactions.json.gz', compression='gzip') 104 | os.chdir('..') 105 | os.chdir('..') 106 | return df 107 | 108 | def get_uspto_50k(): 109 | ''' 110 | get SI from: 111 | Nadine Schneider; Daniel M. Lowe; Roger A. Sayle; Gregory A. Landrum. J. Chem. Inf. Model.201555139-53 112 | ''' 113 | if not os.path.exists('data'): 114 | os.mkdir('data') 115 | if not os.path.exists('data/raw'): 116 | os.mkdir('data/raw') 117 | os.chdir('data/raw') 118 | subprocess.run(['wget', 'https://pubs.acs.org/doi/suppl/10.1021/ci5006614/suppl_file/ci5006614_si_002.zip']) 119 | subprocess.run(['unzip', '-o', 'ci5006614_si_002.zip']) 120 | data = [] 121 | with gzip.open('ChemReactionClassification/data/training_test_set_patent_data.pkl.gz') as f: 122 | while True: 123 | try: 124 | data.append(pickle.load(f)) 125 | except EOFError: 126 | break 127 | reaction_smiles = [d[0] for d in data] 128 | reaction_reference = [d[1] for d in data] 129 | reaction_class = [d[2] for d in data] 130 | df = pd.DataFrame() 131 | df['reaction_smiles'] = reaction_smiles 132 | df['reaction_reference'] = reaction_reference 133 | df['reaction_class'] = reaction_class 134 | df.to_json('uspto_sm_reactions.json.gz', compression='gzip') 135 | os.chdir('..') 136 | os.chdir('..') 137 | return df 138 | 139 | def get_uspto_golden(): 140 | """ get uspto golden and convert it to smiles dataframe from 141 | Lin, Arkadii; Dyubankova, Natalia; Madzhidov, Timur; Nugmanov, Ramil; 142 | Rakhimbekova, Assima; Ibragimova, Zarina; Akhmetshin, Tagir; Gimadiev, 143 | Timur; Suleymanov, Rail; Verhoeven, Jonas; Wegner, Jörg Kurt; 144 | Ceulemans, Hugo; Varnek, Alexandre (2020): 145 | Atom-to-Atom Mapping: A Benchmarking Study of Popular Mapping Algorithms and Consensus Strategies. 146 | ChemRxiv. Preprint. https://doi.org/10.26434/chemrxiv.13012679.v1 147 | """ 148 | if os.path.exists('data/raw/uspto_golden.json.gz'): 149 | print('loading precomputed') 150 | return pd.read_json('data/raw/uspto_golden.json.gz', compression='gzip') 151 | if not os.path.exists('data'): 152 | os.mkdir('data') 153 | if not os.path.exists('data/raw'): 154 | os.mkdir('data/raw') 155 | os.chdir('data/raw') 156 | subprocess.run(['wget', 'https://github.com/Laboratoire-de-Chemoinformatique/Reaction_Data_Cleaning/raw/master/data/golden_dataset.zip']) 157 | subprocess.run(['unzip', '-o', 'golden_dataset.zip']) #return golden_dataset.rdf 158 | 159 | from CGRtools.files import RDFRead 160 | import CGRtools 161 | from rdkit.Chem import AllChem 162 | def cgr2rxnsmiles(cgr_rx): 163 | smiles_rx = '.'.join([AllChem.MolToSmiles(CGRtools.to_rdkit_molecule(m)) for m in cgr_rx.reactants]) 164 | smiles_rx += '>>'+'.'.join([AllChem.MolToSmiles(CGRtools.to_rdkit_molecule(m)) for m in cgr_rx.products]) 165 | return smiles_rx 166 | 167 | data = {} 168 | input_file = 'golden_dataset.rdf' 169 | do_basic_standardization=True 170 | print('reading and converting the rdf-file') 171 | with RDFRead(input_file) as f: 172 | while True: 173 | try: 174 | r = next(f) 175 | key = r.meta['Reaction_ID'] 176 | if do_basic_standardization: 177 | r.thiele() 178 | r.standardize() 179 | data[key] = cgr2rxnsmiles(r) 180 | except StopIteration: 181 | break 182 | 183 | print('saving as a dataframe to data/uspto_golden.json.gz') 184 | df = pd.DataFrame([data],index=['reaction_smiles']).T 185 | df['reaction_reference'] = df.index 186 | df.index = range(len(df)) #reindex 187 | df.to_json('uspto_golden.json.gz', compression='gzip') 188 | 189 | os.chdir('..') 190 | os.chdir('..') 191 | return df 192 | 193 | def load_USPTO_fortu(path='data/processed', which='uspto_sm_', is_appl_matrix=False): 194 | """ 195 | loads the fortunato preprocessed data as 196 | dict X containing X['train'], X['valid'], and X['test'] 197 | as well as the labels containing the corresponding splits 198 | returns X, y 199 | """ 200 | 201 | X = {} 202 | y = {} 203 | 204 | for split in ['train','valid', 'test']: 205 | tmp = np.load(f'{path}/{which}{split}.input.smiles.npy', allow_pickle=True) 206 | X[split] = [] 207 | for ii in range(len(tmp)): 208 | X[split].append( tmp[ii].split('.')) 209 | 210 | if is_appl_matrix: 211 | y[split] = sparse.load_npz(f'{path}/{which}{split}.appl_matrix.npz') 212 | else: 213 | y[split] = np.load(f'{path}/{which}{split}.labels.classes.npy', allow_pickle=True) 214 | print(split, y[split].shape[0], 'samples (', y[split].max() if not is_appl_matrix else y[split].shape[1],'max label)') 215 | return X, y 216 | 217 | #TODO one should load in this file pd.read_json('uspto_R_retro.templates.uspto_R_.json.gz') 218 | # this only holds the templates.. the other holds everything 219 | def load_templates_sm(path = 'data/processed/uspto_sm_templates.df.json.gz', get_complete_df=False): 220 | "returns a dict mapping from class index to mapped reaction_smarts from the templates_df" 221 | df = pd.read_json(path) 222 | if get_complete_df: return df 223 | template_dict = {} 224 | for row in range(len(df)): 225 | template_dict[df.iloc[row]['index']] = df.iloc[row].reaction_smarts 226 | return template_dict 227 | 228 | def load_templates_lg(path = 'data/processed/uspto_lg_templates.df.json.gz', get_complete_df=False): 229 | return load_templates_sm(path=path, get_complete_df=get_complete_df) 230 | 231 | def load_USPTO_sm(): 232 | "loads the default dataset" 233 | return load_USPTO_fortu(which='uspto_sm_') 234 | 235 | def load_USPTO_lg(): 236 | "loads the default dataset" 237 | return load_USPTO_fortu(which='uspto_lg_') 238 | 239 | def load_USPTO_sm_pretraining(): 240 | "loads the default application matrix label and dataset" 241 | return load_USPTO_fortu(which='uspto_sm_', is_appl_matrix=True) 242 | def load_USPTO_lg_pretraining(): 243 | "loads the default application matrix label and dataset" 244 | return load_USPTO_fortu(which='uspto_lg_', is_appl_matrix=True) 245 | 246 | def load_USPTO_df_sm(): 247 | "loads the USPTO small Sm dataset dataframe" 248 | return pd.read_json('data/raw/uspto_sm_reactions.json.gz') 249 | 250 | def load_USPTO_df_lg(): 251 | "loads the USPTO large Lg dataset dataframe" 252 | return pd.read_json('data/raw/uspto_sm_reactions.json.gz') 253 | 254 | def load_USPTO_golden(): 255 | "loads the golden USPTO dataset" 256 | return load_USPTO_fortu(which=f'uspto_golden_', is_appl_matrix=False) 257 | 258 | def load_USPTO(which = 'sm', is_appl_matrix=False): 259 | return load_USPTO_fortu(which=f'uspto_{which}_', is_appl_matrix=is_appl_matrix) 260 | 261 | def load_templates(which = 'sm',fdir='data/processed', get_complete_df=False): 262 | return load_templates_sm(path=f'{fdir}/uspto_{which}_templates.df.json.gz', get_complete_df=get_complete_df) 263 | 264 | def load_data(dataset, path): 265 | splits = ['train', 'valid', 'test'] 266 | split2smiles = {} 267 | split2label = {} 268 | split2reactants = {} 269 | split2appl = {} 270 | split2prod_idx_reactants = {} 271 | 272 | for split in splits: 273 | label_fn = os.path.join(path, f'{dataset}_{split}.labels.classes.npy') 274 | split2label[split] = np.load(label_fn, allow_pickle=True) 275 | 276 | smiles_fn = os.path.join(path, f'{dataset}_{split}.input.smiles.npy') 277 | split2smiles[split] = np.load(smiles_fn, allow_pickle=True) 278 | 279 | reactants_fn = os.path.join(path, f'uspto_R_{split}.reactants.canonical.npy') 280 | split2reactants[split] = np.load(reactants_fn, allow_pickle=True) 281 | 282 | 283 | split2appl[split] = np.load(os.path.join(path, f'{dataset}_{split}.applicability.npy')) 284 | 285 | pir_fn = os.path.join(path, f'{dataset}_{split}.prod.idx.reactants.p') 286 | if os.path.isfile(pir_fn): 287 | with open(pir_fn, 'rb') as f: 288 | split2prod_idx_reactants[split] = pickle.load(f) 289 | 290 | 291 | if len(split2prod_idx_reactants) == 0: 292 | split2prod_idx_reactants = None 293 | 294 | with open(os.path.join(path, f'{dataset}_templates.json'), 'r') as f: 295 | label2template = json.load(f) 296 | label2template = {int(k): v for k,v in label2template.items()} 297 | 298 | return split2smiles, split2label, split2reactants, split2appl, split2prod_idx_reactants, label2template 299 | 300 | 301 | def load_dataset_from_csv(csv_path='', split_col='split', input_col='prod_smiles', ssretroeval=False, reactants_col='reactants_can', ret_df=False, **kwargs): 302 | """loads the dataset from a CSV file containing a split-column, and input-column which can be defined, 303 | as well as a 'reaction_smarts' column containing the extracted template, a 'label' column (the index of the template) 304 | :returns 305 | 306 | """ 307 | print('loading X, y from csv') 308 | df = pd.read_csv(csv_path) 309 | X = {} 310 | y = {} 311 | 312 | for spli in set(df[split_col]): 313 | #X[spli] = list(df[df[split_col]==spli]['prod_smiles'].apply(lambda k: [k])) 314 | X[spli] = list(df[df[split_col]==spli][input_col].apply(lambda k: [k])) 315 | y[spli] = (df[df[split_col]==spli]['label']).values 316 | print(spli, len(X[spli]), 'samples') 317 | 318 | # template to dict 319 | tmp = df[['reaction_smarts','label']].drop_duplicates(subset=['reaction_smarts','label']).sort_values('label') 320 | tmp.index= tmp.label 321 | template_list = tmp['reaction_smarts'].to_dict() 322 | print(len(template_list),'templates') 323 | 324 | if ssretroeval: 325 | # setup for ttest 326 | test_reactants_can = list(df[df[split_col]=='test'][reactants_col]) 327 | 328 | only_in_test = set(y['test']) - set(y['train']).union(set(y['valid'])) 329 | print('obfuscating', len(only_in_test), 'templates because they are only in test') 330 | for ii in only_in_test: 331 | template_list[ii] = 'CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC.CCCCCCCCCCCCCCCCCCCCCCCCCCC.CCCCCCCCCCCCCCCCCCCCCC>>CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC.CCCCCCCCCCCCCCCCCCCCC' #obfuscate them 332 | if ret_df: 333 | return X, y, template_list, test_reactants_can, df 334 | return X, y, template_list, test_reactants_can 335 | 336 | if ret_df: 337 | return X, y, template_list, None, df 338 | return X, y, template_list, None -------------------------------------------------------------------------------- /mhnreact/inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Author: Philipp Seidl 4 | ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning 5 | Johannes Kepler University Linz 6 | Contact: seidl@ml.jku.at 7 | 8 | File contains functions that help prepare and download USPTO-related datasets 9 | """ 10 | 11 | # Cell 12 | from .model import ModelConfig, MHN 13 | import torch -------------------------------------------------------------------------------- /mhnreact/inspect.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Author: Philipp Seidl 4 | ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning 5 | Johannes Kepler University Linz 6 | Contact: seidl@ml.jku.at 7 | 8 | File contains functions that 9 | """ 10 | 11 | from . import model 12 | import torch 13 | import os 14 | 15 | MODEL_PATH = 'data/model/' 16 | 17 | def smarts2svg(smarts, useSmiles=True, highlightByReactant=True, save_to=''): 18 | """ 19 | draws smiles of smarts to an SVG and displays it in the Notebook, 20 | or optinally can be saved to a file `save_to` 21 | adapted from https://www.kesci.com/mw/project/5c7685191ce0af002b556cc5 22 | """ 23 | # adapted from https://www.kesci.com/mw/project/5c7685191ce0af002b556cc5 24 | from rdkit import RDConfig 25 | from rdkit import Chem 26 | from rdkit.Chem import Draw, AllChem 27 | from rdkit.Chem.Draw import rdMolDraw2D 28 | from rdkit import Geometry 29 | import matplotlib.pyplot as plt 30 | import matplotlib.cm as cm 31 | import matplotlib 32 | from IPython.display import SVG, display 33 | 34 | rxn = AllChem.ReactionFromSmarts(smarts,useSmiles=useSmiles) 35 | d = Draw.MolDraw2DSVG(900, 100) 36 | 37 | # rxn = AllChem.ReactionFromSmarts('[CH3:1][C:2](=[O:3])[OH:4].[CH3:5][NH2:6]>CC(O)C.[Pt]>[CH3:1][C:2](=[O:3])[NH:6][CH3:5].[OH2:4]',useSmiles=True) 38 | colors=[(0.3,0.7,0.9),(0.9,0.7,0.9),(0.6,0.9,0.3),(0.9,0.9,0.1)] 39 | try: 40 | d.DrawReaction(rxn,highlightByReactant=highlightByReactant) 41 | d.FinishDrawing() 42 | 43 | txt = d.GetDrawingText() 44 | # self.assertTrue(txt.find("") != -1) 46 | 47 | svg = d.GetDrawingText() 48 | svg2 = svg.replace('svg:','') 49 | svg3 = SVG(svg2) 50 | display(svg3) 51 | 52 | if save_to!='': 53 | with open(save_to, 'w') as f_handle: 54 | f_handle.write(svg3.data) 55 | except: 56 | print('Error drawing') 57 | 58 | return svg2 59 | 60 | def list_models(model_path=MODEL_PATH): 61 | """returns a list of loadable models""" 62 | return dict(enumerate(list(filter(lambda k: str(k)[-3:]=='.pt', os.listdir(model_path))))) 63 | 64 | def load_clf(model_fn='', model_path=MODEL_PATH, device='cpu', model_type='mhn'): 65 | """ returns the model with loaded weights given a filename""" 66 | import json 67 | config_fn = '_'.join(model_fn.split('_')[-2:]).split('.pt')[0] 68 | conf_dict = json.load( open( f"{model_path}{config_fn}_config.json" ) ) 69 | train_conf_dict = json.load( open( f"{model_path}{config_fn}_config.json" ) ) 70 | 71 | # specify the config the saved model had 72 | conf = model.ModelConfig(**conf_dict) 73 | conf.device = device 74 | print(conf.__dict__) 75 | 76 | if model_type == 'staticQK': 77 | clf = model.StaticQK(conf) 78 | elif model_type == 'mhn': 79 | clf = model.MHN(conf) 80 | elif model_type == 'segler': 81 | clf = model.SeglerBaseline(conf) 82 | elif model_type == 'fortunato': 83 | clf = model.SeglerBaseline(conf) 84 | else: 85 | raise NotImplementedError('model_type',model_type,'not found') 86 | 87 | # load the model 88 | PATH = model_path+model_fn 89 | params = torch.load(PATH) 90 | clf.load_state_dict(params, strict=False) 91 | if 'templates+noise' in params.keys(): 92 | print('loading templates+noise') 93 | clf.templates = params['templates+noise'] 94 | #clf.templates.to(clf.config.device) 95 | return clf -------------------------------------------------------------------------------- /mhnreact/plotutils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Author: Philipp Seidl 4 | ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning 5 | Johannes Kepler University Linz 6 | Contact: seidl@ml.jku.at 7 | 8 | Plot utils 9 | """ 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import matplotlib.pyplot as plt 14 | from matplotlib import pyplot as plt 15 | 16 | plt.style.use('default') 17 | 18 | 19 | def normal_approx_interval(p_hat, n, z=1.96): 20 | """ approximating the distribution of error about a binomially-distributed observation, {\hat {p)), with a normal distribution 21 | z = 1.96 --> alpha =0.05 22 | z = 1 --> std 23 | https://www.wikiwand.com/en/Binomial_proportion_confidence_interval""" 24 | return z*((p_hat*(1-p_hat))/n)**(1/2) 25 | 26 | 27 | our_colors = { 28 | "lightblue": ( 0/255, 132/255, 187/255), 29 | "red": (217/255, 92/255, 76/255), 30 | "blue": ( 0/255, 132/255, 187/255), 31 | "green": ( 91/255, 167/255, 85/255), 32 | "yellow": (241/255, 188/255, 63/255), 33 | "cyan": ( 79/255, 176/255, 191/255), 34 | "grey": (125/255, 130/255, 140/255), 35 | "lightgreen":(191/255, 206/255, 82/255), 36 | "violett": (174/255, 97/255, 157/255), 37 | } 38 | 39 | 40 | def plot_std(p_hats, n_samples,z=1.96, color=our_colors['red'], alpha=0.2, xs=None): 41 | p_hats = np.array(p_hats) 42 | stds = np.array([normal_approx_interval(p_hats[ii], n_samples[ii], z=z) for ii in range(len(p_hats))]) 43 | xs = range(len(p_hats)) if xs is None else xs 44 | plt.fill_between(xs, p_hats-(stds), p_hats+stds, color=color, alpha=alpha) 45 | #plt.errorbar(range(13), asdf, [normal_approx_interval(asdf[ii], n_samples[ii], z=z) for ii in range(len(asdf))], 46 | # c=our_colors['red'], linestyle='None', marker='.', ecolor=our_colors['red']) 47 | 48 | 49 | def plot_loss(hist): 50 | plt.plot(hist['step'], hist['loss'] ) 51 | plt.plot(hist['steps_valid'], np.array(hist['loss_valid'])) 52 | plt.legend(['train','validation']) 53 | plt.xlabel('update-step') 54 | plt.ylabel('loss (categorical-crossentropy-loss)') 55 | 56 | 57 | def plot_topk(hist, sets=['train', 'valid', 'test'], with_last = 2): 58 | ks = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 100] 59 | baseline_val_res = {1:0.4061, 10:0.6827, 50: 0.7883, 100:0.8400} 60 | plt.plot(list(baseline_val_res.keys()), list(baseline_val_res.values()), 'k.--') 61 | for i in range(1,with_last): 62 | for s in sets: 63 | plt.plot(ks, [hist[f't{k}_acc_{s}'][-i] for k in ks],'.--', alpha=1/i) 64 | plt.xlabel('top-k') 65 | plt.ylabel('Accuracy') 66 | plt.legend(sets) 67 | plt.title('Hopfield-NN') 68 | plt.ylim([-0.02,1]) 69 | 70 | 71 | def plot_nte(hist, dataset='Sm', last_cpt=1, include_bar=True, model_legend='MHN (ours)', 72 | draw_std=True, z=1.96, n_samples=None, group_by_template_fp=False, schwaller_hist=None, fortunato_hist=None): #1.96 for 95%CI 73 | markers = ['.']*4#['1','2','3','4']#['8','P','p','*'] 74 | lw = 2 75 | ms = 8 76 | k = 100 77 | ntes = range(13) 78 | if dataset=='Sm': 79 | basel_values = [0. , 0.38424785, 0.66807858, 0.7916149 , 0.9051132 , 80 | 0.92531258, 0.87295875, 0.94865587, 0.91830721, 0.95993717, 81 | 0.97215858, 0.9896713 , 0.99917817] #old basel_values = [0.0, 0.3882, 0.674, 0.7925, 0.9023, 0.9272, 0.874, 0.947, 0.9185, 0.959, 0.9717, 0.9927, 1.0] 82 | pretr_values = [0.08439423, 0.70743412, 0.85555528, 0.95200267, 0.96513376, 83 | 0.96976397, 0.98373613, 0.99960286, 0.98683919, 0.96684724, 84 | 0.95907246, 0.9839079 , 0.98683919]# old [0.094, 0.711, 0.8584, 0.952, 0.9683, 0.9717, 0.988, 1.0, 1.0, 0.984, 0.9717, 1.0, 1.0] 85 | staticQK = [0.2096, 0.1992, 0.2291, 0.1787, 0.2301, 0.1753, 0.2142, 0.2693, 0.2651, 0.1786, 0.2834, 0.5366, 0.6636] 86 | if group_by_template_fp: 87 | staticQK = [0.2651, 0.2617, 0.261 , 0.2181, 0.2622, 0.2393, 0.2157, 0.2184, 0.2 , 0.225 , 0.2039, 0.4568, 0.5293] 88 | if dataset=='Lg': 89 | pretr_values = [0.03410448, 0.65397054, 0.7254572 , 0.78969294, 0.81329924, 90 | 0.8651173 , 0.86775655, 0.8593128 , 0.88184124, 0.87764794, 91 | 0.89734215, 0.93328846, 0.99531597] 92 | basel_values = [0. , 0.62478044, 0.68784314, 0.75089511, 0.77044644, 93 | 0.81229423, 0.82968149, 0.82965544, 0.83778338, 0.83049176, 94 | 0.8662873 , 0.92308414, 1.00042408] 95 | #staticQK = [0.03638, 0.0339 , 0.03732, 0.03506, 0.03717, 0.0331 , 0.03003, 0.03613, 0.0304 , 0.02109, 0.0297 , 0.02632, 0.02217] # on 90k templates 96 | staticQK = [0.006416,0.00686, 0.00616, 0.00825, 0.005085,0.006718,0.01041, 0.0015335,0.006668,0.004673,0.001706,0.02551,0.04074] 97 | if dataset=='Golden': 98 | staticQK = [0]*13 99 | pretr_values = [0]*13 100 | basel_values = [0]*13 101 | 102 | if schwaller_hist: 103 | midx = np.argmin(schwaller_hist['loss_valid']) 104 | basel_values = ([schwaller_hist[f't100_acc_nte_{k}'][midx] for k in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, '>10', '>49']]) 105 | if fortunato_hist: 106 | midx = np.argmin(fortunato_hist['loss_valid']) 107 | pretr_values = ([fortunato_hist[f't100_acc_nte_{k}'][midx] for k in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, '>10', '>49']]) 108 | 109 | #hand_val = [0.0 , 0.4, 0.68, 0.79, 0.89, 0.91, 0.86, 0.9,0.88, 0.9, 0.93] 110 | 111 | 112 | if include_bar: 113 | if dataset=='Sm': 114 | if n_samples is None: 115 | n_samples = [610, 1699, 287, 180, 143, 105, 70, 48, 124, 86, 68, 2539, 1648] 116 | if group_by_template_fp: 117 | n_samples = [460, 993, 433, 243, 183, 117, 102, 87, 110, 80, 103, 3048, 2203] 118 | if dataset=='Lg': 119 | if n_samples is None: 120 | n_samples = [18861, 32226, 4220, 2546, 1573, 1191, 865, 652, 1350, 642, 586, 11638, 4958] #new 121 | if group_by_template_fp: 122 | n_samples = [13923, 17709, 7637, 4322, 2936, 2137, 1586, 1260, 1272, 1044, 829, 21695, 10559] 123 | #[5169, 15904, 2814, 1853, 1238, 966, 766, 609, 1316, 664, 640, 30699, 21471] 124 | #[13424,17246, 7681, 4332, 2844,2129,1698,1269, 1336,1067, 833, 22491, 11202] #grouped fp 125 | plt.bar(range(11+2), np.array(n_samples)/sum(n_samples[:-1]), alpha=0.4, color=our_colors['grey']) 126 | 127 | xti = [*[str(i) for i in range(11)], '>10', '>49'] 128 | asdf = [] 129 | for nte in xti: 130 | try: 131 | asdf.append( hist[f't{k}_acc_nte_{nte}'][-last_cpt]) 132 | except: 133 | asdf.append(None) 134 | 135 | plt.plot(range(13), asdf,f'{markers[3]}--', markersize=ms,c=our_colors['red'], linewidth=lw,alpha=1) 136 | plt.plot(ntes, pretr_values,f'{markers[1]}--', c=our_colors['green'], 137 | linewidth=lw, alpha=1,markersize=ms) #old [0.08, 0.7, 0.85, 0.9, 0.91, 0.95, 0.98, 0.97,0.98, 1, 1] 138 | plt.plot(ntes, basel_values,f'{markers[0]}--',linewidth=lw, 139 | c=our_colors['blue'], markersize=ms,alpha=1) 140 | plt.plot(range(len(staticQK)), staticQK, f'{markers[2]}--',markersize=ms,c=our_colors['yellow'],linewidth=lw, alpha=1) 141 | 142 | plt.title(f'USPTO-{dataset}') 143 | plt.xlabel('number of training examples') 144 | plt.ylabel('top-100 test-accuracy') 145 | plt.legend([model_legend, 'Fortunato et al.','FNN baseline',"FPM baseline", #static${\\xi X}: \\dfrac{|{\\xi} \\cap {X}|}{|{X}|}$ 146 | 'test sample proportion']) 147 | 148 | if draw_std: 149 | alpha=0.2 150 | plot_std(asdf, n_samples, z=z, color=our_colors['red'], alpha=alpha) 151 | plot_std(pretr_values, n_samples, z=z, color=our_colors['green'], alpha=alpha) 152 | plot_std(basel_values, n_samples, z=z, color=our_colors['blue'], alpha=alpha) 153 | plot_std(staticQK, n_samples, z=z, color=our_colors['yellow'], alpha=alpha) 154 | 155 | 156 | plt.xticks(range(13),xti); 157 | plt.yticks(np.arange(0,1.05,0.1)) 158 | plt.grid('on', alpha=0.3) -------------------------------------------------------------------------------- /mhnreact/retroeval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Author: Philipp Seidl, Philipp Renz 4 | ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning 5 | Johannes Kepler University Linz 6 | Contact: seidl@ml.jku.at 7 | 8 | Evaluation functions for single-step-retrosynthesis 9 | """ 10 | import sys 11 | 12 | import rdchiral 13 | from rdchiral.main import rdchiralRun, rdchiralReaction, rdchiralReactants 14 | import hashlib 15 | from rdkit import Chem 16 | 17 | import torch 18 | import numpy as np 19 | import pandas as pd 20 | from collections import defaultdict 21 | from copy import deepcopy 22 | from glob import glob 23 | import os 24 | import pickle 25 | 26 | from multiprocessing import Pool 27 | import hashlib 28 | import pickle 29 | import logging 30 | 31 | #import timeout_decorator 32 | 33 | 34 | def _cont_hash(fn): 35 | with open(fn, 'rb') as f: 36 | return hashlib.md5(f.read()).hexdigest() 37 | 38 | def load_templates_only(path, cache_dir='/tmp'): 39 | arg_hash_base = 'load_templates_only' + path 40 | arg_hash = hashlib.md5(arg_hash_base.encode()).hexdigest() 41 | matches = glob(os.path.join(cache_dir, arg_hash+'*')) 42 | 43 | if len(matches) > 1: 44 | raise RuntimeError('Too many matches') 45 | elif len(matches) == 1: 46 | fn = matches[0] 47 | content_hash = _cont_hash(path) 48 | content_hash_file = os.path.basename(fn).split('_')[1].split('.')[0] 49 | if content_hash_file == content_hash: 50 | with open(fn, 'rb') as f: 51 | return pickle.load(f) 52 | 53 | df = pd.read_json(path) 54 | template_dict = {} 55 | for row in range(len(df)): 56 | template_dict[df.iloc[row]['index']] = df.iloc[row].reaction_smarts 57 | 58 | # cache the file 59 | content_hash = _cont_hash(path) 60 | fn = os.path.join(cache_dir, f"{arg_hash}_{content_hash}.p") 61 | with open(fn, 'wb') as f: 62 | pickle.dump(template_dict, f) 63 | 64 | def load_templates_v2(path, get_complete_df=False): 65 | if get_complete_df: 66 | df = pd.read_json(path) 67 | return df 68 | 69 | return load_templates_only(path) 70 | 71 | def canonicalize_reactants(smiles, can_steps=2): 72 | if can_steps==0: 73 | return smiles 74 | 75 | mol = Chem.MolFromSmiles(smiles) 76 | for a in mol.GetAtoms(): 77 | a.ClearProp('molAtomMapNumber') 78 | 79 | smiles = Chem.MolToSmiles(mol, True) 80 | if can_steps==1: 81 | return smiles 82 | 83 | smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles), True) 84 | if can_steps==2: 85 | return smiles 86 | 87 | raise ValueError("Invalid can_steps") 88 | 89 | 90 | 91 | def load_test_set(fn): 92 | df = pd.read_csv(fn, index_col=0) 93 | test = df[df.dataset=='test'] 94 | 95 | test_product_smarts = list(test.prod_smiles) # we make predictions for these 96 | for s in test_product_smarts: 97 | assert len(s.split('.')) == 1 98 | assert '>' not in s 99 | 100 | test_reactants = [] # we want to predict these 101 | for rs in list(test.rxn_smiles): 102 | rs = rs.split('>>') 103 | assert len(rs) == 2 104 | reactants_ori, products = rs 105 | reactants = reactants_ori.split('.') 106 | products = products.split('.') 107 | assert len(reactants) >= 1 108 | assert len(products) == 1 109 | 110 | test_reactants.append(reactants_ori) 111 | 112 | return test_product_smarts, test_reactants 113 | 114 | 115 | #@timeout_decorator.timeout(1, use_signals=False) 116 | def time_out_rdchiralRun(temp, prod_rct, combine_enantiomers=False): 117 | rxn = rdchiralReaction(temp) 118 | return rdchiralRun(rxn, prod_rct, combine_enantiomers=combine_enantiomers) 119 | 120 | def _run_templates_rdchiral(prod_appl): 121 | prod, applicable_templates = prod_appl 122 | prod_rct = rdchiralReactants(prod) # preprocess reactants with rdchiral 123 | 124 | results = {} 125 | for idx, temp in applicable_templates: 126 | temp = str(temp) 127 | try: 128 | results[(idx, temp)] = time_out_rdchiralRun(temp, prod_rct, combine_enantiomers=False) 129 | except: 130 | pass 131 | 132 | return results 133 | 134 | def _run_templates_rdchiral_original(prod_appl): 135 | prod, applicable_templates = prod_appl 136 | prod_rct = rdchiralReactants(prod) # preprocess reactants with rdchiral 137 | 138 | results = {} 139 | rxn_cache = {} 140 | for idx, temp in applicable_templates: 141 | temp = str(temp) 142 | if temp in rxn_cache: 143 | rxn = rxn_cache[(temp)] 144 | else: 145 | try: 146 | rxn = rdchiralReaction(temp) 147 | rxn_cache[temp] = rxn 148 | except: 149 | rxn_cache[temp] = None 150 | msg = temp+' error converting to rdchiralReaction' 151 | logging.debug(msg) 152 | try: 153 | res = rdchiralRun(rxn, prod_rct, combine_enantiomers=False) 154 | results[(idx, temp)] = res 155 | except: 156 | pass 157 | 158 | return results 159 | 160 | def run_templates(test_product_smarts, templates, appl, njobs=32, cache_dir='/tmp'): 161 | appl_dict = defaultdict(list) 162 | for i,j in zip(*appl): 163 | appl_dict[i].append(j) 164 | 165 | prod_appl_list = [] 166 | for prod_idx, prod in enumerate(test_product_smarts): 167 | applicable_templates = [(idx, templates[idx]) for idx in appl_dict[prod_idx]] 168 | prod_appl_list.append((prod, applicable_templates)) 169 | 170 | arg_hash = hashlib.md5(pickle.dumps(prod_appl_list)).hexdigest() 171 | cache_file = os.path.join(cache_dir, arg_hash+'.p') 172 | 173 | if os.path.isfile(cache_file): 174 | with open(cache_file, 'rb') as f: 175 | print('loading results from file',f) 176 | all_results = pickle.load(f) 177 | 178 | #find /tmp -type f \( ! -user root \) -atime +3 -delete 179 | # to delete the tmp files that havent been accessed 3 days 180 | 181 | else: 182 | #with Pool(njobs) as pool: 183 | # all_results = pool.map(_run_templates_rdchiral, prod_appl_list) 184 | 185 | from tqdm.contrib.concurrent import process_map 186 | all_results = process_map(_run_templates_rdchiral, prod_appl_list, max_workers=njobs, chunksize=1, mininterval=2) 187 | 188 | #with open(cache_file, 'wb') as f: 189 | # print('saving applicable_templates to cache', cache_file) 190 | # pickle.dump(all_results, f) 191 | 192 | 193 | 194 | prod_idx_reactants = [] 195 | prod_temp_reactants = [] 196 | 197 | for prod, idx_temp_reactants in zip(test_product_smarts, all_results): 198 | prod_idx_reactants.append({idx_temp[0]: r for idx_temp, r in idx_temp_reactants.items()}) 199 | prod_temp_reactants.append({idx_temp[1]: r for idx_temp, r in idx_temp_reactants.items()}) 200 | 201 | return prod_idx_reactants, prod_temp_reactants 202 | 203 | def sort_by_template(template_scores, prod_idx_reactants): 204 | sorted_results = [] 205 | for i, predictions in enumerate(prod_idx_reactants): 206 | score_row = template_scores[i] 207 | appl_idxs = np.array(list(predictions.keys())) 208 | if len(appl_idxs) == 0: 209 | sorted_results.append([]) 210 | continue 211 | scores = score_row[appl_idxs] 212 | sorted_idxs = appl_idxs[np.argsort(scores)][::-1] 213 | sorted_reactants = [predictions[idx] for idx in sorted_idxs] 214 | sorted_results.append(sorted_reactants) 215 | return sorted_results 216 | 217 | def no_dup_same_order(l): 218 | return list({r: 0 for r in l}.keys()) 219 | 220 | def flatten_per_product(sorted_results, remove_duplicates=True): 221 | flat_results = [sum((r for r in row), []) for row in sorted_results] 222 | if remove_duplicates: 223 | flat_results = [no_dup_same_order(row) for row in flat_results] 224 | return flat_results 225 | 226 | 227 | def topkaccuracy(test_reactants, predicted_reactants, ks=[1], ret_ranks=False): 228 | ks = [k if k is not None else 1e10 for k in ks] 229 | ranks = [] 230 | for true, pred in zip(test_reactants, predicted_reactants): 231 | try: 232 | rank = pred.index(true) + 1 233 | except ValueError: 234 | rank = 1e15 235 | ranks.append(rank) 236 | ranks = np.array(ranks) 237 | if ret_ranks: 238 | return ranks 239 | 240 | return [np.mean([ranks <= k]) for k in ks] -------------------------------------------------------------------------------- /mhnreact/retrosyn.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from mhnreact.molutils import canonicalize_template, canonicalize_smi, remove_attom_mapping 3 | from mhnreact.model import MHN 4 | 5 | import os 6 | import pickle 7 | from collections import Counter 8 | from urllib.request import urlretrieve 9 | from time import time 10 | 11 | from scipy.special import softmax 12 | from rdkit.Chem import rdChemReactions 13 | from rdkit import Chem 14 | from rdkit.Chem.Draw import MolToImage 15 | from rdchiral.main import rdchiralRun, rdchiralReaction, rdchiralReactants 16 | 17 | import numpy as np 18 | import pandas as pd 19 | from rdkit import Chem 20 | from rdkit.Chem import AllChem 21 | 22 | import logging 23 | import warnings 24 | from joblib import Memory 25 | 26 | from treelib import Node, Tree 27 | 28 | import torch 29 | 30 | cachedir = 'data/cache/' 31 | memory = Memory(cachedir, verbose=0, bytes_limit=80e9) 32 | 33 | from mhnreact.molutils import getFingerprint 34 | from mhnreact.molutils import convert_smiles_to_fp 35 | 36 | reaction_superclass_names = { 37 | 1 : 'Heteroatom alkylation and arylation', 38 | 2 : 'Acylation and related processes', 39 | 3 : 'C-C bond formation', 40 | 4 : 'Heterocycle formation', #TODO check 41 | 5 : 'Protections', 42 | 6 : 'Deprotections', 43 | 7 : 'Reductions', 44 | 8 : 'Oxidations', 45 | 9 : 'Functional group interconversoin (FGI)', 46 | 10 :'Functional group addition (FGA)' 47 | } 48 | 49 | @memory.cache(ignore=[]) 50 | def getTemplateApplicabilityMatrix(t, fp_size=8096, fp_type='pattern'): 51 | only_left_side_of_templates = list(map(lambda k: k.split('>>')[0], t.values())) 52 | return convert_smiles_to_fp(only_left_side_of_templates,is_smarts=True, which=fp_type, fp_size=fp_size) 53 | 54 | def FPF(smi, templates, fp_size=8096, fp_type='pattern'): 55 | """Fingerprint-Filter for applicability""" 56 | tfp = getTemplateApplicabilityMatrix(templates, fp_size=fp_size, fp_type=fp_type) 57 | if not isinstance(smi, list): 58 | smi = [smi] 59 | mfp = convert_smiles_to_fp(smi,which=fp_type, fp_size=fp_size) 60 | applicable = ((tfp&mfp).sum(1)==(tfp.sum(1))) 61 | return applicable 62 | 63 | @memory.cache 64 | def load_emol_bbs(DATA_DIR='data/cache/'): 65 | """Downloads emolecules building blocks""" 66 | os.makedirs(DATA_DIR, exist_ok=True) 67 | fn_bb_txt = os.path.join(DATA_DIR, 'emol_building_blocks.txt') 68 | fn_emol_csv = os.path.join(DATA_DIR, 'version.smi.gz') 69 | fn_emol_pickle = os.path.join(DATA_DIR, 'emol.pkl.gz') 70 | 71 | url_emol_csv = 'http://downloads.emolecules.com/free/2021-10-01/version.smi.gz' 72 | 73 | if os.path.isfile(fn_bb_txt): 74 | with open(fn_bb_txt, 'r') as fh: 75 | return fh.read().split() 76 | 77 | if not os.path.isfile(fn_emol_csv): 78 | print('Downloading emolecule data') 79 | urlretrieve(url_emol_csv, fn_emol_csv) 80 | 81 | print('Loading eMolecules data frame (may take a minute ;)') 82 | emol_df = pd.read_csv(fn_emol_csv, sep=' ', compression='gzip') 83 | 84 | bb_list = set(list(emol_df.isosmiles.values)) 85 | 86 | return bb_list 87 | 88 | def in_emol_or_trivial(smi): 89 | smi = canonicalize_smi(smi, remove_atom_mapping=True) 90 | if len(smi)<5: 91 | return True 92 | return smi in emol 93 | 94 | def not_in_emol_nor_trivial(smi): 95 | return not in_emol_or_trivial(smi) 96 | 97 | def link_mols_to_html(list_of_mols): 98 | final_str = '[' 99 | for mol in list_of_mols: 100 | url = f'{mol}.html' 101 | final_str+=f'{mol}, ' 102 | return final_str+']' 103 | 104 | def svg_to_html(svg): 105 | return svg.replace('\n','').replace('\n','') 106 | 107 | def compute_html(res, target_smiles): 108 | res_df = pd.DataFrame(res)#.canon_reaction.unique() 109 | res_df['non_buyable_reactants'] = res_df.reaction_canonical.apply(lambda k: list(filter(not_in_emol_nor_trivial, k.split('>>')[0].split('.'))) ) 110 | res_df['SVG'] = res_df.reaction.apply(lambda k: svg_to_html(smarts2svg(k))) 111 | res_df['non_buyable_reactants'] = res_df.non_buyable_reactants.apply(link_mols_to_html) 112 | res_df.to_html(f'data/cache/retrohtml/{target_smiles}.html', escape=False) 113 | return res_df 114 | 115 | #@memory.cache(ignore=['clf','viz']) 116 | def ssretro(target_smiles:str, clf, num_paths=1, try_max_temp=10, viz=False, use_FPF=False): 117 | """single-step-retrosynthesis""" 118 | if hasattr(clf, 'templates'): 119 | if clf.X is None: 120 | clf.X = clf.template_encoder(clf.templates) 121 | preds = clf.forward_smiles([target_smiles]) 122 | if use_FPF: 123 | appl = FPF(target_smiles, t) 124 | preds = preds * torch.tensor(appl) 125 | preds = clf.softmax(preds) 126 | 127 | idxs = preds.argsort().detach().numpy().flatten()[::-1] 128 | preds = preds.detach().numpy().flatten() 129 | 130 | try: 131 | prod_rct = rdchiralReactants(target_smiles) 132 | except: 133 | print('target_smiles', target_smiles, 'not computebale') 134 | return [] 135 | reactions = [] 136 | 137 | i=0 138 | while len(reactions)>'+k[0] for k in list(mapped_res.values())] 153 | for r in rs: 154 | di = { 155 | 'template_used':t[idxs[i]], 156 | 'template_idx': idxs[i], 157 | 'template_rank': i+1, #get the acutal rank, not the one without non-executable 158 | 'reaction': r, 159 | 'reaction_canonical': canonicalize_template(r), 160 | 'prob': preds[idxs[i]]*100, 161 | 'template_class': reaction_superclass_names[df[df.reaction_smarts==t[idxs[i]]]["class"].unique()[0]] 162 | } 163 | di['template_num_train_samples'] = (y['train']==di['template_idx']).sum() 164 | reactions.append(di) 165 | if viz: 166 | for r in rs: 167 | print('with template #',idxs[i], t[idxs[i]]) 168 | smarts2svg(r, useSmiles=True, highlightByReactant=True); 169 | return reactions 170 | 171 | def main(target_smiles, clf_name, max_steps=3, model_type='mhn'): 172 | clf = load_clf(clf_name, model_type=model_type) 173 | clf.eval() 174 | res = ssretro(target_smiles, clf=clf, num_paths=10, viz=False, use_FPF=True) 175 | res_df = pd.DataFrame(res) 176 | emol = load_emol_bbs() 177 | print('check non-buyable') 178 | res_df['non_buyable_reactants'] = res_df.reaction_canonical.apply(lambda k: list(filter(not_in_emol_nor_trivial, k.split('>>')[0].split('.'))) ) 179 | print('compute html-files') 180 | res_df = compute_html(res_df, 'mhn_main_'+target_smiles) 181 | 182 | import functools 183 | import operator 184 | non_buyable_reactant_set = ['_dummy'] 185 | 186 | while (max_steps!=0) and (len(non_buyable_reactant_set)>0): 187 | max_steps-=1 188 | non_buyable_reactant_set = set(functools.reduce(operator.iconcat, res_df['non_buyable_reactants'].values.tolist(), [])) 189 | for r in non_buyable_reactant_set: 190 | res_2 = ssretro(r, clf=clf) 191 | if len(res_2>0): 192 | compute_html(res_2, r) 193 | 194 | res_df = pd.DataFrame(res_2) 195 | res_df['non_buyable_reactants'] = res_df.reaction_canonical.apply(lambda k: list(filter(not_in_emol_nor_trivial, k.split('>>')[0].split('.'))) ) 196 | 197 | 198 | def recu(target_smi, clf, tree, n_iter=1, use_FPF=False): 199 | """recurrent part of multi-step-retrosynthesis""" 200 | if n_iter: 201 | res = ssretro(target_smi, clf, num_paths=5, try_max_temp=100, use_FPF=use_FPF) 202 | if len(res)==0: 203 | return tree 204 | reactants = res[0]['reaction_canonical'].split('>>')[1].split('.') 205 | if len(reactants)>=2: 206 | pass 207 | 208 | for r in reactants: 209 | try: 210 | if check_availability: 211 | avail = not_in_emol_nor_trivial(r) 212 | tree.create_node(r+'_available' if avail else r, r, parent=target_smi) 213 | else: 214 | tree.create_node(r, r, parent=target_smi) 215 | except: 216 | reactants = res[1]['reaction_canonical'].split('>>')[1].split('.') 217 | for r in reactants: 218 | try: 219 | tree.create_node(r+'_available' if (check_availability and not_in_emol_nor_trivial(r)) else r, r, parent=target_smi) 220 | except: 221 | tree.create_node(r+'_faild',r+'_faild', parent=target_smi) 222 | tree = recu(r, clf, tree, n_iter=(n_iter-1), use_FPF=use_FPF) 223 | return tree 224 | else: 225 | return tree 226 | 227 | def msretro(target_smi, clf, n_iter=3, use_FPF=False, check_availability=False): 228 | """Multi-step-retro-synthesis function""" 229 | tree = Tree() 230 | tree.create_node(target_smi, target_smi) 231 | 232 | tree = recu(target_smi, clf, tree, n_iter=n_iter, use_FPF=False) 233 | return tree 234 | 235 | if __name__ == '__main__': 236 | msretro() -------------------------------------------------------------------------------- /mhnreact/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Author: Philipp Seidl 4 | ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning 5 | Johannes Kepler University Linz 6 | Contact: seidl@ml.jku.at 7 | 8 | General utility functions 9 | """ 10 | 11 | import argparse 12 | from collections import defaultdict 13 | import numpy as np 14 | import pandas as pd 15 | import math 16 | import torch 17 | 18 | # used and fastest version 19 | def top_k_accuracy(y_true, y_pred, k=5, ret_arocc=False, ret_mrocc=False, verbose=False, count_equal_as_correct=False, eps_noise=0): 20 | """ partly from http://stephantul.github.io/python/pytorch/2020/09/18/fast_topk/ 21 | count_equal counts equal values as beein a correct choice e.g. all preds = 0 --> T1acc = 1 22 | ret_mrocc ... also return median rank of correct choice 23 | eps_noise ... if >0 ads noise*eps to y_pred .. recommended e.g. 1e-10 24 | """ 25 | if eps_noise>0: 26 | if torch.is_tensor(y_pred): 27 | y_pred = y_pred + torch.rand(y_pred.shape)*eps_noise 28 | else: 29 | y_pred = y_pred + np.random.rand(*y_pred.shape)*eps_noise 30 | 31 | if count_equal_as_correct: 32 | greater = (y_pred > y_pred[range(len(y_pred)), y_true][:,None]).sum(1) # how many are bigger 33 | else: 34 | greater = (y_pred >= y_pred[range(len(y_pred)), y_true][:,None]).sum(1) # how many are bigger or equal 35 | if torch.is_tensor(y_pred): 36 | greater = greater.long() 37 | if isinstance(k, int): k = [k] # pack it into a list 38 | tkaccs = [] 39 | for ki in k: 40 | if count_equal_as_correct: 41 | tkacc = (greater<=(ki-1)) 42 | else: 43 | tkacc = (greater<=(ki)) 44 | if torch.is_tensor(y_pred): 45 | tkacc = tkacc.float().mean().detach().cpu().numpy() 46 | else: 47 | tkacc = tkacc.mean() 48 | tkaccs.append(tkacc) 49 | if verbose: print('Top', ki, 'acc:\t', str(tkacc)[:6]) 50 | 51 | if ret_arocc: 52 | arocc = greater.float().mean()+1 53 | if torch.is_tensor(arocc): 54 | arocc = arocc.detach().cpu().numpy() 55 | return (tkaccs[0], arocc) if len(tkaccs) == 1 else (tkaccs, arocc) 56 | if ret_mrocc: 57 | mrocc = greater.median()+1 58 | if torch.is_tensor(mrocc): 59 | mrocc = mrocc.float().detach().cpu().numpy() 60 | return (tkaccs[0], mrocc) if len(tkaccs) == 1 else (tkaccs, mrocc) 61 | 62 | 63 | return tkaccs[0] if len(tkaccs) == 1 else tkaccs 64 | 65 | 66 | def seed_everything(seed=70135): 67 | """ does what it says ;) - from https://gist.github.com/KirillVladimirov/005ec7f762293d2321385580d3dbe335""" 68 | import numpy as np 69 | import random 70 | import os 71 | import torch 72 | 73 | random.seed(seed) 74 | os.environ['PYTHONHASHSEED'] = str(seed) 75 | np.random.seed(seed) 76 | torch.manual_seed(seed) 77 | torch.cuda.manual_seed(seed) 78 | torch.backends.cudnn.deterministic = True 79 | 80 | def get_best_gpu(): 81 | '''Get the gpu with most RAM on the machine. From P. Neves''' 82 | import torch 83 | if torch.cuda.is_available(): 84 | gpus_ram = [] 85 | for ind in range(torch.cuda.device_count()): 86 | gpus_ram.append(torch.cuda.get_device_properties(ind).total_memory/1e9) 87 | return f"cuda:{gpus_ram.index(max(gpus_ram))}" 88 | else: 89 | raise ValueError("No gpus were detected in this machine.") 90 | 91 | 92 | def sort_by_template_and_flatten(template_scores, prod_idx_reactants, agglo_fun=sum): 93 | flat_results = [] 94 | for ii in range(len(template_scores)): 95 | idx_prod_reactants = defaultdict(list) 96 | for k,v in prod_idx_reactants[ii].items(): 97 | for iv in v: 98 | idx_prod_reactants[iv].append(template_scores[ii,k]) 99 | d2 = {k: agglo_fun(v) for k, v in idx_prod_reactants.items()} 100 | if len(d2)==0: 101 | flat_results.append([]) 102 | else: 103 | flat_results.append(pd.DataFrame.from_dict(d2, orient='index').sort_values(0, ascending=False).index.values.tolist()) 104 | return flat_results 105 | 106 | 107 | def str2bool(v): 108 | """adapted from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse""" 109 | if isinstance(v, bool): 110 | return v 111 | if v.lower() in ('yes', 'true', 't', 'y', '1', '',' '): 112 | return True 113 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 114 | return False 115 | else: 116 | raise argparse.ArgumentTypeError('Boolean value expected.') 117 | 118 | 119 | @np.vectorize 120 | def lgamma(x): 121 | return math.lgamma(x) 122 | 123 | def multinom_gk(array, axis=0): 124 | """Multinomial lgamma pooling over a given axis""" 125 | res = lgamma(np.sum(array,axis=axis)+2) - np.sum(lgamma(array+1),axis=axis) 126 | return res -------------------------------------------------------------------------------- /mhnreact/view.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Author: Philipp Seidl 4 | ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning 5 | Johannes Kepler University Linz 6 | Contact: seidl@ml.jku.at 7 | 8 | Loading log-files from training 9 | """ 10 | 11 | from pathlib import Path 12 | import os 13 | import datetime 14 | import pandas as pd 15 | import numpy as np 16 | import pandas as pd 17 | import matplotlib.pyplot as plt 18 | 19 | def load_experiments(EXP_DIR = Path('data/experiments/')): 20 | dfs = [] 21 | for fn in os.listdir(EXP_DIR): 22 | print(fn, end='\r') 23 | if fn.split('.')[-1]=='tsv': 24 | df = pd.read_csv(EXP_DIR/fn, sep='\t', index_col=0) 25 | try: 26 | with open(df['fn_hist'][0]) as f: 27 | hist = eval(f.readlines()[0] ) 28 | df['hist'] = [hist] 29 | df['fn'] = fn 30 | except: 31 | print('err') 32 | #print(df['fn_hist']) 33 | dfs.append( df ) 34 | df = pd.concat(dfs,ignore_index=True) 35 | return df 36 | 37 | def get_x(k, kw, operation='max', index=None): 38 | operation = getattr(np,operation) 39 | try: 40 | if index is not None: 41 | return k[kw][index] 42 | 43 | return operation(k[kw]) 44 | except: 45 | return 0 46 | 47 | def get_min_val_loss_idx(k): 48 | return get_x(k, 'loss_valid', 'argmin') #changed from argmax to argmin!! 49 | 50 | def get_tauc(hist): 51 | idx = get_min_val_loss_idx(hist) 52 | # takes max TODO take idx 53 | return np.mean([get_x(hist, f't100_acc_nte_{nt}') for nt in [*range(11),'>10']]) 54 | 55 | def get_stats_from_hist(df): 56 | df['0shot_acc'] = df['hist'].apply(lambda k: get_x(k, 't100_acc_nte_0')) 57 | df['1shot_acc'] = df['hist'].apply(lambda k: get_x(k, 't100_acc_nte_1')) 58 | df['>49shot_acc'] = df['hist'].apply(lambda k: get_x(k, 't100_acc_nte_>49')) 59 | df['min_loss_valid'] = df['hist'].apply(lambda k: get_x(k, 'loss_valid', 'min')) 60 | return df -------------------------------------------------------------------------------- /notebooks/02_prepro_uspto_50k.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%cd .." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "# Prepare USPTO-50k for single-step retrosynthesis" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "data": { 26 | "text/html": [ 27 | "
\n", 28 | "\n", 41 | "\n", 42 | " \n", 43 | " \n", 44 | " \n", 45 | " \n", 46 | " \n", 47 | " \n", 48 | " \n", 49 | " \n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | "
classidprod_smilesrxn_smilesprod_smiles_popkeep
105211US20100317582A1C[C@@H](NC1CCCC(c2ccccn2)C1)c1cccc2ccccc12O=[C:1]1[CH2:2][CH2:3][CH2:4][CH:5]([c:6]2[cH:...1True
254921US05932582Cc1cc(OCC(N)=O)ccc1NC(=O)OC(C)(C)CBr[CH2:1][C:2]([NH2:3])=[O:4].[CH3:5][c:6]1[cH...1True
469977US05266570Nc1cnc(NC2CCN(CC34CC(c5ccccc53)c3ccccc34)CC2)nc1NO=[N+:1]([O-])[c:2]1[c:3]([NH2:4])[n:5][c:6]([...1True
155806US20080181866A1O=C(O)C1CS[C@H](C2CCCNC2)N1C(=O)OCc1ccccc1CC(C)(C)OC(=O)[N:1]1[CH2:2][CH2:3][CH2:4][CH:5...1True
69617US07229987B2CN(C)CCCOc1ccc(N)cc1O=[N+:1]([O-])[c:2]1[cH:3][cH:4][c:5]([O:6][CH...1True
\n", 101 | "
" 102 | ], 103 | "text/plain": [ 104 | " class id \\\n", 105 | "10521 1 US20100317582A1 \n", 106 | "25492 1 US05932582 \n", 107 | "46997 7 US05266570 \n", 108 | "15580 6 US20080181866A1 \n", 109 | "6961 7 US07229987B2 \n", 110 | "\n", 111 | " prod_smiles \\\n", 112 | "10521 C[C@@H](NC1CCCC(c2ccccn2)C1)c1cccc2ccccc12 \n", 113 | "25492 Cc1cc(OCC(N)=O)ccc1NC(=O)OC(C)(C)C \n", 114 | "46997 Nc1cnc(NC2CCN(CC34CC(c5ccccc53)c3ccccc34)CC2)nc1N \n", 115 | "15580 O=C(O)C1CS[C@H](C2CCCNC2)N1C(=O)OCc1ccccc1 \n", 116 | "6961 CN(C)CCCOc1ccc(N)cc1 \n", 117 | "\n", 118 | " rxn_smiles prod_smiles_pop \\\n", 119 | "10521 O=[C:1]1[CH2:2][CH2:3][CH2:4][CH:5]([c:6]2[cH:... 1 \n", 120 | "25492 Br[CH2:1][C:2]([NH2:3])=[O:4].[CH3:5][c:6]1[cH... 1 \n", 121 | "46997 O=[N+:1]([O-])[c:2]1[c:3]([NH2:4])[n:5][c:6]([... 1 \n", 122 | "15580 CC(C)(C)OC(=O)[N:1]1[CH2:2][CH2:3][CH2:4][CH:5... 1 \n", 123 | "6961 O=[N+:1]([O-])[c:2]1[cH:3][cH:4][c:5]([O:6][CH... 1 \n", 124 | "\n", 125 | " keep \n", 126 | "10521 True \n", 127 | "25492 True \n", 128 | "46997 True \n", 129 | "15580 True \n", 130 | "6961 True " 131 | ] 132 | }, 133 | "execution_count": 2, 134 | "metadata": {}, 135 | "output_type": "execute_result" 136 | } 137 | ], 138 | "source": [ 139 | "# load the dataset\n", 140 | "import pandas as pd\n", 141 | "# or can be downloaded here: https://github.com/connorcoley/retrosim/raw/master/retrosim/data/data_processed.csv\n", 142 | "df = pd.read_csv('https://github.com/connorcoley/retrosim/raw/master/retrosim/data/data_processed.csv', index_col=0)\n", 143 | "df = df.sample(len(df), random_state=42) # shuffle the dataset\n", 144 | "df.head()" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 3, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "import numpy as np\n", 154 | "def split_data_df(data, val_frac=0.1, test_frac=0.1, shuffle=False, seed=None):\n", 155 | " \"\"\"edited from https://github.com/connorcoley/retrosim/blob/master/retrosim/data/get_data.py\"\"\"\n", 156 | " # Define shuffling\n", 157 | " if shuffle:\n", 158 | " if seed is None:\n", 159 | " np.random.seed(int(time.time()))\n", 160 | " else:\n", 161 | " np.random.seed(seed)\n", 162 | " def shuffle_func(x):\n", 163 | " np.random.shuffle(x)\n", 164 | " else:\n", 165 | " def shuffle_func(x):\n", 166 | " pass\n", 167 | "\n", 168 | " # Go through each class\n", 169 | " classes = sorted(np.unique(data['class']))\n", 170 | " for class_ in classes:\n", 171 | " indeces = data.loc[data['class'] == class_].index\n", 172 | " N = len(indeces)\n", 173 | " print('{} rows with class value {}'.format(N, class_))\n", 174 | "\n", 175 | " shuffle_func(indeces)\n", 176 | " train_end = int((1.0 - val_frac - test_frac) * N)\n", 177 | " val_end = int((1.0 - test_frac) * N)\n", 178 | "\n", 179 | " for i in indeces[:train_end]:\n", 180 | " data.at[i, 'dataset'] = 'train'\n", 181 | " for i in indeces[train_end:val_end]:\n", 182 | " data.at[i, 'dataset'] = 'valid'\n", 183 | " for i in indeces[val_end:]:\n", 184 | " data.at[i, 'dataset'] = 'test'\n", 185 | " print(data['dataset'].value_counts())" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 4, 191 | "metadata": {}, 192 | "outputs": [ 193 | { 194 | "name": "stdout", 195 | "output_type": "stream", 196 | "text": [ 197 | "15151 rows with class value 1\n", 198 | "11896 rows with class value 2\n", 199 | "5662 rows with class value 3\n", 200 | "909 rows with class value 4\n", 201 | "672 rows with class value 5\n", 202 | "8237 rows with class value 6\n", 203 | "4614 rows with class value 7\n", 204 | "811 rows with class value 8\n", 205 | "1834 rows with class value 9\n", 206 | "230 rows with class value 10\n", 207 | "train 40008\n", 208 | "test 5007\n", 209 | "valid 5001\n", 210 | "Name: dataset, dtype: int64\n" 211 | ] 212 | }, 213 | { 214 | "data": { 215 | "text/html": [ 216 | "
\n", 217 | "\n", 230 | "\n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | "
classidprod_smilesrxn_smilesprod_smiles_popkeepdataset
105211US20100317582A1C[C@@H](NC1CCCC(c2ccccn2)C1)c1cccc2ccccc12O=[C:1]1[CH2:2][CH2:3][CH2:4][CH:5]([c:6]2[cH:...1Truetrain
254921US05932582Cc1cc(OCC(N)=O)ccc1NC(=O)OC(C)(C)CBr[CH2:1][C:2]([NH2:3])=[O:4].[CH3:5][c:6]1[cH...1Truetrain
469977US05266570Nc1cnc(NC2CCN(CC34CC(c5ccccc53)c3ccccc34)CC2)nc1NO=[N+:1]([O-])[c:2]1[c:3]([NH2:4])[n:5][c:6]([...1Truetrain
155806US20080181866A1O=C(O)C1CS[C@H](C2CCCNC2)N1C(=O)OCc1ccccc1CC(C)(C)OC(=O)[N:1]1[CH2:2][CH2:3][CH2:4][CH:5...1Truetrain
69617US07229987B2CN(C)CCCOc1ccc(N)cc1O=[N+:1]([O-])[c:2]1[cH:3][cH:4][c:5]([O:6][CH...1Truetrain
\n", 296 | "
" 297 | ], 298 | "text/plain": [ 299 | " class id \\\n", 300 | "10521 1 US20100317582A1 \n", 301 | "25492 1 US05932582 \n", 302 | "46997 7 US05266570 \n", 303 | "15580 6 US20080181866A1 \n", 304 | "6961 7 US07229987B2 \n", 305 | "\n", 306 | " prod_smiles \\\n", 307 | "10521 C[C@@H](NC1CCCC(c2ccccn2)C1)c1cccc2ccccc12 \n", 308 | "25492 Cc1cc(OCC(N)=O)ccc1NC(=O)OC(C)(C)C \n", 309 | "46997 Nc1cnc(NC2CCN(CC34CC(c5ccccc53)c3ccccc34)CC2)nc1N \n", 310 | "15580 O=C(O)C1CS[C@H](C2CCCNC2)N1C(=O)OCc1ccccc1 \n", 311 | "6961 CN(C)CCCOc1ccc(N)cc1 \n", 312 | "\n", 313 | " rxn_smiles prod_smiles_pop \\\n", 314 | "10521 O=[C:1]1[CH2:2][CH2:3][CH2:4][CH:5]([c:6]2[cH:... 1 \n", 315 | "25492 Br[CH2:1][C:2]([NH2:3])=[O:4].[CH3:5][c:6]1[cH... 1 \n", 316 | "46997 O=[N+:1]([O-])[c:2]1[c:3]([NH2:4])[n:5][c:6]([... 1 \n", 317 | "15580 CC(C)(C)OC(=O)[N:1]1[CH2:2][CH2:3][CH2:4][CH:5... 1 \n", 318 | "6961 O=[N+:1]([O-])[c:2]1[cH:3][cH:4][c:5]([O:6][CH... 1 \n", 319 | "\n", 320 | " keep dataset \n", 321 | "10521 True train \n", 322 | "25492 True train \n", 323 | "46997 True train \n", 324 | "15580 True train \n", 325 | "6961 True train " 326 | ] 327 | }, 328 | "execution_count": 4, 329 | "metadata": {}, 330 | "output_type": "execute_result" 331 | } 332 | ], 333 | "source": [ 334 | "# split the data just like in retrosim (don't know the seed though) # shuffle throws error\n", 335 | "#from retrosim.data.get_data import split_data_df\n", 336 | "split_data_df(df, shuffle=False, seed=42) # 80/10/10 within each class\n", 337 | "df.head()" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 5, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "import hashlib\n", 347 | "def create_hash(pd_row):\n", 348 | " return hashlib.md5(pd_row.to_json().encode()).hexdigest()\n", 349 | "\n", 350 | "if '_id' not in df.columns:\n", 351 | " df['_id'] = df.apply(create_hash, axis=1)" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 6, 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "df.rename(columns={'rxn_smiles':'reaction_smiles'}, inplace=True)\n", 361 | "df.rename(columns={'dataset':'split'}, inplace=True)" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 7, 367 | "metadata": {}, 368 | "outputs": [], 369 | "source": [ 370 | "reactants, spectators, products = list(zip(*[s.split('>') for s in df['reaction_smiles']]))\n", 371 | "df['reactants'] = reactants\n", 372 | "df['spectators'] = spectators\n", 373 | "df['products'] = products" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": 8, 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [ 382 | "# extract templates\n", 383 | "\n", 384 | "from multiprocessing import Pool\n", 385 | "from rdchiral.template_extractor import extract_from_reaction\n", 386 | "\n", 387 | "reaction_dicts = [row.to_dict() for i, row in df.iterrows()]\n", 388 | "with Pool(32) as pool:\n", 389 | " res = pool.map(extract_from_reaction, reaction_dicts)" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 9, 395 | "metadata": {}, 396 | "outputs": [], 397 | "source": [ 398 | "assert list(df._id) == [r['reaction_id'] for r in res]\n", 399 | "reaction_smarts = [r['reaction_smarts'] for r in res]\n", 400 | "df['reaction_smarts'] = reaction_smarts" 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": 10, 406 | "metadata": {}, 407 | "outputs": [], 408 | "source": [ 409 | "# canonicalize reactant (optionally product_can_from_reaction)\n", 410 | "from mhnreact.retroeval import canonicalize_reactants\n", 411 | "df['reactants_can'] = [canonicalize_reactants(r, can_steps=2) for r in df['reactants']]" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 11, 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "def filter_by_dict(df, fil):\n", 421 | " for col, value in fil.items():\n", 422 | " if not isinstance(value, list):\n", 423 | " value = [value]\n", 424 | " df = df[df[col].isin(value)]\n", 425 | " return df" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": 12, 431 | "metadata": {}, 432 | "outputs": [], 433 | "source": [ 434 | "import re\n", 435 | "\n", 436 | "mre = ':\\d+(?=])'\n", 437 | "unmapped = [re.sub(mre,'',r) for r in df['reaction_smarts']]\n", 438 | "df['unmapped_template'] = unmapped\n", 439 | "\n", 440 | "unmapped2idx = {}\n", 441 | "labels = []\n", 442 | "for split in ['train', 'valid', 'test']:\n", 443 | " sub = filter_by_dict(df, {'split': split})\n", 444 | " for u in sub['unmapped_template']:\n", 445 | " if u not in unmapped2idx:\n", 446 | " label = len(unmapped2idx)\n", 447 | " unmapped2idx[u] = label\n", 448 | " \n", 449 | "df['label'] = [unmapped2idx[u] for u in df['unmapped_template']]" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": 13, 455 | "metadata": {}, 456 | "outputs": [ 457 | { 458 | "data": { 459 | "text/plain": [ 460 | "class 1\n", 461 | "id US20100317582A1\n", 462 | "prod_smiles C[C@@H](NC1CCCC(c2ccccn2)C1)c1cccc2ccccc12\n", 463 | "reaction_smiles O=[C:1]1[CH2:2][CH2:3][CH2:4][CH:5]([c:6]2[cH:...\n", 464 | "prod_smiles_pop 1\n", 465 | "keep True\n", 466 | "split train\n", 467 | "_id 20f7e94253448bd4e0c5edb1c421eea0\n", 468 | "reactants O=[C:1]1[CH2:2][CH2:3][CH2:4][CH:5]([c:6]2[cH:...\n", 469 | "spectators \n", 470 | "products [CH:1]1([NH:15][C@H:14]([CH3:13])[c:16]2[cH:17...\n", 471 | "reaction_smarts [C:2]-[CH;D3;+0:1](-[C:3])-[NH;D2;+0:5]-[C:4]>...\n", 472 | "reactants_can C[C@@H](N)c1cccc2ccccc12.O=C1CCCC(c2ccccn2)C1\n", 473 | "unmapped_template [C]-[CH;D3;+0](-[C])-[NH;D2;+0]-[C]>>O=[C;H0;D...\n", 474 | "label 0\n", 475 | "Name: 10521, dtype: object" 476 | ] 477 | }, 478 | "execution_count": 13, 479 | "metadata": {}, 480 | "output_type": "execute_result" 481 | } 482 | ], 483 | "source": [ 484 | "df.iloc[0]" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": 14, 490 | "metadata": {}, 491 | "outputs": [], 492 | "source": [ 493 | "# all the relevant data is now in here\n", 494 | "df_rel = df[['id','class','prod_smiles','reactants_can','split', 'reaction_smarts', 'label']]" 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": 15, 500 | "metadata": {}, 501 | "outputs": [], 502 | "source": [ 503 | "df_rel.to_csv('./data/USPTO_50k_MHN_prepro_recre.csv.gz') # note it's not the same file-size due to e.g. time-split is missing" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": null, 509 | "metadata": {}, 510 | "outputs": [], 511 | "source": [ 512 | "#df.to_csv('./data/USPTO_50k_MHN_prepro_allcol.csv') # all columns" 513 | ] 514 | } 515 | ], 516 | "metadata": { 517 | "kernelspec": { 518 | "display_name": "Python 3", 519 | "language": "python", 520 | "name": "python3" 521 | }, 522 | "language_info": { 523 | "codemirror_mode": { 524 | "name": "ipython", 525 | "version": 3 526 | }, 527 | "file_extension": ".py", 528 | "mimetype": "text/x-python", 529 | "name": "python", 530 | "nbconvert_exporter": "python", 531 | "pygments_lexer": "ipython3", 532 | "version": "3.7.6" 533 | } 534 | }, 535 | "nbformat": 4, 536 | "nbformat_minor": 4 537 | } 538 | -------------------------------------------------------------------------------- /notebooks/11_training_template_relevance_prediction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%cd .." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "# template relevance prediction training\n", 17 | "for saving the model to ```./data/model/``` add ```--save_model True```\n", 18 | "\n", 19 | "for further details call ```python -m mhnreact.train -h```" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "### train DNN for template relevance prediction" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 3, 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "name": "stdout", 36 | "output_type": "stream", 37 | "text": [ 38 | "seeded with 0\n", 39 | "train 29816 samples ( 9161 max label)\n", 40 | "valid 4482 samples ( 9157 max label)\n", 41 | "test 5959 samples ( 9145 max label)\n", 42 | "[610, 1699, 287, 180, 143, 105, 70, 48, 124, 86, 68, 2539, 1648]\n", 43 | "{'fingerprint_type': 'morgan', 'template_fp_type': 'rdk', 'num_templates': 9162, 'fp_size': 4096, 'fp_radius': 2, 'device': 'cuda:1', 'batch_size': 256, 'pooling_operation_state_embedding': 'mean', 'pooling_operation_head': 'max', 'dropout': 0.15, 'lr': 0.0005, 'optimizer': 'Adam', 'activation_function': 'ReLU', 'verbose': False, 'hopf_input_size': 4096, 'hopf_output_size': None, 'hopf_num_heads': 1, 'hopf_asso_dim': 2048, 'hopf_association_activation': 'None', 'hopf_beta': 0.05, 'norm_input': False, 'norm_asso': False, 'hopf_n_layers': 1, 'mol_encoder_layers': 1, 'temp_encoder_layers': 1, 'encoder_af': 'ReLU', 'normalize': True, 'norm_affine': True, 'hopf_pooling_operation_head': 'mean'}\n", 44 | "117it [00:02, 47.77it/s]\n", 45 | " 0 -- train_loss: 6.897, loss_valid: 5.957, val_t1acc: 0.240, val_t100acc: 0.567\n", 46 | "117it [00:02, 48.01it/s]\n", 47 | " 1 -- train_loss: 4.298, loss_valid: 4.945, val_t1acc: 0.317, val_t100acc: 0.703\n", 48 | "117it [00:02, 48.17it/s]\n", 49 | " 2 -- train_loss: 2.836, loss_valid: 4.391, val_t1acc: 0.380, val_t100acc: 0.766\n", 50 | "117it [00:02, 47.86it/s]\n", 51 | " 3 -- train_loss: 1.671, loss_valid: 4.200, val_t1acc: 0.407, val_t100acc: 0.786\n", 52 | "117it [00:02, 47.91it/s]\n", 53 | " 4 -- train_loss: 0.818, loss_valid: 4.217, val_t1acc: 0.419, val_t100acc: 0.788\n", 54 | "117it [00:02, 47.89it/s]\n", 55 | " 5 -- train_loss: 0.429, loss_valid: 4.302, val_t1acc: 0.420, val_t100acc: 0.792\n", 56 | "117it [00:02, 47.85it/s]\n", 57 | " 6 -- train_loss: 0.277, loss_valid: 4.421, val_t1acc: 0.420, val_t100acc: 0.792\n", 58 | "117it [00:02, 48.07it/s]\n", 59 | " 7 -- train_loss: 0.196, loss_valid: 4.528, val_t1acc: 0.420, val_t100acc: 0.792\n", 60 | "117it [00:02, 48.12it/s]\n", 61 | " 8 -- train_loss: 0.153, loss_valid: 4.582, val_t1acc: 0.424, val_t100acc: 0.791\n", 62 | "117it [00:02, 48.32it/s]\n", 63 | " 9 -- train_loss: 0.123, loss_valid: 4.656, val_t1acc: 0.426, val_t100acc: 0.792\n", 64 | "saving predictions to ./data/preds/USPTO_sm_test_segler_rerun_1632299812.npy\n", 65 | "model saved to data/model/USPTO_sm_segler_valloss4.656_rerun_1632299812.pt\n", 66 | "4.19974672794342\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "#sm_DNN_no_test\n", 72 | "!python -m mhnreact.train --device cuda:1 --dataset_type sm --exp_name rerun --model_type segler \\\n", 73 | "--pretrain_epochs 0 --epochs 10 --hopf_asso_dim 2048 --fp_type morgan --fp_size 4096 --dropout 0.15 \\\n", 74 | "--lr 0.0005 --mol_encoder_layers 1 --batch_size 256 --save_preds True --save_model True \\\n", 75 | "--seed 0 --fp_radius 2" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "### pretraining on applicability" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 5, 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "name": "stdout", 92 | "output_type": "stream", 93 | "text": [ 94 | "seeded with 0\n", 95 | "train 29816 samples ( 9161 max label)\n", 96 | "valid 4482 samples ( 9157 max label)\n", 97 | "test 5959 samples ( 9145 max label)\n", 98 | "[610, 1699, 287, 180, 143, 105, 70, 48, 124, 86, 68, 2539, 1648]\n", 99 | "{'fingerprint_type': 'morgan', 'template_fp_type': 'rdk', 'num_templates': 9162, 'fp_size': 4096, 'fp_radius': 2, 'device': 'cuda:1', 'batch_size': 256, 'pooling_operation_state_embedding': 'mean', 'pooling_operation_head': 'max', 'dropout': 0.15, 'lr': 0.0005, 'optimizer': 'Adam', 'activation_function': 'ReLU', 'verbose': False, 'hopf_input_size': 4096, 'hopf_output_size': None, 'hopf_num_heads': 1, 'hopf_asso_dim': 2048, 'hopf_association_activation': 'None', 'hopf_beta': 0.05, 'norm_input': False, 'norm_asso': False, 'hopf_n_layers': 1, 'mol_encoder_layers': 1, 'temp_encoder_layers': 1, 'encoder_af': 'ReLU', 'normalize': True, 'norm_affine': True, 'hopf_pooling_operation_head': 'mean'}\n", 100 | "pretraining on applicability-matrix -- loading the matrix\n", 101 | "train 29816 samples ( 9162 max label)\n", 102 | "valid 4482 samples ( 9162 max label)\n", 103 | "test 5959 samples ( 9162 max label)\n", 104 | "pre-training (BCE-loss)\n", 105 | "117it [00:05, 22.39it/s]\n", 106 | " 0 -- train_loss: 0.094, loss_valid: 0.042, train_acc: 0.99024\n", 107 | "117it [00:05, 22.41it/s]\n", 108 | " 1 -- train_loss: 0.040, loss_valid: 0.040, train_acc: 0.99024\n", 109 | "117it [00:05, 22.40it/s]\n", 110 | " 2 -- train_loss: 0.039, loss_valid: 0.040, train_acc: 0.99024\n", 111 | "117it [00:05, 22.51it/s]\n", 112 | " 3 -- train_loss: 0.038, loss_valid: 0.039, train_acc: 0.99024\n", 113 | "117it [00:05, 22.52it/s]\n", 114 | " 4 -- train_loss: 0.038, loss_valid: 0.039, train_acc: 0.99024\n", 115 | "117it [00:05, 22.57it/s]\n", 116 | " 5 -- train_loss: 0.038, loss_valid: 0.039, train_acc: 0.99024\n", 117 | "117it [00:05, 22.44it/s]\n", 118 | " 6 -- train_loss: 0.037, loss_valid: 0.040, train_acc: 0.99015\n", 119 | "117it [00:05, 22.55it/s]\n", 120 | " 7 -- train_loss: 0.037, loss_valid: 0.040, train_acc: 0.99014\n", 121 | "117it [00:05, 22.55it/s]\n", 122 | " 8 -- train_loss: 0.036, loss_valid: 0.040, train_acc: 0.99011\n", 123 | "117it [00:05, 22.53it/s]\n", 124 | " 9 -- train_loss: 0.036, loss_valid: 0.040, train_acc: 0.99004\n", 125 | "117it [00:05, 22.51it/s]\n", 126 | "10 -- train_loss: 0.035, loss_valid: 0.041, train_acc: 0.99001\n", 127 | "117it [00:05, 22.47it/s]\n", 128 | "11 -- train_loss: 0.035, loss_valid: 0.041, train_acc: 0.99003\n", 129 | "117it [00:05, 22.48it/s]\n", 130 | "12 -- train_loss: 0.034, loss_valid: 0.042, train_acc: 0.98983\n", 131 | "117it [00:05, 22.58it/s]\n", 132 | "13 -- train_loss: 0.034, loss_valid: 0.042, train_acc: 0.98976\n", 133 | "117it [00:05, 22.56it/s]\n", 134 | "14 -- train_loss: 0.033, loss_valid: 0.042, train_acc: 0.98975\n", 135 | "117it [00:05, 22.50it/s]\n", 136 | "15 -- train_loss: 0.033, loss_valid: 0.043, train_acc: 0.98972\n", 137 | "117it [00:05, 22.40it/s]\n", 138 | "16 -- train_loss: 0.032, loss_valid: 0.043, train_acc: 0.98960\n", 139 | "117it [00:05, 22.48it/s]\n", 140 | "17 -- train_loss: 0.032, loss_valid: 0.044, train_acc: 0.98949\n", 141 | "117it [00:05, 22.45it/s]\n", 142 | "18 -- train_loss: 0.031, loss_valid: 0.044, train_acc: 0.98946\n", 143 | "117it [00:05, 22.52it/s]\n", 144 | "19 -- train_loss: 0.031, loss_valid: 0.044, train_acc: 0.98944\n", 145 | "117it [00:05, 22.52it/s]\n", 146 | "20 -- train_loss: 0.030, loss_valid: 0.045, train_acc: 0.98935\n", 147 | "117it [00:05, 22.45it/s]\n", 148 | "21 -- train_loss: 0.030, loss_valid: 0.045, train_acc: 0.98934\n", 149 | "117it [00:05, 22.48it/s]\n", 150 | "22 -- train_loss: 0.029, loss_valid: 0.046, train_acc: 0.98921\n", 151 | "117it [00:05, 22.47it/s]\n", 152 | "23 -- train_loss: 0.029, loss_valid: 0.046, train_acc: 0.98904\n", 153 | "117it [00:05, 22.43it/s]\n", 154 | "24 -- train_loss: 0.029, loss_valid: 0.047, train_acc: 0.98922\n", 155 | "117it [00:02, 47.88it/s]\n", 156 | " 0 -- train_loss: 5.154, loss_valid: 4.276, val_t1acc: 0.369, val_t100acc: 0.756\n", 157 | "117it [00:02, 48.02it/s]\n", 158 | " 1 -- train_loss: 1.529, loss_valid: 3.946, val_t1acc: 0.418, val_t100acc: 0.805\n", 159 | "117it [00:02, 48.17it/s]\n", 160 | " 2 -- train_loss: 0.439, loss_valid: 4.059, val_t1acc: 0.417, val_t100acc: 0.804\n", 161 | "117it [00:02, 47.74it/s]\n", 162 | " 3 -- train_loss: 0.214, loss_valid: 4.160, val_t1acc: 0.425, val_t100acc: 0.804\n", 163 | "117it [00:02, 47.94it/s]\n", 164 | " 4 -- train_loss: 0.147, loss_valid: 4.227, val_t1acc: 0.425, val_t100acc: 0.803\n", 165 | "117it [00:02, 48.18it/s]\n", 166 | " 5 -- train_loss: 0.110, loss_valid: 4.272, val_t1acc: 0.422, val_t100acc: 0.805\n", 167 | "117it [00:02, 47.79it/s]\n", 168 | " 6 -- train_loss: 0.091, loss_valid: 4.350, val_t1acc: 0.420, val_t100acc: 0.803\n", 169 | "117it [00:02, 48.11it/s]\n", 170 | " 7 -- train_loss: 0.080, loss_valid: 4.360, val_t1acc: 0.425, val_t100acc: 0.803\n", 171 | "117it [00:02, 48.32it/s]\n", 172 | " 8 -- train_loss: 0.069, loss_valid: 4.425, val_t1acc: 0.422, val_t100acc: 0.803\n", 173 | "117it [00:02, 48.22it/s]\n", 174 | " 9 -- train_loss: 0.062, loss_valid: 4.404, val_t1acc: 0.423, val_t100acc: 0.804\n", 175 | "saving predictions to ./data/preds/USPTO_sm_test_fortunato_rerun_1632299982.npy\n", 176 | "model saved to data/model/USPTO_sm_fortunato_valloss4.404_rerun_1632299982.pt\n", 177 | "3.945611251725091\n" 178 | ] 179 | } 180 | ], 181 | "source": [ 182 | "#sm_DNN_yes_test\n", 183 | "!python -m mhnreact.train --device cuda:1 --dataset_type sm --exp_name rerun \\\n", 184 | "--model_type fortunato --pretrain_epochs 25 --epochs 10 --hopf_asso_dim 2048 \\\n", 185 | "--fp_type morgan --fp_size 4096 --dropout 0.15 --lr 0.0005 \\\n", 186 | "--mol_encoder_layers 1 --batch_size 256 --save_preds True --save_model True \\\n", 187 | "--exp_name=rerun --seed 0 --fp_radius 2" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "### train MHN for template relevance prediction" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 13, 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "name": "stdout", 204 | "output_type": "stream", 205 | "text": [ 206 | "seeded with 0\n", 207 | "train 29816 samples ( 9161 max label)\n", 208 | "valid 4482 samples ( 9157 max label)\n", 209 | "test 5959 samples ( 9145 max label)\n", 210 | "[610, 1699, 287, 180, 143, 105, 70, 48, 124, 86, 68, 2539, 1648]\n", 211 | "{'fingerprint_type': 'morgan', 'template_fp_type': 'rdk', 'num_templates': 9162, 'fp_size': 4096, 'fp_radius': 2, 'device': 'cuda:1', 'batch_size': 1024, 'pooling_operation_state_embedding': 'mean', 'pooling_operation_head': 'max', 'dropout': 0.2, 'lr': 0.001, 'optimizer': 'Adam', 'activation_function': 'ReLU', 'verbose': False, 'hopf_input_size': 4096, 'hopf_output_size': None, 'hopf_num_heads': 1, 'hopf_asso_dim': 1024, 'hopf_association_activation': 'None', 'hopf_beta': 0.03, 'norm_input': False, 'norm_asso': True, 'hopf_n_layers': 1, 'mol_encoder_layers': 1, 'temp_encoder_layers': 1, 'encoder_af': 'ReLU', 'hopf_pooling_operation_head': 'mean'}\n", 212 | "loading tfp from file ./data/cache/templ_emb_4096_rdk2_9162_3083282594899452900587240289666014809109594710917180313564340596704210428539624892529205810214822540907625940879430504284296794943595967027539004835710563.npy\n", 213 | "Num of templates with added rand-vect of size 409 due to >=thresh (2): 1367\n", 214 | "(9162, 409) (9162,) 1367\n", 215 | "(9162, 4096) 9162\n", 216 | "adding template_matrix to params\n", 217 | "30it [00:02, 13.86it/s]\n", 218 | " 0 -- train_loss: 7.499, loss_valid: 5.818, val_t1acc: 0.129, val_t100acc: 0.654\n", 219 | "30it [00:02, 14.42it/s]\n", 220 | " 1 -- train_loss: 3.891, loss_valid: 3.622, val_t1acc: 0.339, val_t100acc: 0.866\n", 221 | "30it [00:02, 14.32it/s]\n", 222 | " 2 -- train_loss: 2.181, loss_valid: 3.097, val_t1acc: 0.397, val_t100acc: 0.908\n", 223 | "30it [00:02, 14.17it/s]\n", 224 | " 3 -- train_loss: 1.457, loss_valid: 2.941, val_t1acc: 0.419, val_t100acc: 0.917\n", 225 | "30it [00:02, 14.24it/s]\n", 226 | " 4 -- train_loss: 1.049, loss_valid: 2.884, val_t1acc: 0.429, val_t100acc: 0.921\n", 227 | "30it [00:02, 13.99it/s]\n", 228 | " 5 -- train_loss: 0.814, loss_valid: 2.881, val_t1acc: 0.436, val_t100acc: 0.923\n", 229 | "30it [00:02, 13.67it/s]\n", 230 | " 6 -- train_loss: 0.639, loss_valid: 2.902, val_t1acc: 0.436, val_t100acc: 0.920\n", 231 | "30it [00:02, 14.10it/s]\n", 232 | " 7 -- train_loss: 0.540, loss_valid: 2.923, val_t1acc: 0.436, val_t100acc: 0.918\n", 233 | "30it [00:02, 13.89it/s]\n", 234 | " 8 -- train_loss: 0.462, loss_valid: 2.952, val_t1acc: 0.440, val_t100acc: 0.918\n", 235 | "30it [00:02, 14.09it/s]\n", 236 | " 9 -- train_loss: 0.410, loss_valid: 2.974, val_t1acc: 0.444, val_t100acc: 0.915\n", 237 | "30it [00:02, 13.53it/s]\n", 238 | "10 -- train_loss: 0.363, loss_valid: 3.014, val_t1acc: 0.439, val_t100acc: 0.912\n", 239 | "30it [00:02, 13.84it/s]\n", 240 | "11 -- train_loss: 0.328, loss_valid: 3.079, val_t1acc: 0.437, val_t100acc: 0.911\n", 241 | "30it [00:02, 14.14it/s]\n", 242 | "12 -- train_loss: 0.305, loss_valid: 3.068, val_t1acc: 0.433, val_t100acc: 0.913\n", 243 | "30it [00:02, 13.98it/s]\n", 244 | "13 -- train_loss: 0.274, loss_valid: 3.104, val_t1acc: 0.439, val_t100acc: 0.911\n", 245 | "30it [00:02, 13.59it/s]\n", 246 | "14 -- train_loss: 0.263, loss_valid: 3.135, val_t1acc: 0.434, val_t100acc: 0.908\n", 247 | "30it [00:02, 14.02it/s]\n", 248 | "15 -- train_loss: 0.245, loss_valid: 3.138, val_t1acc: 0.436, val_t100acc: 0.905\n", 249 | "30it [00:02, 13.98it/s]\n", 250 | "16 -- train_loss: 0.242, loss_valid: 3.152, val_t1acc: 0.436, val_t100acc: 0.905\n", 251 | "30it [00:02, 13.90it/s]\n", 252 | "17 -- train_loss: 0.228, loss_valid: 3.170, val_t1acc: 0.436, val_t100acc: 0.904\n", 253 | "30it [00:02, 14.09it/s]\n", 254 | "18 -- train_loss: 0.216, loss_valid: 3.204, val_t1acc: 0.432, val_t100acc: 0.905\n", 255 | "30it [00:02, 13.99it/s]\n", 256 | "19 -- train_loss: 0.207, loss_valid: 3.237, val_t1acc: 0.433, val_t100acc: 0.900\n", 257 | "30it [00:02, 14.03it/s]\n", 258 | "20 -- train_loss: 0.202, loss_valid: 3.243, val_t1acc: 0.436, val_t100acc: 0.900\n", 259 | "30it [00:02, 13.91it/s]\n", 260 | "21 -- train_loss: 0.187, loss_valid: 3.248, val_t1acc: 0.436, val_t100acc: 0.900\n", 261 | "30it [00:02, 14.09it/s]\n", 262 | "22 -- train_loss: 0.185, loss_valid: 3.260, val_t1acc: 0.433, val_t100acc: 0.899\n", 263 | "30it [00:02, 13.92it/s]\n", 264 | "23 -- train_loss: 0.186, loss_valid: 3.293, val_t1acc: 0.437, val_t100acc: 0.895\n", 265 | "30it [00:02, 13.57it/s]\n", 266 | "24 -- train_loss: 0.182, loss_valid: 3.301, val_t1acc: 0.430, val_t100acc: 0.898\n", 267 | "30it [00:02, 13.86it/s]\n", 268 | "25 -- train_loss: 0.180, loss_valid: 3.288, val_t1acc: 0.431, val_t100acc: 0.894\n", 269 | "30it [00:02, 14.01it/s]\n", 270 | "26 -- train_loss: 0.179, loss_valid: 3.303, val_t1acc: 0.437, val_t100acc: 0.894\n", 271 | "30it [00:02, 13.97it/s]\n", 272 | "27 -- train_loss: 0.167, loss_valid: 3.329, val_t1acc: 0.433, val_t100acc: 0.891\n", 273 | "30it [00:02, 13.89it/s]\n", 274 | "28 -- train_loss: 0.163, loss_valid: 3.332, val_t1acc: 0.434, val_t100acc: 0.891\n", 275 | "30it [00:02, 14.04it/s]\n", 276 | "29 -- train_loss: 0.156, loss_valid: 3.333, val_t1acc: 0.434, val_t100acc: 0.895\n", 277 | "saving predictions to ./data/preds/USPTO_sm_test_mhn_rerun_1632318117.npy\n", 278 | "model saved to data/model/USPTO_sm_mhn_valloss3.333_rerun_1632318117.pt\n", 279 | "2.881380891799927\n" 280 | ] 281 | } 282 | ], 283 | "source": [ 284 | "#sm_MHN_no_test\n", 285 | "!python -m mhnreact.train --batch_size=1024 --concat_rand_template_thresh=2 --dataset_type=sm \\\n", 286 | "--device=cuda:1 --dropout=0.2 --epochs=30 --exp_name=rerun --fp_size=4096 --fp_type=morgan --hopf_asso_dim=1024 \\\n", 287 | "--hopf_association_activation=None --hopf_beta=0.03 --temp_encoder_layers=1 --mol_encoder_layers=1 \\\n", 288 | "--norm_asso=True --norm_input=False --hopf_num_heads=1 --lr=0.001 --model_type=mhn --save_preds True --save_model True \\\n", 289 | "--exp_name=rerun --seed 0 --fp_radius 2" 290 | ] 291 | }, 292 | { 293 | "cell_type": "markdown", 294 | "metadata": {}, 295 | "source": [ 296 | "## evaluation and loading in a trained model\n", 297 | "see ```notebooks/20_evaluation.ipynb```" 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "metadata": {}, 303 | "source": [ 304 | "# view experiments" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 10, 310 | "metadata": { 311 | "collapsed": false, 312 | "jupyter": { 313 | "outputs_hidden": false 314 | }, 315 | "pycharm": { 316 | "name": "#%%\n" 317 | } 318 | }, 319 | "outputs": [], 320 | "source": [ 321 | "import pandas as pd\n", 322 | "import os\n", 323 | "fldr = './data/experiments/'\n", 324 | "dfs = []\n", 325 | "for fn in os.listdir(fldr):\n", 326 | " if fn.split('.')[-1]=='tsv':\n", 327 | " dfs.append(pd.read_csv(fldr+fn,sep='\\t'))\n", 328 | "df = pd.concat(dfs)" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [ 337 | "clms = ['model_type', 'dataset_type', 'min_loss_valid', 'max_t1_acc_valid', 'max_t100_acc_valid']\n", 338 | "df[clms]" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": null, 344 | "metadata": {}, 345 | "outputs": [], 346 | "source": [] 347 | } 348 | ], 349 | "metadata": { 350 | "kernelspec": { 351 | "display_name": "Python 3", 352 | "language": "python", 353 | "name": "python3" 354 | }, 355 | "language_info": { 356 | "codemirror_mode": { 357 | "name": "ipython", 358 | "version": 3 359 | }, 360 | "file_extension": ".py", 361 | "mimetype": "text/x-python", 362 | "name": "python", 363 | "nbconvert_exporter": "python", 364 | "pygments_lexer": "ipython3", 365 | "version": "3.7.6" 366 | } 367 | }, 368 | "nbformat": 4, 369 | "nbformat_minor": 4 370 | } 371 | -------------------------------------------------------------------------------- /notebooks/12_training_single_step_retrosynthesis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%cd .." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "# train on single-step retrosynthesis\n", 17 | "for saving the model to ```./data/model/``` add ```--save_model True```\n", 18 | "\n", 19 | "for further details call ```python -m mhnreact.train -h```" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "## mhn model" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 17, 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "name": "stdout", 36 | "output_type": "stream", 37 | "text": [ 38 | "seeded with 0\n", 39 | "loading X, y from csv\n", 40 | "train 40008 samples\n", 41 | "test 5007 samples\n", 42 | "valid 5001 samples\n", 43 | "11800 templates\n", 44 | "obfuscating 706 templates because they are only in test\n", 45 | "[759, 355, 241, 210, 138, 132, 90, 99, 69, 49, 44, 2821, 1852]\n", 46 | "{'fingerprint_type': 'morgan', 'template_fp_type': 'rdk', 'num_templates': 11800, 'fp_size': 4096, 'fp_radius': 2, 'device': 'cuda:0', 'batch_size': 128, 'pooling_operation_state_embedding': 'mean', 'pooling_operation_head': 'max', 'dropout': 0.2, 'lr': 0.0005, 'optimizer': 'Adam', 'activation_function': 'ReLU', 'verbose': False, 'hopf_input_size': 4096, 'hopf_output_size': None, 'hopf_num_heads': 1, 'hopf_asso_dim': 512, 'hopf_association_activation': 'Tanh', 'hopf_beta': 0.05, 'norm_input': True, 'norm_asso': True, 'hopf_n_layers': 1, 'mol_encoder_layers': 1, 'temp_encoder_layers': 1, 'encoder_af': 'ReLU', 'hopf_pooling_operation_head': 'mean'}\n", 47 | "loading tfp from file ./data/cache/templ_emb_4096_rdk2_11800_5854610437078579843841453284346972473875662921463140419388261575361415213389916933693998864075020914850195914610357939423215965314516020256039769991581398.npy\n", 48 | "Num of templates with added rand-vect of size 409 due to >=thresh (1): 10216\n", 49 | "(11800, 409) (11800,) 10216\n", 50 | "(11800, 4096) 11800\n", 51 | "adding template_matrix to params\n", 52 | "313it [00:09, 34.14it/s]\n", 53 | " 0 -- train_loss: 4.189, loss_valid: 3.733, val_t1acc: 0.339, val_t100acc: 0.855\n", 54 | "313it [00:09, 34.01it/s]\n", 55 | " 1 -- train_loss: 2.819, loss_valid: 3.297, val_t1acc: 0.389, val_t100acc: 0.889\n", 56 | "313it [00:09, 34.02it/s]\n", 57 | " 2 -- train_loss: 2.152, loss_valid: 3.155, val_t1acc: 0.402, val_t100acc: 0.898\n", 58 | "313it [00:09, 33.92it/s]\n", 59 | " 3 -- train_loss: 1.735, loss_valid: 3.083, val_t1acc: 0.413, val_t100acc: 0.905\n", 60 | "313it [00:09, 33.92it/s]\n", 61 | " 4 -- train_loss: 1.473, loss_valid: 3.078, val_t1acc: 0.419, val_t100acc: 0.903\n", 62 | "313it [00:09, 33.99it/s]\n", 63 | " 5 -- train_loss: 1.290, loss_valid: 3.052, val_t1acc: 0.420, val_t100acc: 0.903\n", 64 | "313it [00:09, 34.12it/s]\n", 65 | " 6 -- train_loss: 1.142, loss_valid: 3.070, val_t1acc: 0.418, val_t100acc: 0.906\n", 66 | "313it [00:09, 34.25it/s]\n", 67 | " 7 -- train_loss: 1.045, loss_valid: 3.091, val_t1acc: 0.421, val_t100acc: 0.900\n", 68 | "313it [00:09, 34.17it/s]\n", 69 | " 8 -- train_loss: 0.961, loss_valid: 3.084, val_t1acc: 0.419, val_t100acc: 0.907\n", 70 | "313it [00:09, 34.35it/s]\n", 71 | " 9 -- train_loss: 0.891, loss_valid: 3.120, val_t1acc: 0.414, val_t100acc: 0.901\n", 72 | "3.0523353099822996\n", 73 | "testing on the real test set ;)\n", 74 | "execute all templates\n", 75 | "59082600 441874 0.007478919343427676\n", 76 | "len(X_fp[test]): 5007\n", 77 | "running the templates\n", 78 | "100%|███████████████████████████████████████| 5007/5007 [01:57<00:00, 42.49it/s]\n", 79 | "Single-step retrosynthesis-evaluation, results on ttest:\n", 80 | "t1acc\tt2acc\tt3acc\tt5acc\tt10acc\tt20acc\tt50acc\tt100acc\tt101acc\t\n", 81 | "48.21\t63.99\t71.44\t79.19\t86.40\t91.09\t93.75\t94.55\t94.71\t" 82 | ] 83 | } 84 | ], 85 | "source": [ 86 | "!python -m mhnreact.train --model_type=mhn --device=best --fp_size=4096 --fp_type morgan --template_fp_type rdk --concat_rand_template_thresh 1 \\\n", 87 | "--exp_name rerun --dataset_type 50k --csv_path ./data/USPTO_50k_MHN_prepro.csv.gz --split_col split --ssretroeval True --seed 0 \\\n", 88 | "--hopf_association_activation Tanh --fp_radius 2" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "t1acc corresponds to top-1 exact match accuracy on the test set after applying the predicted templates" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "## dnn model" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 23, 108 | "metadata": {}, 109 | "outputs": [ 110 | { 111 | "name": "stdout", 112 | "output_type": "stream", 113 | "text": [ 114 | "seeded with 0\n", 115 | "loading X, y from csv\n", 116 | "train 40008 samples\n", 117 | "test 5007 samples\n", 118 | "valid 5001 samples\n", 119 | "11800 templates\n", 120 | "obfuscating 706 templates because they are only in test\n", 121 | "[759, 355, 241, 210, 138, 132, 90, 99, 69, 49, 44, 2821, 1852]\n", 122 | "{'fingerprint_type': 'morgan', 'template_fp_type': 'rdk', 'num_templates': 11800, 'fp_size': 4096, 'fp_radius': 2, 'device': 'cuda:0', 'batch_size': 512, 'pooling_operation_state_embedding': 'mean', 'pooling_operation_head': 'max', 'dropout': 0.2, 'lr': 0.001, 'optimizer': 'Adam', 'activation_function': 'ReLU', 'verbose': False, 'hopf_input_size': 4096, 'hopf_output_size': None, 'hopf_num_heads': 1, 'hopf_asso_dim': 4096, 'hopf_association_activation': 'None', 'hopf_beta': 0.05, 'norm_input': True, 'norm_asso': True, 'hopf_n_layers': 1, 'mol_encoder_layers': 1, 'temp_encoder_layers': 1, 'encoder_af': 'SELU', 'hopf_pooling_operation_head': 'mean'}\n", 123 | "79it [00:05, 14.77it/s]\n", 124 | " 0 -- train_loss: 5.737, loss_valid: 4.977, val_t1acc: 0.342, val_t100acc: 0.702\n", 125 | "79it [00:05, 15.00it/s]\n", 126 | " 1 -- train_loss: 0.850, loss_valid: 5.933, val_t1acc: 0.343, val_t100acc: 0.740\n", 127 | "79it [00:05, 14.88it/s]\n", 128 | " 2 -- train_loss: 0.494, loss_valid: 6.581, val_t1acc: 0.349, val_t100acc: 0.742\n", 129 | "79it [00:05, 14.85it/s]\n", 130 | " 3 -- train_loss: 0.291, loss_valid: 6.869, val_t1acc: 0.360, val_t100acc: 0.742\n", 131 | "79it [00:05, 14.80it/s]\n", 132 | " 4 -- train_loss: 0.201, loss_valid: 6.947, val_t1acc: 0.361, val_t100acc: 0.740\n", 133 | "79it [00:05, 14.71it/s]\n", 134 | " 5 -- train_loss: 0.152, loss_valid: 7.239, val_t1acc: 0.368, val_t100acc: 0.738\n", 135 | "79it [00:05, 14.69it/s]\n", 136 | " 6 -- train_loss: 0.113, loss_valid: 7.076, val_t1acc: 0.362, val_t100acc: 0.740\n", 137 | "4.977378702163696\n", 138 | "testing on the real test set ;)\n", 139 | "execute all templates\n", 140 | "59082600 441874 0.007478919343427676\n", 141 | "len(X_fp[test]): 5007\n", 142 | "running the templates\n", 143 | "100%|███████████████████████████████████████| 5007/5007 [01:57<00:00, 42.68it/s]\n", 144 | "Single-step retrosynthesis-evaluation, results on ttest:\n", 145 | "t1acc\tt2acc\tt3acc\tt5acc\tt10acc\tt20acc\tt50acc\tt100acc\tt101acc\t\n", 146 | "44.68\t59.72\t67.25\t74.96\t82.70\t88.48\t92.83\t94.37\t94.71\t" 147 | ] 148 | } 149 | ], 150 | "source": [ 151 | "!python -m mhnreact.train --model_type=segler --device=best --fp_size=4096 --fp_type morgan --fp_radius 2 --batch_size=512 --encoder_af SELU --epochs 7 --lr 1e-3 --hopf_asso_dim 4096 \\\n", 152 | "--exp_name rerun --dataset_type 50k --csv_path ./data/USPTO_50k_MHN_prepro.csv.gz --split_col split --ssretroeval True --seed 0" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "# final selected parameters for mhn" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 32, 165 | "metadata": {}, 166 | "outputs": [ 167 | { 168 | "name": "stdout", 169 | "output_type": "stream", 170 | "text": [ 171 | "seeded with 0\n", 172 | "loading X, y from csv\n", 173 | "train 40008 samples\n", 174 | "valid 5001 samples\n", 175 | "test 5007 samples\n", 176 | "11800 templates\n", 177 | "obfuscating 706 templates because they are only in test\n", 178 | "adding val to train\n", 179 | "[725, 333, 241, 205, 142, 108, 97, 86, 92, 53, 54, 2871, 1908]\n", 180 | "{'fingerprint_type': 'maccs+morganc+topologicaltorsion+erg+atompair+pattern+rdkc+layered+mhfp', 'template_fp_type': 'rdkc+pattern+morganc+layered+atompair+erg+topologicaltorsion+mhfp', 'num_templates': 11800, 'fp_size': 30000, 'fp_radius': 2, 'device': 'cuda:1', 'batch_size': 1024, 'pooling_operation_state_embedding': 'mean', 'pooling_operation_head': 'max', 'dropout': 0.4, 'lr': 0.0001, 'optimizer': 'Adam', 'activation_function': 'ReLU', 'verbose': False, 'hopf_input_size': 30000, 'hopf_output_size': None, 'hopf_num_heads': 1, 'hopf_asso_dim': 1024, 'hopf_association_activation': 'None', 'hopf_beta': 0.03, 'norm_input': True, 'norm_asso': True, 'hopf_n_layers': 2, 'mol_encoder_layers': 1, 'temp_encoder_layers': 1, 'encoder_af': 'ReLU', 'hopf_pooling_operation_head': 'mean'}\n", 181 | "loading tfp from file ./data/cache/templ_emb_30000_rdkc+pattern+morganc+layered+atompair+erg+topologicaltorsion+mhfp2_11800_5854610437078579843841453284346972473875662921463140419388261575361415213389916933693998864075020914850195914610357939423215965314516020256039769991581398.npy\n", 182 | "Num of templates with added rand-vect of size 3000 due to >=thresh (1): 11094\n", 183 | "(11800, 3000) (11800,) 11094\n", 184 | "(11800, 30000) 11800\n", 185 | "adding template_matrix to params\n", 186 | "44it [00:46, 1.06s/it]\n", 187 | "44it [00:47, 1.09s/it]\n", 188 | "44it [00:46, 1.07s/it]\n", 189 | "44it [00:46, 1.05s/it]\n", 190 | "44it [00:46, 1.05s/it]\n", 191 | "44it [00:45, 1.04s/it]\n", 192 | "44it [00:45, 1.03s/it]\n", 193 | "44it [00:45, 1.04s/it]\n", 194 | "44it [00:45, 1.04s/it]\n", 195 | "44it [00:45, 1.03s/it]\n", 196 | " 0 -- train_loss: 0.938, loss_valid: 0.284, val_t1acc: 0.918, val_t100acc: 1.000\n", 197 | "44it [00:45, 1.03s/it]\n", 198 | "44it [00:45, 1.03s/it]\n", 199 | "44it [00:45, 1.02s/it]\n", 200 | "44it [00:45, 1.03s/it]\n", 201 | "44it [00:45, 1.03s/it]\n", 202 | "44it [00:45, 1.03s/it]\n", 203 | "44it [00:45, 1.04s/it]\n", 204 | "44it [00:45, 1.03s/it]\n", 205 | "44it [00:45, 1.04s/it]\n", 206 | "44it [00:45, 1.04s/it]\n", 207 | " 1 -- train_loss: 0.355, loss_valid: 0.033, val_t1acc: 0.995, val_t100acc: 1.000\n", 208 | "44it [00:45, 1.03s/it]\n", 209 | "44it [00:45, 1.03s/it]\n", 210 | "44it [00:45, 1.05s/it]\n", 211 | "44it [00:45, 1.04s/it]\n", 212 | "44it [00:45, 1.04s/it]\n", 213 | "44it [00:45, 1.04s/it]\n", 214 | "44it [00:45, 1.04s/it]\n", 215 | "44it [00:45, 1.04s/it]\n", 216 | "44it [00:45, 1.04s/it]\n", 217 | "44it [00:46, 1.05s/it]\n", 218 | " 2 -- train_loss: 0.173, loss_valid: 0.010, val_t1acc: 0.998, val_t100acc: 1.000\n", 219 | "44it [00:45, 1.04s/it]\n", 220 | "44it [00:45, 1.04s/it]\n", 221 | "44it [00:45, 1.04s/it]\n", 222 | "44it [00:46, 1.05s/it]\n", 223 | "44it [00:45, 1.04s/it]\n", 224 | "44it [00:45, 1.04s/it]\n", 225 | "44it [00:46, 1.05s/it]\n", 226 | "44it [00:45, 1.04s/it]\n", 227 | "44it [00:45, 1.04s/it]\n", 228 | "44it [00:45, 1.04s/it]\n", 229 | " 3 -- train_loss: 0.112, loss_valid: 0.008, val_t1acc: 0.998, val_t100acc: 1.000\n", 230 | "model saved to data/model/USPTO_50k_mhn_valloss0.008_retro_selected_1632315904.pt\n", 231 | "0.007613665238022804\n", 232 | "testing on the real test set ;)\n", 233 | "execute all templates\n", 234 | "59082600 441874 0.007478919343427676\n", 235 | "len(X_fp[test]): 5007\n", 236 | "running the templates\n", 237 | "100%|███████████████████████████████████████| 5007/5007 [01:56<00:00, 43.15it/s]\n", 238 | "Single-step retrosynthesis-evaluation, results on ttest:\n", 239 | "t1acc\tt2acc\tt3acc\tt5acc\tt10acc\tt20acc\tt50acc\tt100acc\tt101acc\t\n", 240 | "51.45\t67.25\t74.16\t80.75\t87.92\t91.91\t93.95\t94.63\t94.71\t" 241 | ] 242 | } 243 | ], 244 | "source": [ 245 | "!python -m mhnreact.train --batch_size=1024 --concat_rand_template_thresh=1 --device=cuda:1 --dropout=0.4 \\\n", 246 | "--epochs=40 --fp_size=30000 --fp_type=maccs+morganc+topologicaltorsion+erg+atompair+pattern+rdkc+layered+mhfp --hopf_asso_dim=1024 --hopf_association_activation=None --hopf_beta=0.03 \\\n", 247 | "--norm_asso=True --norm_input=True --hopf_num_heads=1 --lr=0.0001 --model_type=mhn --template_fp_type rdkc+pattern+morganc+layered+atompair+erg+topologicaltorsion+mhfp \\\n", 248 | "--exp_name=retro_selected --reactant_pooling lgamma --hopf_n_layers 2 --layer2weight 0.1 --template_fp_type2 rdk \\\n", 249 | "--dataset_type 50k --csv_path ./data/USPTO_50k_MHN_prepro.csv.gz --split_col split --ssretroeval True --seed 0 --eval_every_n_epochs 10 --addval2train True --save_model True" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "# time-split" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 31, 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "name": "stdout", 266 | "output_type": "stream", 267 | "text": [ 268 | "seeded with 0\n", 269 | "loading X, y from csv\n", 270 | "nan 0 samples\n", 271 | "test 7315 samples\n", 272 | "valid 3517 samples\n", 273 | "train 38546 samples\n", 274 | "11800 templates\n", 275 | "obfuscating 1276 templates because they are only in test\n", 276 | "adding val to train\n", 277 | "0it [00:00, ?it/s]\n", 278 | "100%|██████████| 7315/7315 [00:03<00:00, 1847.76it/s] \n", 279 | "100%|██████████| 3517/3517 [00:03<00:00, 1001.60it/s] \n", 280 | "100%|██████████| 42063/42063 [00:23<00:00, 1808.62it/s] | 28360/42063 [00:16<00:00, 32294.16it/s]� | 31989/42063 [00:18<00:00, 33801.71it/s]\n", 281 | "[1528, 472, 332, 214, 172, 169, 119, 114, 102, 86, 83, 3924, 2719]\n", 282 | "{'fingerprint_type': 'maccs+morganc+topologicaltorsion+erg+atompair+pattern+rdkc+layered+mhfp', 'template_fp_type': 'rdkc+pattern+morganc+layered+atompair+erg+topologicaltorsion+mhfp', 'num_templates': 11800, 'fp_size': 30000, 'fp_radius': 2, 'device': 'cuda:1', 'batch_size': 1024, 'pooling_operation_state_embedding': 'mean', 'pooling_operation_head': 'max', 'dropout': 0.4, 'lr': 0.0001, 'optimizer': 'Adam', 'activation_function': 'ReLU', 'verbose': False, 'hopf_input_size': 30000, 'hopf_output_size': None, 'hopf_num_heads': 1, 'hopf_asso_dim': 1024, 'hopf_association_activation': 'None', 'hopf_beta': 0.03, 'norm_input': True, 'norm_asso': True, 'hopf_n_layers': 4, 'mol_encoder_layers': 1, 'temp_encoder_layers': 1, 'encoder_af': 'ReLU', 'hopf_pooling_operation_head': 'mean'}\n", 283 | "updating template-embedding; (just computing the template-fingerprint and using that)\n", 284 | "100%|██████████| 11800/11800 [00:06<00:00, 1897.37it/s] 11800 [00:01<00:00, 32251.77it/s]\n", 285 | "100%|██████████| 11800/11800 [00:07<00:00, 1659.80it/s] \n", 286 | "Num of templates with added rand-vect of size 3000 due to >=thresh (1): 10413\n", 287 | "(11800, 3000) (11800,) 10413\n", 288 | "(11800, 30000) 11800\n", 289 | "adding template_matrix to params\n", 290 | "42it [01:27, 2.07s/it]\n", 291 | "42it [01:30, 2.15s/it]\n", 292 | "42it [01:27, 2.09s/it]\n", 293 | "42it [01:26, 2.06s/it]\n", 294 | "42it [01:26, 2.07s/it]\n", 295 | "42it [01:25, 2.05s/it]\n", 296 | "42it [01:26, 2.06s/it]\n", 297 | "42it [01:25, 2.05s/it]\n", 298 | "42it [01:25, 2.04s/it]\n", 299 | "42it [01:25, 2.04s/it]\n", 300 | " 0 -- train_loss: 1.056, loss_valid: 0.323, val_t1acc: 0.904, val_t100acc: 1.000\n", 301 | "42it [01:26, 2.05s/it]\n", 302 | "42it [01:26, 2.06s/it]\n", 303 | "42it [01:26, 2.06s/it]\n", 304 | "42it [01:26, 2.06s/it]\n", 305 | "42it [01:26, 2.07s/it]\n", 306 | "42it [01:26, 2.07s/it]\n", 307 | "42it [01:26, 2.06s/it]\n", 308 | "42it [01:26, 2.07s/it]\n", 309 | "42it [01:27, 2.07s/it]\n", 310 | "42it [01:26, 2.07s/it]valid\n", 311 | " 1 -- train_loss: 0.445, loss_valid: 0.067, val_t1acc: 0.986, val_t100acc: 1.000\n", 312 | "42it [01:26, 2.05s/it]\n", 313 | "42it [01:26, 2.07s/it]\n", 314 | "42it [01:27, 2.08s/it]\n", 315 | "42it [01:27, 2.08s/it]\n", 316 | "42it [01:27, 2.07s/it]\n", 317 | "42it [01:27, 2.08s/it]\n", 318 | "42it [01:27, 2.08s/it]\n", 319 | "42it [01:27, 2.08s/it]\n", 320 | "42it [01:27, 2.08s/it]\n", 321 | "42it [01:27, 2.08s/it]valid\n", 322 | " 2 -- train_loss: 0.243, loss_valid: 0.020, val_t1acc: 0.997, val_t100acc: 1.000\n", 323 | "42it [01:26, 2.06s/it]\n", 324 | "42it [01:27, 2.07s/it]\n", 325 | "42it [01:27, 2.08s/it]\n", 326 | "42it [01:27, 2.08s/it]\n", 327 | "42it [01:27, 2.08s/it]\n", 328 | "42it [01:27, 2.08s/it]\n", 329 | "42it [01:27, 2.08s/it]\n", 330 | "42it [01:27, 2.08s/it]\n", 331 | "42it [01:27, 2.09s/it]\n", 332 | "42it [01:27, 2.08s/it]valid\n", 333 | " 3 -- train_loss: 0.162, loss_valid: 0.008, val_t1acc: 0.999, val_t100acc: 1.000\n", 334 | "0.008088136557489634\n", 335 | "testing on the real test set ;)\n", 336 | "execute all templates\n", 337 | "86317000 683460 0.007918023100895536\n", 338 | "len(X_fp[test]): 7315\n", 339 | "running the templates\n", 340 | "100%|██████████| 7315/7315 [03:09<00:00, 38.63it/s]\n", 341 | "Single-step retrosynthesis-evaluation, results on ttest:\n", 342 | "t1acc\tt2acc\tt3acc\tt5acc\tt10acc\tt20acc\tt50acc\tt100acc\tt101acc\t\n", 343 | "42.97\t58.06\t66.12\t74.37\t82.73\t88.34\t92.11\t93.41\t93.56\t\n" 344 | ] 345 | } 346 | ], 347 | "source": [ 348 | "!python -m mhnreact.train --batch_size=1024 --concat_rand_template_thresh=1 --device=cuda:1 --dropout=0.4 \\\n", 349 | "--epochs=40 --fp_size=30000 --fp_type=maccs+morganc+topologicaltorsion+erg+atompair+pattern+rdkc+layered+mhfp --hopf_asso_dim=1024 --hopf_association_activation=None --hopf_beta=0.03 \\\n", 350 | "--norm_asso=True --norm_input=True --hopf_num_heads=1 --lr=0.0001 --model_type=mhn --template_fp_type rdkc+pattern+morganc+layered+atompair+erg+topologicaltorsion+mhfp \\\n", 351 | "--exp_name=retro_selected --reactant_pooling lgamma --hopf_n_layers 2 --layer2weight 0.1 --template_fp_type2 rdk \\\n", 352 | "--dataset_type 50k_time --csv_path ./data/USPTO_50k_MHN_prepro.csv.gz --split_col time_split_years --ssretroeval True --seed 0 --eval_every_n_epochs 10 --addval2train True --wandb True" 353 | ] 354 | } 355 | ], 356 | "metadata": { 357 | "kernelspec": { 358 | "display_name": "Python 3", 359 | "language": "python", 360 | "name": "python3" 361 | }, 362 | "language_info": { 363 | "codemirror_mode": { 364 | "name": "ipython", 365 | "version": 3 366 | }, 367 | "file_extension": ".py", 368 | "mimetype": "text/x-python", 369 | "name": "python", 370 | "nbconvert_exporter": "python", 371 | "pygments_lexer": "ipython3", 372 | "version": "3.7.6" 373 | } 374 | }, 375 | "nbformat": 4, 376 | "nbformat_minor": 4 377 | } 378 | -------------------------------------------------------------------------------- /notebooks/30_retrieval_fast_scalable.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "/system/user/publicwork/seidl/projects/mhn-react\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "%cd .." 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "# Fast and efficient retrieval using a trained model\n", 25 | "After training the model can be used as embedding-model and templates can be efficiently retrieved, instead of allways holding all templates in gpu-memory" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 4, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "seeded with 0\n", 38 | "loading X, y from csv\n", 39 | "test 5007 samples\n", 40 | "valid 5001 samples\n", 41 | "train 40008 samples\n", 42 | "11800 templates\n", 43 | "[759, 355, 241, 210, 138, 132, 90, 99, 69, 49, 44, 2821, 1852]\n", 44 | "{'fingerprint_type': 'morgan', 'template_fp_type': 'rdk', 'num_templates': 11800, 'fp_size': 4096, 'fp_radius': 2, 'device': 'cuda:0', 'batch_size': 128, 'pooling_operation_state_embedding': 'mean', 'pooling_operation_head': 'max', 'dropout': 0.2, 'lr': 0.0005, 'optimizer': 'Adam', 'activation_function': 'ReLU', 'verbose': False, 'hopf_input_size': 4096, 'hopf_output_size': None, 'hopf_num_heads': 1, 'hopf_asso_dim': 512, 'hopf_association_activation': 'None', 'hopf_beta': 0.05, 'norm_input': True, 'norm_asso': True, 'hopf_n_layers': 1, 'mol_encoder_layers': 1, 'temp_encoder_layers': 1, 'encoder_af': 'ReLU', 'hopf_pooling_operation_head': 'mean'}\n", 45 | "loading tfp from file ./data/cache/templ_emb_4096_rdk2_11800_319308756406610263009709791057247232169953402565012761665882794469049907419597432806619930615825219861593907665538161448220007278140048617814857686486870.npy\n", 46 | "313it [00:15, 20.61it/s]\n", 47 | " 0 -- train_loss: 4.206, loss_valid: 3.770, val_t1acc: 0.200, val_t100acc: 0.880\n", 48 | "313it [00:14, 21.71it/s]\n", 49 | " 1 -- train_loss: 2.972, loss_valid: 3.409, val_t1acc: 0.226, val_t100acc: 0.910\n", 50 | "313it [00:14, 21.74it/s]\n", 51 | " 2 -- train_loss: 2.407, loss_valid: 3.275, val_t1acc: 0.231, val_t100acc: 0.913\n", 52 | "313it [00:14, 22.27it/s]\n", 53 | " 3 -- train_loss: 2.057, loss_valid: 3.254, val_t1acc: 0.236, val_t100acc: 0.913\n", 54 | "313it [00:14, 21.85it/s]\n", 55 | " 4 -- train_loss: 1.823, loss_valid: 3.252, val_t1acc: 0.239, val_t100acc: 0.913\n", 56 | "313it [00:14, 21.84it/s]\n", 57 | " 5 -- train_loss: 1.687, loss_valid: 3.251, val_t1acc: 0.241, val_t100acc: 0.912\n", 58 | "313it [00:14, 21.55it/s]\n", 59 | " 6 -- train_loss: 1.552, loss_valid: 3.241, val_t1acc: 0.241, val_t100acc: 0.912\n", 60 | "313it [00:14, 21.64it/s]\n", 61 | " 7 -- train_loss: 1.475, loss_valid: 3.267, val_t1acc: 0.242, val_t100acc: 0.910\n", 62 | "313it [00:14, 21.63it/s]\n", 63 | " 8 -- train_loss: 1.448, loss_valid: 3.283, val_t1acc: 0.238, val_t100acc: 0.904\n", 64 | "313it [00:14, 21.70it/s]\n", 65 | " 9 -- train_loss: 1.367, loss_valid: 3.285, val_t1acc: 0.239, val_t100acc: 0.906\n", 66 | "model saved to data/model/USPTO_50k_mhn_valloss3.285_rerun_1693819455.pt\n", 67 | "3.241067260503769\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "# train a model if you haven't already (these parameters will produce and ok-ish model fast ;) \n", 73 | "!python -m mhnreact.train --model_type=mhn --device=best --fp_size=4096 --fp_type morgan --template_fp_type rdk --concat_rand_template_thresh -1 \\\n", 74 | "--exp_name rerun --dataset_type 50k --csv_path ./data/USPTO_50k_MHN_prepro.csv.gz --split_col split --ssretroeval False --seed 0 \\\n", 75 | "--hopf_association_activation None --fp_radius 2 --save_model True\n", 76 | "# norm_asso should be True (but is so by default)\n", 77 | "# concat_rand_template_thresh -1 means we don't add noise to template classification to make them better distinguishable; \n", 78 | "# reduces top-1 performance but should be better for retrieval from a new set of templates" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 7, 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "name": "stdout", 88 | "output_type": "stream", 89 | "text": [ 90 | "{'fingerprint_type': 'morgan', 'template_fp_type': 'rdk', 'num_templates': 11800, 'fp_size': 4096, 'fp_radius': 2, 'device': 'cpu', 'batch_size': 128, 'pooling_operation_state_embedding': 'mean', 'pooling_operation_head': 'max', 'dropout': 0.2, 'lr': 0.0005, 'optimizer': 'Adam', 'activation_function': 'ReLU', 'verbose': False, 'hopf_input_size': 4096, 'hopf_output_size': None, 'hopf_num_heads': 1, 'hopf_asso_dim': 512, 'hopf_association_activation': 'None', 'hopf_beta': 0.05, 'norm_input': True, 'norm_asso': True, 'hopf_n_layers': 0, 'mol_encoder_layers': 1, 'temp_encoder_layers': 1, 'encoder_af': 'ReLU', 'hopf_pooling_operation_head': 'mean'}\n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "# now we can load in the model (this will take the first of your models in the model directory);\n", 96 | "# you can also specify your own model dir by passing the model_path argument\n", 97 | "from mhnreact.inspect import *\n", 98 | "clf = load_clf(model_fn=list_models()[0], model_type='mhn', device='cpu')" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 139, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "# get templates from USPTO-full (230k templates)\n", 108 | "# for USPTO_full_MHN_prepro_woiv run the notebook 03_prepro_uspto_full\n", 109 | "import pandas as pd\n", 110 | "df = pd.read_csv(\"/system/user/seidl/seidl/projects/projects/mhn-react/data/USPTO_full_MHN_prepro_woiv.csv.gz\")\n", 111 | "#df = pd.read_csv(\"./data/USPTO_50k_MHN_prepro.csv.gz\") # USPTO-sm\n", 112 | "\n", 113 | "tmp = df[['reaction_smarts','label']].drop_duplicates(subset=['reaction_smarts','label']).sort_values('label')\n", 114 | "# drop the ones from the test set\n", 115 | "\n", 116 | "tmp.index= tmp.label\n", 117 | "template_list = tmp['reaction_smarts'].to_dict()" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 9, 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "data": { 127 | "text/plain": [ 128 | "278546" 129 | ] 130 | }, 131 | "execution_count": 9, 132 | "metadata": {}, 133 | "output_type": "execute_result" 134 | } 135 | ], 136 | "source": [ 137 | "templates = list(template_list.values())\n", 138 | "len(templates)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "# encode templates\n", 148 | "xd = clf.encode_templates(templates)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 11, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "data": { 158 | "text/plain": [ 159 | "(278546, 512)" 160 | ] 161 | }, 162 | "execution_count": 11, 163 | "metadata": {}, 164 | "output_type": "execute_result" 165 | } 166 | ], 167 | "source": [ 168 | "xd.shape" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 10, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "# install autofaiss\n", 178 | "#!pip install autofaiss" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 143, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "# use faiss to build index\n", 188 | "import faiss\n", 189 | "index = faiss.IndexFlatIP(xd.shape[1])\n", 190 | "index.add(xd)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 17, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "# save to disk\n", 200 | "faiss.write_index(index, \"./data/templates.index\")\n", 201 | "# load from disk\n", 202 | "# index = faiss.read_index(\"./data/templates.index\")" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "# you can also use autofaiss to build index\n", 212 | "# cosine similarity to find similar templates\n", 213 | "# uses less memory and picks an appropriate index type automatically\n", 214 | "# more prone to false positives than exact search --> experiment: USPTO-sm 70% (vs 90% exact template match accuracy)\n", 215 | "#from autofaiss import build_index\n", 216 | "#index, index_infos = build_index(xd, save_on_disk=True, \n", 217 | "# index_path=\"./data/templates.index\", metric_type='ip', min_nearest_neighbors_to_retrieve=20, \n", 218 | "# use_gpu=False, make_direct_map=False) # ip = inner product (cosine sim if vectors are normalized)" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 18, 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "name": "stdout", 228 | "output_type": "stream", 229 | "text": [ 230 | "556M\t./data/templates.index\n" 231 | ] 232 | } 233 | ], 234 | "source": [ 235 | "# check memory usage \n", 236 | "!du -sh ./data/templates.index" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 19, 242 | "metadata": {}, 243 | "outputs": [ 244 | { 245 | "data": { 246 | "text/plain": [ 247 | "1.140924536" 248 | ] 249 | }, 250 | "execution_count": 19, 251 | "metadata": {}, 252 | "output_type": "execute_result" 253 | } 254 | ], 255 | "source": [ 256 | "# check memory of numpy array in GB\n", 257 | "import sys\n", 258 | "sys.getsizeof(xd)/1e9" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 144, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "# let's evaluate the test set this way (we trained on USPTO-sm but test on USPTO-full zero-shot)\n", 268 | "n_samples = 1000\n", 269 | "xq = clf.encode_smiles(df[df.split==\"test\"].prod_smiles[:n_samples].values.tolist())" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 158, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "y = df[df.split==\"test\"].label.values.tolist() # the template that should have been retrieved\n", 279 | "y = np.array(y)[:n_samples]" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 68, 285 | "metadata": {}, 286 | "outputs": [ 287 | { 288 | "name": "stdout", 289 | "output_type": "stream", 290 | "text": [ 291 | "CPU times: user 7min 35s, sys: 578 ms, total: 7min 35s\n", 292 | "Wall time: 7.05 s\n" 293 | ] 294 | } 295 | ], 296 | "source": [ 297 | "%%time\n", 298 | "# retrieve top k templates using the MHN-encoded molecule\n", 299 | "k=100\n", 300 | "_, I = index.search(xq, k)" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 138, 306 | "metadata": {}, 307 | "outputs": [ 308 | { 309 | "name": "stdout", 310 | "output_type": "stream", 311 | "text": [ 312 | "top- 1 accuracy: 4.10%\n", 313 | "top- 2 accuracy: 7.70%\n", 314 | "top- 3 accuracy: 10.50%\n", 315 | "top- 5 accuracy: 13.40%\n", 316 | "top- 10 accuracy: 19.20%\n", 317 | "top- 20 accuracy: 26.00%\n", 318 | "top- 50 accuracy: 33.60%\n", 319 | "top- 100 accuracy: 38.90%\n" 320 | ] 321 | } 322 | ], 323 | "source": [ 324 | "# top-k accuracy\n", 325 | "for k in [1,2,3,5,10,20,50,100]:\n", 326 | " print(f\"top-{k: 4d} accuracy: {(y[:,None]==I[:,:k]).any(axis=1).mean()*100: 6.2f}%\")" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": 114, 332 | "metadata": {}, 333 | "outputs": [ 334 | { 335 | "name": "stdout", 336 | "output_type": "stream", 337 | "text": [ 338 | "CPU times: user 40.2 s, sys: 4.45 s, total: 44.6 s\n", 339 | "Wall time: 23.5 s\n" 340 | ] 341 | } 342 | ], 343 | "source": [ 344 | "%%time\n", 345 | "# retrieve using a dot-droduct via numpy\n", 346 | "# besides being slower, it also uses more memory\n", 347 | "import numpy as np\n", 348 | "I_np = np.argsort(np.dot(xq[:], xd.T), axis=1)[:,-k:][:,::-1]\n", 349 | "# to do this more efficently one can used argpartion beforehand ;)" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 137, 355 | "metadata": {}, 356 | "outputs": [ 357 | { 358 | "name": "stdout", 359 | "output_type": "stream", 360 | "text": [ 361 | "top- 1 accuracy: 4.10%\n", 362 | "top- 2 accuracy: 7.70%\n", 363 | "top- 3 accuracy: 10.50%\n", 364 | "top- 5 accuracy: 13.40%\n", 365 | "top- 10 accuracy: 19.20%\n", 366 | "top- 20 accuracy: 26.00%\n", 367 | "top- 50 accuracy: 33.60%\n", 368 | "top- 100 accuracy: 38.90%\n" 369 | ] 370 | } 371 | ], 372 | "source": [ 373 | "# top-k accuracy\n", 374 | "for k in [1,2,3,5,10,20,50,100]:\n", 375 | " print(f\"top-{k: 4d} accuracy: {(y[:n_samples,None]==I[:,:k]).any(axis=1).mean()*100: 6.2f}%\")" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 51, 381 | "metadata": {}, 382 | "outputs": [ 383 | { 384 | "data": { 385 | "text/plain": [ 386 | "24.94507689235071" 387 | ] 388 | }, 389 | "execution_count": 51, 390 | "metadata": {}, 391 | "output_type": "execute_result" 392 | } 393 | ], 394 | "source": [ 395 | "#(y==I[:,0]).mean()*100 # top 1 accuracy\n", 396 | "# top1 acc: 25% for USPTO-sm\n", 397 | "# top100 acc: 90% for USPTO-sm" 398 | ] 399 | } 400 | ], 401 | "metadata": { 402 | "kernelspec": { 403 | "display_name": "Python 3", 404 | "language": "python", 405 | "name": "python3" 406 | }, 407 | "language_info": { 408 | "codemirror_mode": { 409 | "name": "ipython", 410 | "version": 3 411 | }, 412 | "file_extension": ".py", 413 | "mimetype": "text/x-python", 414 | "name": "python", 415 | "nbconvert_exporter": "python", 416 | "pygments_lexer": "ipython3", 417 | "version": "3.7.13" 418 | }, 419 | "orig_nbformat": 4 420 | }, 421 | "nbformat": 4, 422 | "nbformat_minor": 2 423 | } 424 | -------------------------------------------------------------------------------- /scripts/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/mhn-react/a3e0cc6fe3b8715f2e9b91352a0b66893b4aa560/scripts/.gitkeep -------------------------------------------------------------------------------- /scripts/make_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda create -n mhnreact_env python=3.8 3 | eval "$(conda shell.bash hook)" 4 | conda activate mhnreact_env 5 | conda install -c conda-forge rdkit 6 | pip install torch scipy ipykernel matplotlib sklearn swifter 7 | cd data/temprel-fortunato/template-relevance-master/ 8 | pip install -e . 9 | pip install -e "git://github.com/connorcoley/rdchiral.git#egg=rdchiral" 10 | # conda install -c conda-forge -c ljn917 rdchiral_cpp # consider using fast c_pp rdchiral -- doesn't work right now... -- try later ;) -------------------------------------------------------------------------------- /scripts/train_ssr_mhn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #call from parent folder 3 | eval "$(conda shell.bash hook)" 4 | conda activate mhnreact_env 5 | python -m mhnreact.train --batch_size=1024 --concat_rand_template_thresh=1 --device=best --dropout=0.4 \ 6 | --epochs=40 --fp_size=30000 --fp_type=maccs+morganc+topologicaltorsion+erg+atompair+pattern+rdkc+layered+mhfp --hopf_asso_dim=1024 --hopf_association_activation=None --hopf_beta=0.03 --lr=0.0001 --model_type=mhn --template_fp_type rdkc+pattern+morganc+layered+atompair+erg+topologicaltorsion+mhfp \ 7 | --exp_name=retro_selected --reactant_pooling lgamma --hopf_n_layers 2 --layer2weight 0.05 --template_fp_type2 rdk \ 8 | --dataset_type 50k --csv_path ./data/USPTO_50k_MHN_prepro.csv.gz --split_col split --ssretroeval True --seed 0 --eval_every_n_epochs 10 --addval2train True -------------------------------------------------------------------------------- /scripts/train_tr_dnn_fortunato.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #call from parent folder 3 | eval "$(conda shell.bash hook)" 4 | conda activate mhnreact_env 5 | #sm_DNN_yes_test 6 | python -m mhnreact.train --device best --dataset_type sm --exp_name rerun \ 7 | --model_type fortunato --pretrain_epochs 25 --epochs 10 --hopf_asso_dim 2048 \ 8 | --fp_type morgan --fp_size 4096 --dropout 0.15 --lr 0.0005 \ 9 | --mol_encoder_layers 1 --batch_size 256 --save_preds True --save_model True \ 10 | --exp_name=rerun --seed 0 --fp_radius 2 -------------------------------------------------------------------------------- /scripts/train_tr_dnn_segler.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #call from parent folder 3 | eval "$(conda shell.bash hook)" 4 | conda activate mhnreact_env 5 | #sm_DNN_no_test 6 | python -m mhnreact.train --device best --dataset_type sm --exp_name rerun --model_type segler \ 7 | --pretrain_epochs 0 --epochs 10 --hopf_asso_dim 2048 --fp_type morgan --fp_size 4096 --dropout 0.15 \ 8 | --lr 0.0005 --mol_encoder_layers 1 --batch_size 256 --save_preds True --save_model True \ 9 | --seed 0 --fp_radius 2 -------------------------------------------------------------------------------- /scripts/train_tr_mhn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #call from parent folder 3 | eval "$(conda shell.bash hook)" 4 | conda activate mhnreact_env 5 | #sm_MHN_no_test 6 | python -m mhnreact.train --batch_size=1024 --concat_rand_template_thresh=2 --dataset_type=sm \ 7 | --device=best --dropout=0.2 --epochs=30 --exp_name=rerun --fp_size=4096 --fp_type=morgan --hopf_asso_dim=1024 \ 8 | --hopf_association_activation=None --hopf_beta=0.03 --temp_encoder_layers=1 --mol_encoder_layers=1 \ 9 | --norm_asso=True --norm_input=False --hopf_num_heads=1 --lr=0.001 --model_type=mhn --save_preds True --save_model True \ 10 | --exp_name=rerun --seed 0 --fp_radius 2 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="mhnreact", 8 | version="1.0", 9 | author="Philipp Seidl and Philipp Renz", 10 | author_email="ph.seidl92@gmail.com", 11 | description="", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/ml-jku/mhn-react", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: BSD-2-Clause License", 19 | "Operating System :: linux-64", 20 | ], 21 | python_requires='>=3.7', 22 | ) 23 | -------------------------------------------------------------------------------- /tools/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.2-base-ubuntu18.04 2 | 3 | RUN apt-get update && apt-get install -y --no-install-recommends \ 4 | build-essential \ 5 | git \ 6 | curl \ 7 | ca-certificates \ 8 | libjpeg-dev \ 9 | libpng-dev \ 10 | && rm -rf /var/lib/apt/lists/* 11 | 12 | RUN curl -fsSL -v -o ~/miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 13 | chmod +x ~/miniconda.sh && \ 14 | ~/miniconda.sh -b -p /opt/conda && \ 15 | rm ~/miniconda.sh 16 | 17 | WORKDIR /mhn-react 18 | COPY . /workdir/mhn-react 19 | 20 | RUN /opt/conda/bin/conda env create -f /workdir/mhn-react/tools/docker/env.yml && \ 21 | /opt/conda/bin/conda clean -ya 22 | 23 | RUN /opt/conda/envs/mhnreact_env/bin/python -m pip install /workdir/mhn-react 24 | RUN /opt/conda/envs/mhnreact_env/bin/python -m pip install /workdir/mhn-react/data/temprel-fortunato/template-relevance-master/. 25 | 26 | ENV PATH /opt/conda/envs/mhnreact_env/bin:$PATH 27 | -------------------------------------------------------------------------------- /tools/docker/README.md: -------------------------------------------------------------------------------- 1 | # Docker Image 2 | 3 | Build docker image using following command. 4 | [`DOCKER_BUILDKIT`](https://docs.docker.com/develop/develop-images/build_enhancements/) results in faster builds but is Linux only. 5 | 6 | ``` 7 | DOCKER_BUILDKIT=1 docker build -t mhnreact:latest -f Dockerfile ../.. 8 | ``` 9 | -------------------------------------------------------------------------------- /tools/docker/env.yml: -------------------------------------------------------------------------------- 1 | name: mhnreact_env 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - pip=20.1.1=py_1 7 | - python=3.7 8 | - rdkit=2021.03.3 9 | - pip: 10 | - jupyterlab 11 | - scikit-learn==0.23.1 12 | - scipy==1.4 13 | - torch==1.6.0 14 | - torchvision==0.7.0 15 | - tqdm 16 | - git+https://github.com/connorcoley/rdchiral.git@01cca8c7f5b0946187f47928738730040de06b16#egg=rdchiral 17 | --------------------------------------------------------------------------------