├── .dockerignore ├── .gitignore ├── .gitlab-ci.yml ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── askcos ├── __init__.py ├── application │ ├── __init__.py │ └── run.py ├── global_config.py ├── interfaces │ ├── __init__.py │ ├── context_recommender.py │ ├── forward_enumerator.py │ ├── scorer.py │ └── template_transformer.py ├── prioritization │ ├── __init__.py │ ├── contexts │ │ ├── __init__.py │ │ ├── probability.py │ │ └── rank.py │ ├── default.py │ ├── precursors │ │ ├── __init__.py │ │ ├── heuristic.py │ │ ├── mincost.py │ │ ├── precursors_test.py │ │ ├── relevanceheuristic.py │ │ └── scscore.py │ ├── prioritizer.py │ └── templates │ │ ├── __init__.py │ │ ├── popularity.py │ │ ├── relevance.py │ │ ├── relevance_test.py │ │ └── test_data │ │ ├── relevance_01.pkl │ │ └── relevance_02.pkl ├── retrosynthetic │ ├── __init__.py │ ├── mcts │ │ ├── __init__.py │ │ ├── nodes.py │ │ ├── test_smiles.txt │ │ ├── tree_builder.py │ │ ├── tree_builder_test.py │ │ ├── utils.py │ │ └── v2 │ │ │ ├── __init__.py │ │ │ ├── tree_builder.py │ │ │ └── tree_builder_test.py │ ├── pathway_ranker │ │ ├── __init__.py │ │ ├── model.py │ │ ├── pathway_ranker.py │ │ ├── pathway_ranker_test.py │ │ ├── test_data │ │ │ └── test_trees.json │ │ └── utils.py │ ├── results.py │ ├── transformer.py │ └── transformer_test.py ├── synthetic │ ├── __init__.py │ ├── atom_mapper │ │ ├── __init__.py │ │ └── wln_mapper.py │ ├── context │ │ ├── __init__.py │ │ ├── nearestneighbor.py │ │ ├── neuralnetwork.py │ │ ├── neuralnetwork_test.py │ │ └── v2 │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── db.py │ │ │ ├── evaluate.py │ │ │ ├── graph_util.py │ │ │ ├── graph_util_test.py │ │ │ ├── preprocess_reagent_group.py │ │ │ ├── preprocess_reagent_group_test.py │ │ │ ├── reaction_context_predictor.py │ │ │ ├── results_preprocess.py │ │ │ ├── search.py │ │ │ └── smiles_util.py │ ├── descriptors │ │ ├── descriptors.py │ │ ├── featurization.py │ │ ├── ffn.py │ │ ├── model.py │ │ ├── mpn.py │ │ └── nn_utils.py │ ├── enumeration │ │ ├── __init__.py │ │ ├── results.py │ │ ├── test_data │ │ │ ├── 100.pkl │ │ │ ├── 200.pkl │ │ │ ├── 300.pkl │ │ │ ├── 400.pkl │ │ │ ├── 500.pkl │ │ │ ├── 600.pkl │ │ │ ├── 700.pkl │ │ │ ├── 800.pkl │ │ │ └── 900.pkl │ │ ├── transformer.py │ │ └── transformer_test.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── evaluation_test.py │ │ ├── evaluator.py │ │ ├── fast_filter.py │ │ ├── rexgen_direct │ │ │ ├── __init__.py │ │ │ ├── core_wln_global │ │ │ │ ├── __init__.py │ │ │ │ ├── directcorefinder.py │ │ │ │ ├── ioutils_direct.py │ │ │ │ ├── models.py │ │ │ │ ├── mol_graph.py │ │ │ │ ├── mol_graph_rich.py │ │ │ │ ├── mol_graph_test.py │ │ │ │ ├── nn.py │ │ │ │ └── test_data │ │ │ │ │ ├── mg_smiles2graph_list.pkl │ │ │ │ │ └── mgr_smiles2graph_list.pkl │ │ │ ├── eval_by_smiles.py │ │ │ ├── predict.py │ │ │ ├── predict_test.py │ │ │ ├── rank_diff_wln │ │ │ │ ├── __init__.py │ │ │ │ ├── directcandranker.py │ │ │ │ ├── edit_mol_direct_useScores.py │ │ │ │ ├── models.py │ │ │ │ ├── mol_graph_direct_useScores.py │ │ │ │ └── nn.py │ │ │ └── test_data │ │ │ │ └── predict.pkl │ │ ├── rexgen_release │ │ │ ├── CandRanker │ │ │ │ ├── __init__.py │ │ │ │ ├── cand_ranker.py │ │ │ │ ├── edit_mol.py │ │ │ │ ├── models.py │ │ │ │ ├── mol_graph.py │ │ │ │ ├── mol_graph_test.py │ │ │ │ └── test_data │ │ │ │ │ └── CR_smiles2graph.pkl │ │ │ ├── CoreFinder │ │ │ │ ├── __init__.py │ │ │ │ ├── core_finder.py │ │ │ │ ├── ioutils.py │ │ │ │ ├── models.py │ │ │ │ ├── mol_graph.py │ │ │ │ ├── mol_graph_test.py │ │ │ │ └── test_data │ │ │ │ │ └── CF_smiles2graph_batch.pkl │ │ │ ├── __init__.py │ │ │ ├── predict.py │ │ │ ├── predict_test.py │ │ │ ├── test_data │ │ │ │ └── predict.pkl │ │ │ ├── uspto_samples.txt │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ └── nn.py │ │ ├── template_based.py │ │ ├── template_based_aux.py │ │ ├── template_free.py │ │ ├── test_data │ │ │ ├── evaluator.pkl │ │ │ └── template_free.pkl │ │ └── tree_evaluator.py │ ├── impurity │ │ ├── __init__.py │ │ └── impurity_predictor.py │ └── selectivity │ │ ├── __init__.py │ │ ├── electronegs.py │ │ ├── general_model │ │ ├── __init__.py │ │ ├── data_loading.py │ │ ├── layers.py │ │ ├── loss.py │ │ ├── models.py │ │ ├── qm_layers.py │ │ └── qm_models.py │ │ ├── general_selectivity.py │ │ ├── general_selectivity_test.py │ │ ├── ioutils_direct.py │ │ ├── mol_graph.py │ │ ├── multitask_model.py │ │ ├── site_selectivity.py │ │ ├── site_selectivity_test.py │ │ └── task_dict.pkl └── utilities │ ├── __init__.py │ ├── atoms.py │ ├── banned │ ├── __init__.py │ ├── banned_list.json │ ├── banned_names.txt │ └── prepare_list.py │ ├── buyable │ ├── __init__.py │ ├── pricer.py │ └── pricer_test.py │ ├── canonicalization.py │ ├── cluster.py │ ├── conditions.py │ ├── contexts.py │ ├── descriptors.py │ ├── fingerprinting.py │ ├── formats.py │ ├── historian │ ├── __init__.py │ ├── chemicals.py │ ├── chemicals_test.py │ ├── reactions.py │ └── test_data │ │ └── chemicals.json.gz │ ├── io │ ├── __init__.py │ ├── arg_parser.py │ ├── draw.py │ ├── draw_test.py │ ├── files.py │ ├── logger.py │ ├── model_loader.py │ ├── name_parser.py │ ├── pickle.py │ └── test_data │ │ ├── draw_retro_test_rxn_string.png │ │ ├── draw_test_rxn_string.png │ │ └── draw_transform.png │ ├── outcomes.py │ ├── parsing.py │ ├── reactants.py │ ├── strings.py │ ├── template_extractor.py │ ├── threadsafe.py │ └── with_dummy.py ├── docs ├── Makefile ├── conf.py ├── index.rst ├── make.bat └── modules.rst ├── environment.yml └── requirements.txt /.dockerignore: -------------------------------------------------------------------------------- 1 | **.git 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Python modules 2 | *.pyc 3 | *.so 4 | *.pyd 5 | 6 | # Bad large data 7 | output/* 8 | askcos_site/static/admin/* 9 | 10 | # artfacts from using JupyterLab 11 | *ipynb_checkpoints* 12 | 13 | # ignore SSL key 14 | deploy/askcos.ssl.cert 15 | deploy/askcos.ssl.key 16 | 17 | # ignore pycharm folder 18 | .idea/ 19 | 20 | # directory for gitlab pages source files 21 | pages/ 22 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | image: docker:stable 2 | 3 | services: 4 | - docker:dind 5 | 6 | variables: 7 | DOCKER_TLS_CERTDIR: "/certs" 8 | 9 | before_script: 10 | - apk add git make 11 | - docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY 12 | - docker pull $CI_REGISTRY_IMAGE:dev || true # pull dev image for build cache 13 | 14 | build-dev: # build dev branch for continuous deployment 15 | stage: build 16 | tags: 17 | - askcos 18 | script: 19 | - make push VERSION=$(git describe --tags --always) REGISTRY=$CI_REGISTRY_IMAGE TAG=dev DATA_VERSION=dev 20 | only: 21 | - dev 22 | 23 | build-latest: # build latest version of the master branch 24 | stage: build 25 | tags: 26 | - askcos 27 | script: 28 | - make push VERSION=$(git describe --tags --abbrev=0) REGISTRY=$CI_REGISTRY_IMAGE TAG=latest DATA_VERSION=latest 29 | only: 30 | - master 31 | 32 | build-release: # build all releases, as determined by tags 33 | stage: build 34 | tags: 35 | - askcos 36 | script: 37 | - make push VERSION=$CI_COMMIT_TAG REGISTRY=$CI_REGISTRY_IMAGE TAG=$CI_COMMIT_TAG DATA_VERSION=$CI_COMMIT_TAG 38 | only: 39 | - tags 40 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG BASE_VERSION=2020.03.6-gh2855-py37-conda 2 | ARG DATA_VERSION=dev 3 | 4 | FROM askcos/askcos-data:$DATA_VERSION as data 5 | FROM askcos/askcos-base:$BASE_VERSION 6 | 7 | USER root 8 | 9 | RUN conda install pytorch=1.4=cpu_py37h7e40bad_0 && \ 10 | find /opt/conda/ -follow -type f -name '*.a' -delete && \ 11 | find /opt/conda/ -follow -type f -name '*.js.map' -delete && \ 12 | /opt/conda/bin/conda clean -afy 13 | 14 | USER askcos 15 | 16 | COPY --chown=askcos:askcos --from=data /data /usr/local/askcos-core/askcos/data 17 | COPY --chown=askcos:askcos . /usr/local/askcos-core 18 | 19 | WORKDIR /home/askcos 20 | 21 | ENV PYTHONPATH=/usr/local/askcos-core${PYTHONPATH:+:${PYTHONPATH}} 22 | 23 | LABEL core.version={VERSION} \ 24 | core.git.hash={GIT_HASH} \ 25 | core.git.date={GIT_DATE} \ 26 | core.git.describe={GIT_DESCRIBE} 27 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Makefile for ASKCOS 4 | # 5 | ################################################################################ 6 | 7 | .PHONY: build debug push test setup_pages 8 | 9 | VERSION ?= latest 10 | GIT_HASH := $(shell git log -1 --format='format:%H') 11 | GIT_DATE := $(shell git log -1 --format='format:%cs') 12 | GIT_DESCRIBE := $(shell git describe --tags --always --dirty) 13 | 14 | REGISTRY ?= askcos/askcos-core 15 | TAG ?= $(VERSION) 16 | DATA_VERSION ?= $(VERSION) 17 | 18 | main build: 19 | @echo Building docker image: $(REGISTRY):$(TAG) 20 | @sed \ 21 | -e 's/{VERSION}/$(VERSION)/g' \ 22 | -e 's/{GIT_HASH}/$(GIT_HASH)/g' \ 23 | -e 's/{GIT_DATE}/$(GIT_DATE)/g' \ 24 | -e 's/{GIT_DESCRIBE}/$(GIT_DESCRIBE)/g' \ 25 | Dockerfile | docker build -t $(REGISTRY):$(TAG) \ 26 | --build-arg DATA_VERSION=$(DATA_VERSION) \ 27 | -f - . 28 | 29 | build_ci: 30 | @echo Building docker image: $(REGISTRY):$(TAG) 31 | @sed \ 32 | -e 's/{VERSION}/$(VERSION)/g' \ 33 | -e 's/{GIT_HASH}/$(GIT_HASH)/g' \ 34 | -e 's/{GIT_DATE}/$(GIT_DATE)/g' \ 35 | -e 's/{GIT_DESCRIBE}/$(GIT_DESCRIBE)/g' \ 36 | Dockerfile | docker build -t $(REGISTRY):$(TAG) \ 37 | --cache-from $(REGISTRY):dev \ 38 | --build-arg BUILDKIT_INLINE_CACHE=1 \ 39 | --build-arg DATA_VERSION=$(DATA_VERSION) \ 40 | -f - . 41 | 42 | push: build_ci 43 | @docker push $(REGISTRY):$(TAG) 44 | 45 | debug: 46 | docker run -it --rm -w /usr/local/askcos-core $(VOLUMES) $(REGISTRY):$(TAG) /bin/bash 47 | 48 | test: 49 | docker run --rm -w /usr/local/askcos-core $(VOLUMES) $(REGISTRY):$(TAG) python -m unittest discover -v -p '*test.py' -s askcos 50 | docker run --rm -w /usr/local/askcos-core $(VOLUMES) $(REGISTRY):$(TAG) python -m unittest -v askcos.retrosynthetic.mcts.tree_builder_test 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # askcos-core 2 | [![askcos-base](https://img.shields.io/badge/-askcos--base-lightgray?style=flat-square)](https://github.com/ASKCOS/askcos-base) 3 | [![askcos-data](https://img.shields.io/badge/-askcos--data-lightgray?style=flat-square)](https://github.com/ASKCOS/askcos-data) 4 | [![askcos-core](https://img.shields.io/badge/-askcos--core-blue?style=flat-square)](https://github.com/ASKCOS/askcos-core) 5 | [![askcos-site](https://img.shields.io/badge/-askcos--site-lightgray?style=flat-square)](https://github.com/ASKCOS/askcos-site) 6 | [![askcos-deploy](https://img.shields.io/badge/-askcos--deploy-lightgray?style=flat-square)](https://github.com/ASKCOS/askcos-deploy) 7 | 8 | Python package for the prediction of feasible synthetic routes towards a desired compound and associated tasks related to synthesis planning. Originally developed under the DARPA Make-It program and now being developed under the [MLPDS Consortium](http://mlpds.mit.edu). 9 | 10 | ## Getting Started 11 | 12 | This package can be used on its own as a normal Python package without deploying the full ASKCOS application. To do so, make sure that the project directory is on your `PYTHONPATH` and that the dependencies listed in `requirements.txt` are satisfied. The data and models must be downloaded separately from the [`askcos-data`](https://github.com/ASKCOS/askcos-data) repository and placed in `askcos-core/askcos/data`. 13 | 14 | ### Building a Docker Image 15 | 16 | The `askcos-core` image can be built using the Dockerfile in this repository. It depends on the `askcos-data` Docker image, which can be built manually or pulled from Docker Hub. 17 | 18 | ```bash 19 | $ cd askcos-core 20 | $ docker build -t . 21 | ``` 22 | 23 | A Makefile is also provided to simplify the build command by providing a default image name and tag: 24 | 25 | ```bash 26 | $ cd askcos-core 27 | $ make build 28 | ``` 29 | 30 | ### How To Run Individual Modules 31 | Many of the individual modules -- at least the ones that are the most interesting -- can be run "standalone". Examples of how to use them are often found in the ```if __name__ == '__main__'``` statement at the bottom of the script definitions. For example... 32 | 33 | Using the learned synthetic complexity metric (SCScore): 34 | ``` 35 | askcos/prioritization/precursors/scscore.py 36 | ``` 37 | 38 | Obtaining a single-step retrosynthetic suggestion with consideration of chirality: 39 | ``` 40 | askcos/retrosynthetic/transformer.py 41 | ``` 42 | 43 | Finding recommended reaction conditions based on a trained neural network model: 44 | ``` 45 | askcos/synthetic/context/neuralnetwork.py 46 | ``` 47 | 48 | Using the template-free forward predictor: 49 | ``` 50 | askcos/synthetic/evaluation/template_free.py 51 | ``` 52 | 53 | Using the coarse "fast filter" (binary classifier) for evaluating reaction plausibility: 54 | ``` 55 | askcos/synthetic/evaluation/fast_filter.py 56 | ``` 57 | 58 | Predicting regio-selectivity for a given atom-mapped reaction: 59 | ``` 60 | askcos/synthetic/selectivity/general_selectivity.py 61 | ``` 62 | 63 | Predicting reactivity descriptors for a given list of molecules: 64 | ``` 65 | askcos/synthetic/descriptors/descriptors.py 66 | ``` 67 | -------------------------------------------------------------------------------- /askcos/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /askcos/application/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /askcos/interfaces/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /askcos/interfaces/context_recommender.py: -------------------------------------------------------------------------------- 1 | class ContextRecommender(object): 2 | """Interface for context recommender classes.""" 3 | def __init__(self): 4 | raise NotImplementedError 5 | 6 | def get_n_contexts(self): 7 | raise NotImplementedError 8 | 9 | def get_top_context(self): 10 | raise NotImplementedError 11 | -------------------------------------------------------------------------------- /askcos/interfaces/forward_enumerator.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class ForwardEnumerator(object): 4 | """Interface for forward enumeration classes. 5 | 6 | At least an initialization method and a ``get_outcomes`` of the enumeration 7 | should be present. 8 | """ 9 | def __init__(self): 10 | raise NotImplementedError 11 | 12 | def get_outcomes(self, smiles): 13 | """Enumerates the possible products for a given smiles.""" 14 | raise NotImplementedError 15 | -------------------------------------------------------------------------------- /askcos/interfaces/scorer.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class Scorer(): 4 | """Interface for scorer classes.""" 5 | def __init__(self, args, kwargs={}): 6 | raise NotImplementedError 7 | 8 | def evaluate(self, args, kwargs={}): 9 | """Scores reactions with different contexts. 10 | 11 | Implemented method should return: 12 | 13 | * A list of results (one for each context). 14 | 15 | Each result should contain: 16 | 17 | * A list of outcomes, of which each outcome is a dictionnary 18 | containing: 19 | 20 | * rank 21 | * forward result 22 | * score 23 | * probability 24 | """ 25 | raise NotImplementedError 26 | def stop_expansion(self): 27 | """Method to kill all spun up workers.""" 28 | raise NotImplementedError 29 | -------------------------------------------------------------------------------- /askcos/prioritization/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /askcos/prioritization/contexts/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /askcos/prioritization/contexts/probability.py: -------------------------------------------------------------------------------- 1 | from askcos.prioritization.prioritizer import Prioritizer 2 | import rdkit.Chem as Chem 3 | from rdkit.Chem import AllChem 4 | import numpy as np 5 | from askcos.utilities.io.logger import MyLogger 6 | probability_context_prioritizer_loc = 'probability_context_prioritizer' 7 | 8 | 9 | class ProbabilityContextPrioritizer(Prioritizer): 10 | """A context prioritizer that prioritizes on probability.""" 11 | def __init__(self): 12 | """Initializes ProbabilityContextPrioritizer.""" 13 | pass 14 | def get_priority(self, outcomes): 15 | """Gets priority of outcomes based on probability. 16 | 17 | Args: 18 | Outcomes (dict??): Direct outcome of calling the single reaction 19 | evaluator. 20 | 21 | Returns: 22 | float: Priority score (probability) of given outcomes. 23 | """ 24 | return sorted(outcomes, key=lambda z: z['target']['prob'], reverse=True) 25 | 26 | def load_model(self): 27 | """Loads probability model. 28 | 29 | ProbabilityContextPrioritizer does not use a neural network, so this 30 | does nothing. 31 | """ 32 | pass 33 | -------------------------------------------------------------------------------- /askcos/prioritization/contexts/rank.py: -------------------------------------------------------------------------------- 1 | from askcos.prioritization.prioritizer import Prioritizer 2 | import rdkit.Chem as Chem 3 | from rdkit.Chem import AllChem 4 | import numpy as np 5 | from askcos.utilities.io.logger import MyLogger 6 | rank_context_prioritizer_loc = 'rank_context_prioritizer' 7 | 8 | 9 | class RankContextPrioritizer(Prioritizer): 10 | """A context prioritizer that prioritizes on rank.""" 11 | def __init__(self): 12 | """Initializes RankContextPrioritizer.""" 13 | pass 14 | def get_priority(self, outcomes): 15 | """Gets priority of outcomes based on rank. 16 | 17 | Args: 18 | Outcomes (dict??): Direct outcome of calling the single reaction 19 | evaluator. 20 | 21 | Returns: 22 | float: Priority score (rank) of given outcomes. 23 | """ 24 | return sorted(outcomes, key=lambda z: z['target']['rank'], reverse=True) 25 | 26 | def load_model(self): 27 | """Loads rank model. 28 | 29 | RankContextPrioritizer does not use a neural network, so this does 30 | nothing. 31 | """ 32 | pass 33 | -------------------------------------------------------------------------------- /askcos/prioritization/default.py: -------------------------------------------------------------------------------- 1 | from askcos.prioritization.prioritizer import Prioritizer 2 | from askcos.utilities.io.logger import MyLogger 3 | default_prioritizer_loc = 'default_prioritizer' 4 | 5 | 6 | class DefaultPrioritizer(Prioritizer): 7 | """A default Prioritizer that assigns the same priority to everything. 8 | 9 | Will not prioritize the objects in any way. Returns a priority of 1 for any 10 | object. Can therefore be used both for (non)-prioritization of templates 11 | and/or retro-synthetic precursors. 12 | """ 13 | 14 | def __init__(self): 15 | """Initializes DefaultPrioritizer.""" 16 | pass 17 | def get_priority(self, object_to_prioritize, **kwargs): 18 | """Returns priority of given object. 19 | 20 | If object is a tuple (for prioritization of a template), then return the 21 | first element (the list of templates). 22 | Otherwise returns 1. 23 | 24 | Args: 25 | object_to_prioritize (2-tuple or ??): Object to prioritize. 26 | **kwargs: Unused. 27 | 28 | Returns: 29 | list or float: Unsorted input list of templates or 1.0 30 | """ 31 | try: 32 | (templates, target) = object_to_prioritize 33 | return templates 34 | # if not a tuple: prioritization of a retro-precursor element. 35 | except TypeError: 36 | return 1.0 37 | 38 | def load_model(self): 39 | """Loads default model. 40 | 41 | DefaultPrioritizer does not use a neural network, so this does nothing. 42 | """ 43 | pass 44 | -------------------------------------------------------------------------------- /askcos/prioritization/precursors/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /askcos/prioritization/precursors/heuristic.py: -------------------------------------------------------------------------------- 1 | from askcos.prioritization.prioritizer import Prioritizer 2 | import rdkit.Chem as Chem 3 | from rdkit.Chem import AllChem 4 | import numpy as np 5 | from askcos.utilities.buyable.pricer import Pricer 6 | from askcos.utilities.io.logger import MyLogger 7 | heuristic_precursor_prioritizer_loc = 'heuristic_precursor_prioritizer' 8 | 9 | 10 | class HeuristicPrecursorPrioritizer(Prioritizer): 11 | """A precursor Prioritizer that uses a heuristic scoring function. 12 | 13 | Attributes: 14 | pricer (Pricer or None): Used to look up chemical prices. 15 | """ 16 | def __init__(self): 17 | """Initializes HeuristicPrecursorPrioritizer.""" 18 | self.pricer = None 19 | self._loaded = False 20 | 21 | def get_priority(self, retroPrecursor, **kwargs): 22 | """Gets priority of given precursor based on heuristic function. 23 | 24 | Args: 25 | retroPrecursor (RetroPrecursor): Precursor to calculate priority of. 26 | **kwargs: Unused. 27 | 28 | Returns: 29 | float: Priority score of precursor. 30 | """ 31 | if not self._loaded: 32 | self.load_model() 33 | 34 | necessary_reagent_atoms = retroPrecursor.necessary_reagent.count('[') / 2. 35 | scores = [] 36 | for smiles in retroPrecursor.smiles_list: 37 | # If buyable, basically free 38 | ppg = self.pricer.lookup_smiles(smiles, alreadyCanonical=True) 39 | if ppg: 40 | scores.append(- ppg / 5.0) 41 | continue 42 | 43 | # Else, use heuristic 44 | x = Chem.MolFromSmiles(smiles) 45 | total_atoms = x.GetNumHeavyAtoms() 46 | ring_bonds = sum([b.IsInRing() - b.GetIsAromatic() 47 | for b in x.GetBonds()]) 48 | chiral_centers = len(Chem.FindMolChiralCenters(x)) 49 | 50 | scores.append( 51 | - 2.00 * np.power(total_atoms, 1.5) 52 | - 1.00 * np.power(ring_bonds, 1.5) 53 | - 2.00 * np.power(chiral_centers, 2.0) 54 | ) 55 | 56 | return np.sum(scores) - 4.00 * np.power(necessary_reagent_atoms, 2.0) 57 | 58 | def load_model(self): 59 | """Loads the Pricer used in the heuristic priority scoring.""" 60 | self.pricer = Pricer() 61 | self.pricer.load() 62 | self._loaded = True 63 | -------------------------------------------------------------------------------- /askcos/prioritization/precursors/precursors_test.py: -------------------------------------------------------------------------------- 1 | import askcos.prioritization.precursors.scscore as sc 2 | import askcos.prioritization.precursors.mincost as mc 3 | import unittest 4 | 5 | class TestSCScore(unittest.TestCase): 6 | def setUp(self): 7 | self.model = sc.SCScorePrecursorPrioritizer() 8 | 9 | def test_01_get_score_from_smiles_1024bool(self): 10 | self.model.load_model(model_tag='1024bool') 11 | result = self.model.get_score_from_smiles('CCCOCCC', noprice=True) 12 | expected = 1.43226093752 13 | self.assertLess(abs(expected - result), 1e-4) 14 | 15 | def test_02_get_score_from_smiles_1024bool(self): 16 | self.model.load_model(model_tag='1024bool') 17 | result = self.model.get_score_from_smiles('CCCC', noprice=True) 18 | expected = 1.32272217424 19 | self.assertLess(abs(expected - result), 1e-4) 20 | 21 | def test_03_get_priority_2048bool(self): 22 | self.model.load_model(model_tag='2048bool', FP_len=2048) 23 | result = self.model.get_priority('CCCOCCC', noprice=False) 24 | expected = -0.01 25 | self.assertLess(abs(expected - result), 1e-4) 26 | 27 | def test_04_get_priority_2048bool(self): 28 | self.model.load_model(model_tag='2048bool', FP_len=2048) 29 | result = self.model.get_priority('CCCNc1ccccc1', noprice=False) 30 | expected = -0.45 31 | self.assertLess(abs(expected - result), 1e-4) 32 | 33 | def test_05_get_priority_1024uint8(self): 34 | self.model.load_model(model_tag='1024uint8') 35 | result = self.model.get_priority('CCCOCCC', noprice=False) 36 | expected = -0.01 37 | self.assertLess(abs(expected - result), 1e-4) 38 | 39 | def test_06_get_priority_1024uint8(self): 40 | self.model.load_model(model_tag='1024uint8') 41 | result = self.model.get_priority('CCCNc1ccccc1', noprice=False) 42 | expected = -0.45 43 | self.assertLess(abs(expected - result), 1e-4) 44 | 45 | class TestMincost(unittest.TestCase): 46 | def setUp(self): 47 | self.model = mc.MinCostPrecursorPrioritizer() 48 | self.model.load_model() 49 | 50 | def test_01_get_priority(self): 51 | result = self.model.get_priority('CC(=O)N1C=C(C=C2N=C(N(N=CC3C=CC=CC=3)C(C)=O)N(C(C)=O)C2=O)C2=CC=CC=C21') 52 | expected = 4.01056028509 53 | self.assertLess(abs(expected - result), 1e-4) 54 | 55 | def test_02_get_priority(self): 56 | result = self.model.get_priority('CCCNc1ccccc1') 57 | expected = 0.0 58 | self.assertLess(abs(expected - result), 1e-4) 59 | 60 | if __name__ == '__main__': 61 | res = unittest.main(verbosity=3, exit=False) 62 | -------------------------------------------------------------------------------- /askcos/prioritization/precursors/relevanceheuristic.py: -------------------------------------------------------------------------------- 1 | from askcos.prioritization.prioritizer import Prioritizer 2 | import rdkit.Chem as Chem 3 | from rdkit.Chem import AllChem 4 | import numpy as np 5 | from askcos.utilities.buyable.pricer import Pricer 6 | from askcos.utilities.io.logger import MyLogger 7 | heuristic_precursor_prioritizer_loc = 'relevanceheuristic_precursor_prioritizer' 8 | 9 | 10 | class RelevanceHeuristicPrecursorPrioritizer(Prioritizer): 11 | """A precursor Prioritizer that uses a heuristic and template relevance. 12 | 13 | Attributes: 14 | pricer (Pricer or None): Used to look up chemical prices. 15 | """ 16 | def __init__(self): 17 | """Initializes RelevanceHeuristicPrecursorPrioritizer.""" 18 | self.pricer = None 19 | self._loaded = False 20 | 21 | def score_precursor(self, precursor): 22 | """Score a given precursor using a combination of the template relevance score and a heuristic rule 23 | 24 | Args: 25 | precursor (dict): dictionary of precursor to score 26 | 27 | Returns: 28 | float: combined relevance heuristic score of precursor 29 | """ 30 | scores = [] 31 | necessary_reagent_atoms = precursor['necessary_reagent'].count('[')/2. 32 | for smiles in precursor['smiles_split']: 33 | ppg = self.pricer.lookup_smiles(smiles, alreadyCanonical=True) 34 | # If buyable, basically free 35 | if ppg: 36 | scores.append(- ppg / 1000.0) 37 | continue 38 | 39 | # Else, use heuristic 40 | mol = Chem.MolFromSmiles(smiles) 41 | total_atoms = mol.GetNumHeavyAtoms() 42 | ring_bonds = sum([b.IsInRing() - b.GetIsAromatic() 43 | for b in mol.GetBonds()]) 44 | chiral_centers = len(Chem.FindMolChiralCenters(mol)) 45 | 46 | scores.append( 47 | - 2.00 * np.power(total_atoms, 1.5) 48 | - 1.00 * np.power(ring_bonds, 1.5) 49 | - 2.00 * np.power(chiral_centers, 2.0) 50 | ) 51 | 52 | sco = np.sum(scores) - 4.00 * np.power(necessary_reagent_atoms, 2.0) 53 | return sco / precursor['template_score'] 54 | 55 | 56 | def reorder_precursors(self, precursors): 57 | """Reorder a list of precursors by their newly computed combined relevance heuristic score 58 | 59 | Args: 60 | precursors (list of dict) 61 | 62 | Returns: 63 | list: reordered list of precursor dictionaries with new 'score' and 'rank' keys 64 | """ 65 | scores = np.array([self.score_precursor(p) for p in precursors]) 66 | indices = np.argsort(-scores) 67 | scores = scores[indices] 68 | result = [] 69 | rank = 1 70 | for i, score in zip(indices, scores): 71 | result.append(precursors[i]) 72 | result[-1]['score'] = score 73 | result[-1]['rank'] = rank 74 | rank += 1 75 | return result 76 | 77 | def get_priority(self, retroPrecursor, **kwargs): 78 | """Gets priority of given precursor based on heuristic and relevance. 79 | 80 | Args: 81 | retroPrecursor (RetroPrecursor): Precursor to calculate priority of. 82 | **kwargs: Unused. 83 | 84 | Returns: 85 | float: Priority score of precursor. 86 | """ 87 | if not self._loaded: 88 | self.load_model() 89 | 90 | necessary_reagent_atoms = retroPrecursor.necessary_reagent.count('[') / 2. 91 | scores = [] 92 | for smiles in retroPrecursor.smiles_list: 93 | # If buyable, basically free 94 | ppg = self.pricer.lookup_smiles(smiles, alreadyCanonical=True) 95 | if ppg: 96 | scores.append(- ppg / 1000.0) 97 | continue 98 | 99 | # Else, use heuristic 100 | x = Chem.MolFromSmiles(smiles) 101 | total_atoms = x.GetNumHeavyAtoms() 102 | ring_bonds = sum([b.IsInRing() - b.GetIsAromatic() 103 | for b in x.GetBonds()]) 104 | chiral_centers = len(Chem.FindMolChiralCenters(x)) 105 | 106 | scores.append( 107 | - 2.00 * np.power(total_atoms, 1.5) 108 | - 1.00 * np.power(ring_bonds, 1.5) 109 | - 2.00 * np.power(chiral_centers, 2.0) 110 | ) 111 | 112 | sco = np.sum(scores) - 4.00 * np.power(necessary_reagent_atoms, 2.0) 113 | return sco / retroPrecursor.template_score 114 | 115 | def load_model(self): 116 | """Loads the Pricer used in the heuristic priority scoring.""" 117 | self.pricer = Pricer() 118 | self.pricer.load() 119 | self._loaded = True 120 | -------------------------------------------------------------------------------- /askcos/prioritization/prioritizer.py: -------------------------------------------------------------------------------- 1 | class Prioritizer(object): 2 | """Base class for prioritizers.""" 3 | 4 | def __init__(self): 5 | """Initializes Prioritizer.""" 6 | raise NotImplementedError 7 | 8 | def get_priority(self, object_to_prioritize, **kwargs): 9 | """Gets the priority for an object that can be prioritized. 10 | 11 | The object is either a retro-synthetic precursor or a tuple of 12 | a template and a target. 13 | """ 14 | raise NotImplementedError 15 | 16 | def load_model(self): 17 | """ 18 | If a neural network model is used to determine the priority, load it! 19 | """ 20 | raise NotImplementedError 21 | 22 | def set_max_templates(self, max): 23 | """Sets maximum number of templates to be included in output.""" 24 | self.template_count = max 25 | 26 | def set_max_cum_prob(self, max): 27 | """Sets maximum cumulative probability of output.""" 28 | self.max_cum_prob = max 29 | -------------------------------------------------------------------------------- /askcos/prioritization/templates/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /askcos/prioritization/templates/popularity.py: -------------------------------------------------------------------------------- 1 | from askcos.prioritization.prioritizer import Prioritizer 2 | from askcos.utilities.io.logger import MyLogger 3 | popularity_template_prioritizer_loc = 'popularity_template_prioritizer' 4 | 5 | 6 | class PopularityTemplatePrioritizer(Prioritizer): 7 | """A template Prioritizer ordering by popularity. 8 | 9 | Allows to prioritize a template based on the number of times it appears, 10 | or its reported popularity. 11 | 12 | Attributes: 13 | sorted (bool): Whether the templates have been sorted yet. 14 | reordered_templates (list of ??): Templates ordered by popularity. 15 | template_count (int): Number of templates to return the priority of. 16 | max_cum_prob (float): Maximum cumulative probability of returned 17 | tempates. Unused. 18 | """ 19 | 20 | def __init__(self, no_log=False): 21 | """Initializes PopularityTemplatePrioritizer. 22 | 23 | Args: 24 | no_log (bool, optional): Whether to not log. Unused. 25 | (default: {False}) 26 | """ 27 | # only do 'reorder' once 28 | self.sorted = False 29 | self.reordered_templates = None 30 | self.template_count = 1e9 31 | self.max_cum_prob = 1 32 | 33 | def get_priority(self, input_tuple, **kwargs): 34 | """Returns list of templates ordered by popularity. 35 | 36 | Args: 37 | input_tuple (2-tuple of (list of ??, ??)): Templates to get the 38 | priority of. 39 | **kwargs: Additional optional parameters. Used for template_count. 40 | """ 41 | (templates, target) = input_tuple 42 | return self.reorder(templates)[:min(len(templates), kwargs.get('template_count', 1e10))] 43 | 44 | def reorder(self, templates): 45 | """Reorders templates by popularity. 46 | 47 | Re-orders the list of templates according to field 'count' in descending 48 | order. This means we will apply the most popular templates first. 49 | 50 | Args: 51 | templates (list of ??): Unordered templates to be reordered by 52 | popularity. 53 | 54 | Returns: 55 | list of ??: Templates sorted by popularity. 56 | """ 57 | if self.sorted: 58 | return self.reordered_templates 59 | else: 60 | templates[:] = [x for x in sorted( 61 | templates, key=lambda z: z['count'], reverse=True)] 62 | self.sorted = True 63 | for template in templates: 64 | template['score'] = 1 65 | self.reordered_templates = templates 66 | return templates 67 | 68 | def load_model(self): 69 | """Loads popularity model. 70 | 71 | PopularityTemplatePrioritizer does not use a neural network, so this 72 | does nothing. 73 | """ 74 | pass 75 | -------------------------------------------------------------------------------- /askcos/prioritization/templates/relevance.py: -------------------------------------------------------------------------------- 1 | import askcos.global_config as gc 2 | from askcos.prioritization.prioritizer import Prioritizer 3 | import rdkit.Chem as Chem 4 | from rdkit.Chem import AllChem 5 | import numpy as np 6 | from askcos.utilities.io.logger import MyLogger 7 | import tensorflow as tf 8 | from scipy.special import softmax 9 | import requests 10 | 11 | relevance_template_prioritizer_loc = 'relevance_template_prioritizer' 12 | 13 | 14 | class RelevanceTemplatePrioritizer(Prioritizer): 15 | """A template Prioritizer based on template relevance. 16 | 17 | Attributes: 18 | fp_length (int): Fingerprint length. 19 | fp_radius (int): Fingerprint radius. 20 | """ 21 | 22 | def __init__(self, fp_length=2048, fp_radius=2): 23 | self.fp_length = fp_length 24 | self.fp_radius = fp_radius 25 | 26 | def load_model(self, model_path=gc.RELEVANCE_TEMPLATE_PRIORITIZATION['reaxys']['model_path'], **kwargs): 27 | """Loads a model to predict template priority. 28 | 29 | Args: 30 | model_path (str): Path to keras saved model to be loaded with 31 | tf.keras.models.load_model. **kwargs are passed to load_model 32 | to allow loading of custom_objects 33 | 34 | """ 35 | self.model = tf.keras.models.load_model(model_path, **kwargs) 36 | 37 | def smiles_to_fp(self, smiles): 38 | """Converts SMILES string to fingerprint for use with template relevance model. 39 | 40 | Args: 41 | smiles (str): SMILES string to convert to fingerprint 42 | 43 | Returns: 44 | np.ndarray of np.float32: Fingerprint for given SMILES string. 45 | 46 | """ 47 | mol = Chem.MolFromSmiles(smiles) 48 | if not mol: 49 | return np.zeros((self.fp_length,), dtype=np.float32) 50 | return np.array( 51 | AllChem.GetMorganFingerprintAsBitVect( 52 | mol, self.fp_radius, nBits=self.fp_length, useChirality=True 53 | ), dtype=np.float32 54 | ) 55 | 56 | def predict(self, smiles, max_num_templates=None, max_cum_prob=None): 57 | """Predicts template priority given a SMILES string. 58 | 59 | Args: 60 | smiles (str): SMILES string of input molecule 61 | 62 | Returns: 63 | (scores, indices): np.ndarrays of scores and indices for 64 | prioritized templates 65 | max_num_templates (int, optional): maximum number of templates to 66 | return {default = None} 67 | max_cum_prob (float, optional): maximum cumulative probability of 68 | template relvance scores. This is used to limit the number of 69 | templates that get returned {default = None} 70 | """ 71 | fp = self.smiles_to_fp(smiles).reshape(1, -1) 72 | scores = self.model.predict(fp).reshape(-1) 73 | scores = softmax(scores) 74 | indices = np.argsort(-scores) 75 | scores = scores[indices] 76 | 77 | if max_num_templates is not None: 78 | indices = indices[:max_num_templates] 79 | scores = scores[:max_num_templates] 80 | 81 | if max_cum_prob is not None: 82 | cum_scores = np.cumsum(scores) 83 | scores = scores[cum_scores <= max_cum_prob] 84 | indices = indices[cum_scores <= max_cum_prob] 85 | 86 | return scores, indices 87 | 88 | 89 | if __name__ == '__main__': 90 | model = RelevanceTemplatePrioritizer() 91 | model.load_model() 92 | smis = ['CCCOCCC', 'CCCNc1ccccc1'] 93 | for smi in smis: 94 | lst = model.predict(smi) 95 | print('{} -> {}'.format(smi, lst)) 96 | -------------------------------------------------------------------------------- /askcos/prioritization/templates/relevance_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import unittest 4 | 5 | import numpy as np 6 | 7 | import askcos.prioritization.templates.relevance as rel 8 | 9 | 10 | class TestTemplateRelevance(unittest.TestCase): 11 | 12 | @classmethod 13 | def setUpClass(cls): 14 | """This method is run once before all tests in this class.""" 15 | cls.model = rel.RelevanceTemplatePrioritizer() 16 | cls.model.load_model() 17 | 18 | def test_01_get_topk_from_smi(self): 19 | """Test that the template relevance model returns the expected result for CCCOCCC""" 20 | scores, indices = self.model.predict('CCCOCCC', 100, 0.995) 21 | 22 | with open(os.path.join(os.path.dirname(__file__), 'test_data/relevance_01.pkl'), 'rb') as t: 23 | expected = pickle.load(t) 24 | 25 | self.assertEqual(len(expected[0]), len(scores)) 26 | self.assertEqual(len(expected[1]), len(indices)) 27 | self.assertTrue(np.allclose(expected[0], scores)) 28 | self.assertTrue(np.array_equal(expected[1], indices)) 29 | 30 | def test_02_get_topk_from_smi(self): 31 | """Test that the template relevance model returns the expected result for CCCNc1ccccc1""" 32 | scores, indices = self.model.predict('CCCNc1ccccc1', 100, 0.995) 33 | 34 | with open(os.path.join(os.path.dirname(__file__), 'test_data/relevance_02.pkl'), 'rb') as t: 35 | expected = pickle.load(t) 36 | 37 | self.assertEqual(len(expected[0]), len(scores)) 38 | self.assertEqual(len(expected[1]), len(indices)) 39 | self.assertTrue(np.allclose(expected[0], scores)) 40 | self.assertTrue(np.array_equal(expected[1], indices)) 41 | 42 | 43 | if __name__ == '__main__': 44 | res = unittest.main(verbosity=3, exit=False) 45 | -------------------------------------------------------------------------------- /askcos/prioritization/templates/test_data/relevance_01.pkl: -------------------------------------------------------------------------------- 1 | ((lp0 2 | F0.09677810966968536 3 | aF0.08329660445451736 4 | aF0.07952464371919632 5 | aF0.06772172451019287 6 | aF0.06128274276852608 7 | aF0.0597919262945652 8 | aF0.05465768277645111 9 | aF0.054280247539281845 10 | aF0.04802464693784714 11 | aF0.04177231341600418 12 | aF0.0395895317196846 13 | aF0.03497838228940964 14 | aF0.01637987792491913 15 | aF0.016112003475427628 16 | aF0.014277194626629353 17 | aF0.012704286724328995 18 | aF0.011427170597016811 19 | aF0.010880488902330399 20 | aF0.00848681852221489 21 | aF0.008081144653260708 22 | aF0.007859496399760246 23 | aF0.006901880726218224 24 | aF0.006405259016901255 25 | aF0.006056389771401882 26 | aF0.0059202020056545734 27 | aF0.005715108476579189 28 | aF0.005421617534011602 29 | aF0.005261822137981653 30 | aF0.005040853284299374 31 | aF0.004885892849415541 32 | aF0.004815707448869944 33 | aF0.0045751966536045074 34 | aF0.004032340366393328 35 | aF0.003815937787294388 36 | aF0.003402352100238204 37 | aF0.0033518215641379356 38 | aF0.00317523255944252 39 | aF0.003115095430985093 40 | aF0.0028273549396544695 41 | aF0.0026916363276541233 42 | aF0.0026126625016331673 43 | aF0.002598047722131014 44 | aF0.0025925629306584597 45 | aF0.002455216832458973 46 | aF0.00240224227309227 47 | aF0.002354211173951626 48 | aF0.002309975912794471 49 | aF0.0021989874076098204 50 | aF0.0020490288734436035 51 | aF0.002043849090114236 52 | aF0.001881008385680616 53 | aF0.0014458741061389446 54 | aF0.0013986441772431135 55 | aF0.001383198774419725 56 | aF0.0013400414027273655 57 | aF0.0013277606340125203 58 | aF0.0012347960146144032 59 | aF0.0012329978635534644 60 | aF0.0012080983724445105 61 | aF0.0011732496786862612 62 | aF0.0011555271921679378 63 | aF0.001039374154061079 64 | aF0.001000710530206561 65 | aF0.0009830804774537683 66 | aF0.000977381831035018 67 | aF0.0009753872873261571 68 | aF0.0009257919155061245 69 | aF0.0009101273026317358 70 | aF0.0009028138010762632 71 | aF0.0008777166367508471 72 | aF0.000823114940430969 73 | aF0.0007540630758740008 74 | aF0.0007486902759410441 75 | aF0.000739811104722321 76 | aF0.0007287916378118098 77 | aF0.0007082631927914917 78 | aF0.0006440930301323533 79 | aF0.0006421218276955187 80 | aF0.0006109128589741886 81 | aF0.0006083407788537443 82 | aF0.0005644322372972965 83 | aF0.0005635866546072066 84 | aF0.0005442055407911539 85 | aF0.0005365016404539347 86 | aF0.0005219129379838705 87 | aF0.000518576882313937 88 | aF0.0005170923541299999 89 | aF0.000508149154484272 90 | aF0.0004891224089078605 91 | aF0.0004839212342631072 92 | aF0.00046193087473511696 93 | aF0.0004615433281287551 94 | aF0.00045213394332677126 95 | aF0.000448366510681808 96 | aF0.0004176677030045539 97 | aF0.00041570625035092235 98 | aF0.00040827973862178624 99 | aF0.0004060911596752703 100 | aF0.00040436291601508856 101 | aF0.0004032246069982648 102 | a(lp1 103 | I585 104 | aI43039 105 | aI394 106 | aI19845 107 | aI16079 108 | aI79828 109 | aI1505 110 | aI18864 111 | aI6143 112 | aI1473 113 | aI61743 114 | aI3008 115 | aI5363 116 | aI2243 117 | aI35097 118 | aI114029 119 | aI3169 120 | aI98306 121 | aI93826 122 | aI34635 123 | aI18061 124 | aI2862 125 | aI16589 126 | aI88476 127 | aI5318 128 | aI7689 129 | aI4592 130 | aI16864 131 | aI3411 132 | aI20319 133 | aI8504 134 | aI42949 135 | aI3026 136 | aI11760 137 | aI1327 138 | aI8748 139 | aI12665 140 | aI20312 141 | aI75758 142 | aI68 143 | aI60682 144 | aI68821 145 | aI2290 146 | aI15631 147 | aI81679 148 | aI11847 149 | aI104480 150 | aI12044 151 | aI6660 152 | aI10755 153 | aI41444 154 | aI2759 155 | aI1543 156 | aI21177 157 | aI15806 158 | aI17242 159 | aI193 160 | aI106159 161 | aI86592 162 | aI39292 163 | aI29099 164 | aI6190 165 | aI18759 166 | aI25083 167 | aI15304 168 | aI3473 169 | aI20316 170 | aI16192 171 | aI14399 172 | aI24307 173 | aI57862 174 | aI14186 175 | aI10552 176 | aI1665 177 | aI3684 178 | aI23284 179 | aI13624 180 | aI785 181 | aI42981 182 | aI2680 183 | aI53514 184 | aI116382 185 | aI28795 186 | aI38238 187 | aI11668 188 | aI94180 189 | aI57291 190 | aI2066 191 | aI10335 192 | aI24192 193 | aI69932 194 | aI31777 195 | aI14504 196 | aI33053 197 | aI74867 198 | aI10632 199 | aI15578 200 | aI96329 201 | aI37917 202 | aI754 203 | atp2 204 | . -------------------------------------------------------------------------------- /askcos/prioritization/templates/test_data/relevance_02.pkl: -------------------------------------------------------------------------------- 1 | ((lp0 2 | F0.12483948469161987 3 | aF0.09675677120685577 4 | aF0.07252208888530731 5 | aF0.06345923244953156 6 | aF0.055957842618227005 7 | aF0.05251732096076012 8 | aF0.05059691518545151 9 | aF0.047397688031196594 10 | aF0.038129810243844986 11 | aF0.035854779183864594 12 | aF0.03571247309446335 13 | aF0.031145170331001282 14 | aF0.02715172991156578 15 | aF0.023004816845059395 16 | aF0.01872534118592739 17 | aF0.014951656572520733 18 | aF0.014644996263086796 19 | aF0.013297339901328087 20 | aF0.0124376080930233 21 | aF0.009947249665856361 22 | aF0.009258760139346123 23 | aF0.009112609550356865 24 | aF0.008426773361861706 25 | aF0.007360807154327631 26 | aF0.006890494842082262 27 | aF0.006440803408622742 28 | aF0.006276105064898729 29 | aF0.005383905488997698 30 | aF0.005243632011115551 31 | aF0.004729327280074358 32 | aF0.004109365399926901 33 | aF0.0035438123159110546 34 | aF0.0034864158369600773 35 | aF0.003446843009442091 36 | aF0.003419393440708518 37 | aF0.0032661266159266233 38 | aF0.0030272542499005795 39 | aF0.0028793849050998688 40 | aF0.002565617673099041 41 | aF0.002472592517733574 42 | aF0.001711010467261076 43 | aF0.0016758631682023406 44 | aF0.00162817956879735 45 | aF0.0014764101943001151 46 | aF0.001463010790757835 47 | aF0.0014109652256593108 48 | aF0.0014075757935643196 49 | aF0.0013878429308533669 50 | aF0.0012881041038781404 51 | aF0.0012560184113681316 52 | aF0.0012178672477602959 53 | aF0.0012085025664418936 54 | aF0.0011294559808447957 55 | aF0.0011197818676009774 56 | aF0.0011008285218849778 57 | aF0.0010941301006823778 58 | aF0.0010294823441654444 59 | aF0.0010062656365334988 60 | aF0.0010017863241955638 61 | aF0.0009930713567882776 62 | aF0.0009583045612089336 63 | aF0.0009503121254965663 64 | aF0.0009332725894637406 65 | aF0.0009263189858756959 66 | aF0.0009127791272476315 67 | aF0.0008720188052393496 68 | aF0.0008697713492438197 69 | aF0.0007891117129474878 70 | aF0.0007621290860697627 71 | aF0.000734313449356705 72 | aF0.0007241848506964743 73 | aF0.0006525802309624851 74 | aF0.0006003291346132755 75 | aF0.0005674161366187036 76 | aF0.0005656332359649241 77 | aF0.0005342494114302099 78 | aF0.0005266602383926511 79 | aF0.0005134622915647924 80 | aF0.0004936910117976367 81 | aF0.00048470028559677303 82 | aF0.00045273316209204495 83 | aF0.00045251648407429457 84 | aF0.00040294736390933394 85 | aF0.0003987659583799541 86 | aF0.00038217438850551844 87 | aF0.0003801097918767482 88 | aF0.00036254976294003427 89 | aF0.00035696636768989265 90 | aF0.00035655402461998165 91 | aF0.00035193320945836604 92 | aF0.0003482778265606612 93 | aF0.0003447774797677994 94 | aF0.000343221181537956 95 | aF0.00032383535290136933 96 | aF0.00032274023396894336 97 | aF0.00032139307586476207 98 | aF0.0003197123296558857 99 | aF0.00031562556978315115 100 | aF0.00029040040681138635 101 | aF0.00028849905356764793 102 | a(lp1 103 | I3571 104 | aI98952 105 | aI8636 106 | aI943 107 | aI305 108 | aI38 109 | aI595 110 | aI290 111 | aI20872 112 | aI818 113 | aI434 114 | aI2900 115 | aI28874 116 | aI25533 117 | aI623 118 | aI39045 119 | aI5733 120 | aI81940 121 | aI11429 122 | aI31575 123 | aI15026 124 | aI6977 125 | aI7331 126 | aI67312 127 | aI26749 128 | aI34951 129 | aI213 130 | aI23355 131 | aI71383 132 | aI67735 133 | aI74558 134 | aI3547 135 | aI1960 136 | aI3660 137 | aI10673 138 | aI7461 139 | aI829 140 | aI102262 141 | aI69404 142 | aI5907 143 | aI6313 144 | aI4066 145 | aI10005 146 | aI649 147 | aI99818 148 | aI55306 149 | aI61663 150 | aI83735 151 | aI4449 152 | aI4142 153 | aI13470 154 | aI193 155 | aI39521 156 | aI7507 157 | aI147 158 | aI239 159 | aI86620 160 | aI99819 161 | aI31292 162 | aI4279 163 | aI12887 164 | aI1183 165 | aI903 166 | aI54499 167 | aI6263 168 | aI65 169 | aI761 170 | aI53218 171 | aI2070 172 | aI51339 173 | aI567 174 | aI77675 175 | aI5612 176 | aI592 177 | aI86435 178 | aI73874 179 | aI8613 180 | aI4619 181 | aI6010 182 | aI13432 183 | aI16420 184 | aI156 185 | aI125117 186 | aI59835 187 | aI40833 188 | aI47743 189 | aI78415 190 | aI23677 191 | aI31551 192 | aI28388 193 | aI3941 194 | aI30935 195 | aI39254 196 | aI496 197 | aI3593 198 | aI28342 199 | aI14442 200 | aI64450 201 | aI9128 202 | aI12782 203 | atp2 204 | . -------------------------------------------------------------------------------- /askcos/retrosynthetic/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /askcos/retrosynthetic/mcts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/retrosynthetic/mcts/__init__.py -------------------------------------------------------------------------------- /askcos/retrosynthetic/mcts/v2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/retrosynthetic/mcts/v2/__init__.py -------------------------------------------------------------------------------- /askcos/retrosynthetic/pathway_ranker/__init__.py: -------------------------------------------------------------------------------- 1 | from .pathway_ranker import PathwayRanker 2 | -------------------------------------------------------------------------------- /askcos/retrosynthetic/pathway_ranker/pathway_ranker_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import unittest 4 | 5 | import torch 6 | 7 | from askcos.retrosynthetic.pathway_ranker.pathway_ranker import PathwayRanker 8 | from askcos.retrosynthetic.pathway_ranker.utils import convert_askcos_trees 9 | 10 | 11 | class TestPathwayRanker(unittest.TestCase): 12 | """Contains functional tests for the PathwayRanker class.""" 13 | 14 | @classmethod 15 | def setUpClass(cls): 16 | with open(os.path.join(os.path.dirname(__file__), 'test_data', 'test_trees.json'), 'r') as f: 17 | cls.trees = json.load(f) 18 | 19 | def test_preprocess(self): 20 | """Test the preprocess method.""" 21 | output = convert_askcos_trees(self.trees) 22 | original_indices, remaining_trees = zip(*((i, tree) for i, tree in enumerate(output) if tree['depth'] > 1)) 23 | 24 | ranker = PathwayRanker() 25 | batch = ranker.preprocess(remaining_trees) 26 | 27 | self.assertEqual(original_indices, (2, 3, 4)) 28 | 29 | self.assertIn('pfp', batch) 30 | self.assertIsInstance(batch['pfp'], torch.Tensor) 31 | self.assertIn('rxnfp', batch) 32 | self.assertIsInstance(batch['rxnfp'], torch.Tensor) 33 | self.assertIn('node_order', batch) 34 | self.assertIsInstance(batch['node_order'], torch.Tensor) 35 | self.assertIn('adjacency_list', batch) 36 | self.assertIsInstance(batch['adjacency_list'], torch.Tensor) 37 | self.assertIn('edge_order', batch) 38 | self.assertIsInstance(batch['edge_order'], torch.Tensor) 39 | self.assertEqual(batch['num_nodes'], [2, 5, 4]) 40 | self.assertEqual(batch['num_trees'], [1, 1, 1]) 41 | self.assertEqual(batch['batch_size'], 3) 42 | 43 | def test_scorer(self): 44 | """Test the scorer method.""" 45 | ranker = PathwayRanker() 46 | ranker.load() 47 | 48 | output = ranker.scorer(self.trees, clustering=True) 49 | 50 | self.assertIn('scores', output) 51 | self.assertEqual(len(output['scores']), 5) 52 | self.assertEqual(output['scores'][0], -1) 53 | self.assertEqual(output['scores'][1], -1) 54 | 55 | self.assertIn('encoded_trees', output) 56 | self.assertEqual(len(output['encoded_trees']), 5) 57 | self.assertEqual(len(output['encoded_trees'][0]), 0) 58 | self.assertEqual(len(output['encoded_trees'][1]), 0) 59 | self.assertEqual(len(output['encoded_trees'][2]), 512) 60 | self.assertEqual(len(output['encoded_trees'][3]), 512) 61 | self.assertEqual(len(output['encoded_trees'][4]), 512) 62 | 63 | self.assertIn('clusters', output) 64 | self.assertEqual(output['clusters'], [-1, -1, 0, 1, 2]) 65 | 66 | 67 | if __name__ == '__main__': 68 | res = unittest.main(verbosity=3, exit=False) 69 | -------------------------------------------------------------------------------- /askcos/synthetic/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /askcos/synthetic/atom_mapper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/synthetic/atom_mapper/__init__.py -------------------------------------------------------------------------------- /askcos/synthetic/context/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /askcos/synthetic/context/neuralnetwork_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import unittest 4 | 5 | import numpy as np 6 | 7 | import askcos.global_config as gc 8 | import askcos.synthetic.context.neuralnetwork as nn 9 | 10 | 11 | class TestNeuralNetwork(unittest.TestCase): 12 | 13 | def test_01_get_n_conditions(self): 14 | """Test that NeuralNetContextRecommender.get_n_conditions gives the expected result.""" 15 | cont = nn.NeuralNetContextRecommender() 16 | cont.load_nn_model( 17 | model_path=gc.NEURALNET_CONTEXT_REC['model_path'], 18 | info_path=gc.NEURALNET_CONTEXT_REC['info_path'], 19 | weights_path=gc.NEURALNET_CONTEXT_REC['weights_path'] 20 | ) 21 | result, scores = cont.get_n_conditions('CC1(C)OBOC1(C)C.Cc1ccc(Br)cc1>>Cc1cccc(B2OC(C)(C)C(C)(C)O2)c1', 10, 22 | with_smiles=False, return_scores=True) 23 | 24 | nan = float('nan') 25 | expected_result = [ 26 | [102.30387878417969, 'C1COCCO1', 'CCN(CC)CC', "Reaxys Name (1,1'-bis(diphenylphosphino)ferrocene)palladium(II) dichloride", nan, nan, None, False], 27 | [104.92787170410156, 'C1COCCO1', 'CCN(CC)CC', 'Cl[Pd](Cl)([P](c1ccccc1)(c1ccccc1)c1ccccc1)[P](c1ccccc1)(c1ccccc1)c1ccccc1', nan, nan, None, False], 28 | [99.1409912109375, 'Cc1ccccc1', 'CCN(CC)CC', 'Cl[Pd](Cl)([P](c1ccccc1)(c1ccccc1)c1ccccc1)[P](c1ccccc1)(c1ccccc1)c1ccccc1', nan, nan, None, False], 29 | [76.38555908203125, 'C1CCOC1', 'CCN(CC)CC', 'Cl[Pd](Cl)([P](c1ccccc1)(c1ccccc1)c1ccccc1)[P](c1ccccc1)(c1ccccc1)c1ccccc1', nan, nan, None, False], 30 | [95.92562103271484, 'Cc1ccccc1', 'CCN(CC)CC', "Reaxys Name (1,1'-bis(diphenylphosphino)ferrocene)palladium(II) dichloride", nan, nan, None, False], 31 | [75.68882751464844, 'C1CCOC1', 'CCN(CC)CC', "Reaxys Name (1,1'-bis(diphenylphosphino)ferrocene)palladium(II) dichloride", nan, nan, None, False], 32 | [93.39191436767578, 'C1COCCO1', '', "Reaxys Name (1,1'-bis(diphenylphosphino)ferrocene)palladium(II) dichloride", nan, nan, None, False], 33 | [97.8741226196289, 'C1COCCO1', 'CC(=O)[O-].[K+]', "Reaxys Name (1,1'-bis(diphenylphosphino)ferrocene)palladium(II) dichloride", nan, nan, None, False], 34 | [95.84452819824219, 'C1COCCO1', '[MgH2]', 'Cl[Pd](Cl)([P](c1ccccc1)(c1ccccc1)c1ccccc1)[P](c1ccccc1)(c1ccccc1)c1ccccc1', nan, nan, None, False], 35 | [67.86063385009766, 'C1CCOC1', '[MgH2]', 'Cl[Pd](Cl)([P](c1ccccc1)(c1ccccc1)c1ccccc1)[P](c1ccccc1)(c1ccccc1)c1ccccc1', nan, nan, None, False], 36 | ] 37 | expected_scores = [0.19758703, 0.09385002, 0.0320574, 0.026747962, 0.024693565, 0.010140889, 0.0048135854, 0.004163743, 0.002136398, 0.0018915414] 38 | 39 | for e, r in zip(expected_result, result): 40 | self.assertEqual(len(e), len(r)) 41 | self.assertAlmostEqual(e[0], r[0], places=4) 42 | self.assertEqual(e[1:4], r[1:4]) 43 | self.assertEqual(e[7:], r[7:]) 44 | 45 | self.assertTrue(np.allclose(expected_scores, scores)) 46 | 47 | 48 | if __name__ == '__main__': 49 | res = unittest.main(verbosity=3, exit=False) 50 | -------------------------------------------------------------------------------- /askcos/synthetic/context/v2/README.md: -------------------------------------------------------------------------------- 1 | # Prerequisite 2 | - Tensorflow 2.x 3 | - numpy 4 | - rdkit >= 2019 5 | - python >= 3.6 6 | 7 | # Evaluate 8 | ``` 9 | # examples 10 | python -m askcos.synthetic.context.v2.reaction_context_predictor.reaction_context_predictor 11 | 12 | # one reaction SMILES per line in input.txt 13 | python -m askcos.synthetic.context.v2.reaction_context_predictor.evaluate input.txt output.txt 14 | ``` 15 | -------------------------------------------------------------------------------- /askcos/synthetic/context/v2/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /askcos/synthetic/context/v2/db.py: -------------------------------------------------------------------------------- 1 | import askcos.global_config as gc 2 | 3 | config_reaction_condition = gc.CONTEXT_V2 4 | -------------------------------------------------------------------------------- /askcos/synthetic/context/v2/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | import json 5 | from . import reaction_context_predictor 6 | 7 | 8 | predictor = reaction_context_predictor.ReactionContextRecommenderWLN() 9 | #predictor = reaction_context_predictor.ReactionContextRecommenderFP() 10 | 11 | fn = sys.argv[1] 12 | fn_out = sys.argv[2] 13 | f = open(fn, 'r') 14 | f_out = open(fn_out, 'w') 15 | print('input file: ', f.name) 16 | print('output file: ', f_out.name) 17 | 18 | for s in f: 19 | s = s.strip() 20 | #try: 21 | results = predictor.predict(smiles=s) 22 | #except Exception as e: 23 | # print(e) 24 | # results = None 25 | record = {'smiles':s, 'condition':results} 26 | f_out.write(json.dumps(record)+'\n') 27 | 28 | f_out.close() 29 | f.close() 30 | -------------------------------------------------------------------------------- /askcos/synthetic/context/v2/preprocess_reagent_group.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import json 5 | import copy 6 | import math 7 | import rdkit 8 | from rdkit import Chem 9 | from rdkit.Chem import AllChem 10 | from rdkit import RDLogger 11 | 12 | from . import db 13 | from . import smiles_util 14 | 15 | # These rules convert ions into their connected neutral molecular form. 16 | # This is neccessary for spiltting reagents. 17 | reagent_conv_rules = None 18 | if reagent_conv_rules is None: 19 | with open(db.config_reaction_condition['reagent_conv_rules'], 'r') as f: 20 | _reagent_conv_rules = json.load(f) 21 | reagent_conv_rules = {} 22 | for k, v in _reagent_conv_rules.items(): 23 | s = smiles_util.canonicalize_smiles(k) 24 | if k is not None: 25 | reagent_conv_rules[s] = v 26 | 27 | def canonicalize_smiles_reagent_conv_rules(s): 28 | s_can = smiles_util.canonicalize_smiles(s) 29 | r = reagent_conv_rules.get(s_can, None) 30 | if r is not None: 31 | return r 32 | else: 33 | return s 34 | 35 | # SMILES cannot be parsed by RDKit. If RDKit can parse them correctly, they can be removed. 36 | SMILES2CHARGE = { 37 | 'Br[Br-]Br':-1, 38 | '[BH]1234[H][BH]561[BH]173[BH]34([H]2)[BH]247[BH]761[BH]165[H][BH]586[BH]32([BH]4715)[H]8':0, 39 | '[BH3][S](C)C':0, 40 | '[BH3-][S+](C)C':0, 41 | '[BH3][O]1CCCC1':0, 42 | 'CC1CCCCN1[BH3]':0, 43 | } 44 | 45 | def get_smiles_charge(s): 46 | q = 0 47 | for i in s.split('.'): 48 | q += get_smiles_charge_single_mol(i) 49 | return q 50 | 51 | def get_smiles_charge_single_mol(s): 52 | q = SMILES2CHARGE.get(s, None) 53 | if q is None: 54 | try: 55 | mol = Chem.MolFromSmiles(canonicalize_smiles_reagent_conv_rules(s), sanitize=False) 56 | except: 57 | raise ValueError('get_smiles_charge(): MolFromSmiles() fails, s='+s) 58 | if mol is None: 59 | raise ValueError('get_smiles_charge(): MolFromSmiles() fails, s='+s) 60 | q = Chem.rdmolops.GetFormalCharge(mol) 61 | return q 62 | 63 | def split_neutral_fragment(reagents): 64 | reagents_neutral = set() 65 | reagents_charged = [] 66 | for r in reagents: 67 | r_split = r.split('.') 68 | r_remaining = [] 69 | for s in r_split: 70 | q = get_smiles_charge(s) 71 | if int(q) == 0: 72 | reagents_neutral.add(s) 73 | else: 74 | r_remaining.append(s) 75 | if len(r_remaining) > 0: 76 | r_remaining = '.'.join(r_remaining) 77 | q = get_smiles_charge(r_remaining) 78 | if int(q) == 0: 79 | reagents_neutral.add(r_remaining) 80 | else: 81 | reagents_charged.append(r_remaining) 82 | return reagents_neutral, reagents_charged 83 | 84 | def preprocess_reagents(reagents): 85 | ''' 86 | inputs: list of str, smiles 87 | outputs: list of str, smiles 88 | Rules: 89 | 1. Neutral molecules are splitted from compounds 90 | 2. Try to combine separated charged species 91 | 3. Canonicalization using hardcoded rules 92 | ''' 93 | assert isinstance(reagents, list) 94 | for i in range(len(reagents)): 95 | reagents[i] = canonicalize_smiles_reagent_conv_rules(reagents[i]) 96 | 97 | # Rule 1, split neutral 98 | reagents_neutral, reagents_charged = split_neutral_fragment(reagents) 99 | 100 | # Rule 2, combine charged, reagents_charged --> reagents_neutral 101 | # q for smiles in reagents_charged 102 | charges = [get_smiles_charge(s) for s in reagents_charged] 103 | # sanity check 104 | # check 1, total charge 0 105 | total_charge = sum(charges) 106 | if total_charge != 0: 107 | print('reagents: ', reagents) 108 | print('reagents_neutral: ', reagents_neutral) 109 | print('reagents_charged: ', reagents_charged) 110 | raise ValueError('preprocess_reagents(): total charge is not zero, q='+str(total_charge)) 111 | if len(reagents_charged) > 0: 112 | reagents_neutral.add(smiles_util.canonicalize_smiles('.'.join(reagents_charged))) 113 | reagents_neutral = list(reagents_neutral) 114 | 115 | # Rule 3, Canonicalization, replace using reagent_conv_rules.json 116 | res = set() 117 | for i in reagents_neutral: 118 | tmp = canonicalize_smiles_reagent_conv_rules(i) 119 | tmp1, tmp2 = split_neutral_fragment([tmp]) 120 | if len(tmp2) != 0: 121 | sys.stderr.write('preprocess_reagents(): error: charged fragment, s='+str(reagents)+'\n') 122 | for s in tmp1: 123 | res.add(s) 124 | 125 | return list(res) 126 | -------------------------------------------------------------------------------- /askcos/synthetic/context/v2/preprocess_reagent_group_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from . import preprocess_reagent_group 4 | 5 | 6 | class Test(unittest.TestCase): 7 | def test_01(self): 8 | r = ['CO.[Na+].[OH-]','CO.O'] 9 | result = preprocess_reagent_group.preprocess_reagents(r) 10 | result.sort() 11 | expected = ['CO', 'O', 'O[Na]'] 12 | self.assertEqual(result, expected) 13 | 14 | def test_02(self): 15 | r = ['[Li+].[AlH4-]'] 16 | result = preprocess_reagent_group.preprocess_reagents(r) 17 | result.sort() 18 | expected = ['[Li][AlH4]'] 19 | self.assertEqual(result, expected) 20 | 21 | def test_03(self): 22 | r = ["[Cl-].[Cl-].[Ce+2]"] 23 | result = preprocess_reagent_group.preprocess_reagents(r) 24 | result.sort() 25 | expected = ["Cl[Ce]Cl"] 26 | self.assertEqual(result, expected) 27 | 28 | def test_04(self): 29 | r = ["[OH-].[NH4+]"] 30 | result = preprocess_reagent_group.preprocess_reagents(r) 31 | result.sort() 32 | expected = ["N", "O"] 33 | self.assertEqual(result, expected) 34 | 35 | 36 | if __name__ == '__main__': 37 | res = unittest.main(verbosity=3, exit=False) 38 | -------------------------------------------------------------------------------- /askcos/synthetic/context/v2/results_preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import numpy as np 5 | 6 | from . import preprocess_reagent_group 7 | from . import smiles_util 8 | 9 | def generate_reagents_encoder2(all_reagents): 10 | enc = {None:0} 11 | cnt = 0 12 | for r in all_reagents: 13 | s = smiles_util.canonicalize_smiles(r) 14 | if s is None: 15 | raise RuntimeError("generate_reagents_encoder2: cannot load reagent list.") 16 | if enc.get(s, None) is None: 17 | cnt += 1 18 | enc[s] = cnt 19 | return enc 20 | 21 | def prepare_reagents2(encoder, reagents): 22 | ''' 23 | encoder: {None:0, 'smiles1':1} 24 | reagents: [{'smiles':'xx', 'conc':1}, None, None] 25 | ''' 26 | if not isinstance(reagents, list): 27 | reagents = [reagents] 28 | 29 | valid_reagents = [] 30 | for r in reagents: 31 | if r['smiles'] is None or r['smiles'] is 'missing': 32 | continue 33 | valid_reagents.append(r['smiles']) 34 | 35 | valid_reagents = preprocess_reagent_group.preprocess_reagents(valid_reagents) 36 | res = np.zeros((len(valid_reagents), len(encoder)), dtype=np.float32) 37 | for i in range(len(valid_reagents)): 38 | idx = encoder.get(smiles_util.canonicalize_smiles(valid_reagents[i]), None) 39 | if idx is None: 40 | sys.stderr.write('prepare_reagents2(): encoder missing smiles='+valid_reagents[i]+'\n') 41 | else: 42 | res[i, idx] = 1 43 | 44 | return res 45 | 46 | def convert_onehots_to_multiclass(onehots): 47 | ''' 48 | onehots: (nclass, len_onehots) 49 | return: (len_onehots) 50 | ''' 51 | res = np.sum(onehots, axis=0, keepdims=False, dtype=np.float32) 52 | res = np.where(res > 0, np.float32(1.0), np.float32(0.0)) 53 | return res 54 | -------------------------------------------------------------------------------- /askcos/synthetic/context/v2/search.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | def remove_duplicates_results(results): 5 | ''' 6 | res = [(reagents_onehot, score)] 7 | The highest score is kept 8 | ''' 9 | tmp = {} 10 | for r in results: 11 | key = tuple(r[0].astype('int32').flatten().tolist()) 12 | val = r[1] 13 | tmp[key] = max(tmp.get(key, -np.inf), val) 14 | res = [] 15 | for k, v in tmp.items(): 16 | res.append((np.array(k, dtype='float32'),v)) 17 | return res 18 | 19 | 20 | def top_n_results(_res, n=10, remove_duplicates=True): 21 | ''' 22 | res = [(reagents_onehot, score)] 23 | ''' 24 | if remove_duplicates: 25 | res = remove_duplicates_results(_res) 26 | scores = np.array([i[1] for i in res], dtype='float32') 27 | idx = list(np.flip(np.argsort(scores))) 28 | if len(idx) <= n: 29 | return [res[i] for i in idx] 30 | else: 31 | return [res[i] for i in idx[0:n]] 32 | 33 | 34 | def min_score_results(res): 35 | ''' 36 | res = [(reagents_onehot, score)] 37 | ''' 38 | scores = np.array([i[1] for i in res], dtype='float32') 39 | return np.min(scores) 40 | 41 | 42 | def max_score_results(res): 43 | ''' 44 | res = [(reagents_onehot, score)] 45 | ''' 46 | scores = np.array([i[1] for i in res], dtype='float32') 47 | return np.max(scores) 48 | 49 | 50 | def is_target_in_results(res, tgt): 51 | cnt = 0 52 | for i in res: 53 | if np.all(np.abs(i[0] - tgt) < 1e-6): 54 | return True, cnt 55 | cnt += 1 56 | return False, cnt 57 | 58 | 59 | def score_geometry_average(results): 60 | res = [] 61 | for r in results: 62 | score = r[1] 63 | n = np.sum(r[0])+1 64 | score = score**(1.0/n) 65 | res.append((r[0], score)) 66 | return res 67 | 68 | 69 | def beam_search(model, model_inputs, vacab_size, max_steps=8, beam_size=10, eos_id=0, keepall=False, returnall=True, reagents=None): 70 | ''' 71 | eos_id: end of sequence logit 72 | keepall: keep all top beam_size during search, beam_size^Nstep 73 | returnall: no top_results 74 | ''' 75 | res = [] # [(reagents_onehot, score)] 76 | if reagents is None: 77 | input_reagent = [(np.zeros(shape=(vacab_size), dtype='float32'), 1.0)] 78 | else: 79 | input_reagent = [(reagents, 1.0)] 80 | nstep = 0 81 | while len(input_reagent) != 0 and nstep < max_steps: 82 | new_input_reagent = [] 83 | for r, score in input_reagent: 84 | model_inputs['Input_reagents'] = tf.reshape(tf.convert_to_tensor(r, dtype=tf.float32), shape=(1,vacab_size)) 85 | y_pred = model(**model_inputs)['softmax_1'].numpy() 86 | y_pred_class_num = np.flip(np.argsort(y_pred, axis=-1), axis=-1)[0,0:beam_size] 87 | y_pred_class_score = y_pred[0, y_pred_class_num] 88 | for i in range(len(y_pred_class_num)): 89 | n = y_pred_class_num[i] 90 | s = y_pred_class_score[i]*score 91 | new_r = np.copy(r).flatten() 92 | if n == eos_id: 93 | # move finished 94 | res.append((new_r, s)) 95 | else: 96 | # set predicted reagents 97 | new_r[n] = 1 98 | new_input_reagent.append((new_r, s)) 99 | if keepall: 100 | # remove duplicates, keep max score 101 | input_reagent = remove_duplicates_results(new_input_reagent) 102 | else: 103 | # keep beam_size only 104 | input_reagent = top_n_results(new_input_reagent, n=beam_size, remove_duplicates=True) 105 | # early termination, assume score < 1 106 | if len(res) > beam_size: 107 | res_top = top_n_results(res, n=beam_size, remove_duplicates=True) 108 | res_min = min_score_results(res_top) 109 | s_max = max_score_results(input_reagent) 110 | if s_max < res_min: 111 | break 112 | # increase step count 113 | nstep += 1 114 | if returnall: 115 | res = top_n_results(res, n=len(res), remove_duplicates=True) 116 | else: 117 | res = top_n_results(res, n=beam_size, remove_duplicates=True) 118 | return res 119 | 120 | 121 | if __name__ == "__main__": 122 | a = [(np.array([1,1]), 0.5),(np.array([1,1]), 0.3),(np.array([1,0]), 0.5),(np.array([1,0]), 0.7)] 123 | r = remove_duplicates_results(a) 124 | assert '[(array([1., 1.]), 0.5), (array([1., 0.]), 0.7)]' == str(r) 125 | -------------------------------------------------------------------------------- /askcos/synthetic/context/v2/smiles_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import json 4 | import os 5 | import sys 6 | 7 | import rdkit 8 | from rdkit import Chem 9 | from rdkit.Chem import AllChem 10 | 11 | import numpy as np 12 | 13 | from . import db 14 | 15 | def get_morgan_fp(s, fp_radius, fp_length): 16 | return np.array(AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(s, sanitize=True), fp_radius, nBits=fp_length), dtype='float32') 17 | 18 | def canonicalize_smiles_rdkit(s): 19 | try: 20 | s_can = Chem.MolToSmiles(Chem.MolFromSmiles(s, sanitize=False)) # avoid 'Br[Br-]Br' problem 21 | except: 22 | sys.stderr.write('canonicalize_smiles_rdkit(): fail s='+s+'\n') 23 | s_can = None 24 | return s_can 25 | 26 | def canonicalize_smiles(s): 27 | s = canonicalize_smiles_rdkit(s) 28 | return s 29 | -------------------------------------------------------------------------------- /askcos/synthetic/descriptors/descriptors.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module defines the evaulate function for the reactivity descriptor predictor 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | from rdkit import Chem 8 | 9 | from askcos.synthetic.descriptors.featurization import mol2graph, get_atom_fdim, get_bond_fdim 10 | from askcos import global_config as gc 11 | 12 | model_pt_path = gc.DESCRIPTORS['model_path'] 13 | 14 | class ReactivityDescriptor: 15 | 16 | def __init__(self): 17 | self.device = None 18 | self.model = None 19 | self.initialize() 20 | 21 | def initialize(self): 22 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 23 | 24 | from askcos.synthetic.descriptors.model import MoleculeModel 25 | 26 | # Load model and args 27 | state = torch.load(model_pt_path, lambda storage, loc: storage) 28 | args, loaded_state_dict = state['args'], state['state_dict'] 29 | atom_fdim = get_atom_fdim() 30 | bond_fdim = get_bond_fdim() + atom_fdim 31 | 32 | self.model = MoleculeModel(args, atom_fdim, bond_fdim) 33 | self.model.load_state_dict(loaded_state_dict) 34 | self.model.to(self.device) 35 | self.model.eval() 36 | 37 | self.initalized = True 38 | print('Model file {0} loaded successfully.'.format(model_pt_path)) 39 | 40 | def preprocess(self, smiles): 41 | mol_graph = mol2graph(smiles) 42 | f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, b2br, bond_types = mol_graph.get_components() 43 | f_atoms, f_bonds, a2b, b2a, b2revb, b2br, bond_types = \ 44 | f_atoms.to(self.device), f_bonds.to(self.device), a2b.to(self.device), b2a.to(self.device), \ 45 | b2revb.to(self.device), b2br.to(self.device), bond_types.to(self.device) 46 | 47 | return f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, b2br, bond_types 48 | 49 | def inference(self, data): 50 | descs = self.model(data) 51 | 52 | return descs 53 | 54 | def postprocess(self, smiles, descs): 55 | 56 | descs = [x.data.cpu().numpy() for x in descs] 57 | 58 | partial_charge, partial_neu, partial_elec, NMR, bond_order, bond_distance = descs 59 | 60 | n_atoms, n_bonds = [], [] 61 | for s in smiles: 62 | m = Chem.MolFromSmiles(s) 63 | 64 | m = Chem.AddHs(m) 65 | 66 | n_atoms.append(len(m.GetAtoms())) 67 | n_bonds.append(len(m.GetBonds())) 68 | 69 | partial_charge = [x.tolist() for x in np.split(partial_charge.flatten(), np.cumsum(np.array(n_atoms)))][:-1] 70 | partial_neu = [x.tolist() for x in np.split(partial_neu.flatten(), np.cumsum(np.array(n_atoms)))][:-1] 71 | partial_elec = [x.tolist() for x in np.split(partial_elec.flatten(), np.cumsum(np.array(n_atoms)))][:-1] 72 | NMR = [x.tolist() for x in np.split(NMR.flatten(), np.cumsum(np.array(n_atoms)))][:-1] 73 | 74 | bond_order = [x.tolist() for x in np.split(bond_order.flatten(), np.cumsum(np.array(n_bonds)))][:-1] 75 | bond_distance = [x.tolist() for x in np.split(bond_distance.flatten(), np.cumsum(np.array(n_bonds)))][:-1] 76 | 77 | results = [{'smiles': s, 'partial_charge': pc, 'fukui_neu': pn, 78 | 'fukui_elec': pe, 'NMR': nmr, 'bond_order': bo, 'bond_length': bd} 79 | for s, pc, pn, pe, nmr, bo, bd in zip(smiles, partial_charge, partial_neu, 80 | partial_elec, NMR, bond_order, bond_distance)] 81 | 82 | # FIXME the torch server currently doesn't support batch input, to be consitent only allow single smiles here. 83 | return results[0] 84 | 85 | def evaluate(self, smiles): 86 | smiles = smiles.split('.') 87 | descriptors = self.inference(self.preprocess(smiles)) 88 | result = self.postprocess(smiles, descriptors) 89 | 90 | return result 91 | 92 | 93 | if __name__ == '__main__': 94 | 95 | handler = ReactivityDescriptor() 96 | 97 | data = 'CCCC.CCC.CCCCC' 98 | descriptors = handler.evaluate(data) 99 | 100 | print(descriptors) -------------------------------------------------------------------------------- /askcos/synthetic/descriptors/model.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import torch.nn as nn 4 | 5 | from askcos.synthetic.descriptors.mpn import MPN 6 | from askcos.synthetic.descriptors.ffn import MultiReadout 7 | 8 | class MoleculeModel(nn.Module): 9 | """A MoleculeModel is a model which contains a message passing network following by feed-forward layers.""" 10 | 11 | def __init__(self, args, atom_fdim, bond_fdim): 12 | """ 13 | Initializes the MoleculeModel. 14 | 15 | :param classification: Whether the model is a classification model. 16 | """ 17 | super(MoleculeModel, self).__init__() 18 | self.create_encoder(args, atom_fdim, bond_fdim) 19 | self.create_ffn(args) 20 | 21 | def create_encoder(self, args: Namespace, atom_fdim, bond_fdim): 22 | """ 23 | Creates the message passing encoder for the model. 24 | 25 | :param args: Arguments. 26 | """ 27 | self.encoder = MPN(args, atom_fdim, bond_fdim) 28 | 29 | def create_ffn(self, args: Namespace): 30 | """ 31 | Creates the feed-forward network for the model. 32 | 33 | :param args: Arguments. 34 | """ 35 | 36 | # Create readout layer 37 | self.readout = MultiReadout(args, args.atom_targets, args.bond_targets, 38 | args.atom_constraints, args.bond_constraints) 39 | 40 | def forward(self, input): 41 | """ 42 | Runs the MoleculeModel on input. 43 | 44 | :param input: Input. 45 | :return: The output of the MoleculeModel. 46 | """ 47 | output_all = self.readout(self.encoder(input)) 48 | 49 | return output_all 50 | -------------------------------------------------------------------------------- /askcos/synthetic/descriptors/mpn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 4 | 5 | from argparse import Namespace 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from askcos.synthetic.descriptors.nn_utils import index_select_ND 11 | 12 | 13 | class MPNEncoder(nn.Module): 14 | """A message passing neural network for encoding a molecule.""" 15 | 16 | def __init__(self, args: Namespace, atom_fdim: int, bond_fdim: int): 17 | """Initializes the MPNEncoder. 18 | 19 | :param args: Arguments. 20 | :param atom_fdim: Atom features dimension. 21 | :param bond_fdim: Bond features dimension. 22 | """ 23 | super(MPNEncoder, self).__init__() 24 | self.atom_fdim = atom_fdim 25 | self.bond_fdim = bond_fdim 26 | self.hidden_size = args.hidden_size 27 | self.bias = args.bias 28 | self.depth = args.depth 29 | self.dropout = args.dropout 30 | self.layers_per_message = 1 31 | self.undirected = args.undirected 32 | self.atom_messages = args.atom_messages 33 | self.features_only = args.features_only 34 | self.use_input_features = args.use_input_features 35 | self.args = args 36 | 37 | if self.features_only: 38 | return 39 | 40 | # Dropout 41 | self.dropout_layer = nn.Dropout(p=self.dropout) 42 | 43 | # Activation 44 | self.act_func = nn.ReLU() 45 | 46 | # Cached zeros 47 | self.cached_zero_vector = nn.Parameter(torch.zeros(self.hidden_size), requires_grad=False) 48 | 49 | # Input 50 | input_dim = self.atom_fdim if self.atom_messages else self.bond_fdim 51 | self.W_i = nn.Linear(input_dim, self.hidden_size, bias=self.bias) 52 | 53 | if self.atom_messages: 54 | w_h_input_size = self.hidden_size + self.bond_fdim 55 | else: 56 | w_h_input_size = self.hidden_size 57 | 58 | # Shared weight matrix across depths (default) 59 | self.W_h = nn.Linear(w_h_input_size, self.hidden_size, bias=self.bias) 60 | 61 | # hidden state readout 62 | self.W_o_a = nn.Linear(self.atom_fdim + self.hidden_size, self.hidden_size) 63 | self.W_o_b = nn.Linear(self.bond_fdim + self.hidden_size, self.hidden_size) 64 | 65 | def forward(self, 66 | inputs) -> torch.FloatTensor: 67 | """ 68 | Encodes a batch of molecular graphs. 69 | :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule. 70 | """ 71 | 72 | f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, b2br, bond_types = inputs 73 | 74 | input_feature = self.W_i(f_bonds) # num_bonds x hidden_size 75 | message = self.act_func(input_feature) # num_bonds x hidden_size 76 | 77 | # Message passing 78 | for depth in range(self.depth - 1): 79 | if self.undirected: 80 | message = (message + message[b2revb]) / 2 81 | 82 | 83 | # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1) 84 | # message a_message = sum(nei_a_message) rev_message 85 | nei_a_message = index_select_ND(message, a2b) # num_atoms x max_num_bonds x hidden 86 | a_message = nei_a_message.sum(dim=1) # num_atoms x hidden 87 | rev_message = message[b2revb] # num_bonds x hidden 88 | message = a_message[b2a] - rev_message # num_bonds x hidden 89 | 90 | message = self.W_h(message) 91 | message = self.act_func(input_feature + message) # num_bonds x hidden_size 92 | message = self.dropout_layer(message) # num_bonds x hidden 93 | 94 | # atom hidden 95 | a2x = a2b 96 | nei_a_message = index_select_ND(message, a2x) # num_atoms x max_num_bonds x hidden 97 | a_message = nei_a_message.sum(dim=1) # num_atoms x hidden 98 | a_input = torch.cat([f_atoms, a_message], dim=1) # num_atoms x (atom_fdim + hidden) 99 | atom_hiddens = self.act_func(self.W_o_a(a_input)) # num_atoms x hidden 100 | atom_hiddens = self.dropout_layer(atom_hiddens) # num_atoms x hidden 101 | 102 | # bond hidden 103 | b_input = torch.cat([f_bonds, message], dim=1) 104 | bond_hiddens = self.act_func(self.W_o_b(b_input)) 105 | bond_hiddens = self.dropout_layer(bond_hiddens) 106 | 107 | return atom_hiddens, a_scope, bond_hiddens, b_scope, b2br, bond_types # num_atoms x hidden, remove the first one which is zero padding 108 | 109 | 110 | class MPN(nn.Module): 111 | def __init__(self, args, atom_fdim, bond_fdim): 112 | super(MPN, self).__init__() 113 | self.encoder = MPNEncoder(args, atom_fdim, bond_fdim) 114 | 115 | def forward(self, *inputs): 116 | return self.encoder(*inputs) 117 | -------------------------------------------------------------------------------- /askcos/synthetic/descriptors/nn_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def index_select_ND(source: torch.Tensor, index: torch.Tensor) -> torch.Tensor: 5 | """ 6 | Selects the message features from source corresponding to the atom or bond indices in :code:`index`. 7 | :param source: A tensor of shape :code:`(num_bonds, hidden_size)` containing message features. 8 | :param index: A tensor of shape :code:`(num_atoms/num_bonds, max_num_bonds)` containing the atom or bond 9 | indices to select from :code:`source`. 10 | :return: A tensor of shape :code:`(num_atoms/num_bonds, max_num_bonds, hidden_size)` containing the message 11 | features corresponding to the atoms/bonds specified in index. 12 | """ 13 | index_size = index.size() # (num_atoms/num_bonds, max_num_bonds) 14 | suffix_dim = source.size()[1:] # (hidden_size,) 15 | final_size = index_size + suffix_dim # (num_atoms/num_bonds, max_num_bonds, hidden_size) 16 | 17 | target = source.index_select(dim=0, index=index.view(-1)) # (num_atoms/num_bonds * max_num_bonds, hidden_size) 18 | target = target.view(final_size) # (num_atoms/num_bonds, max_num_bonds, hidden_size) 19 | 20 | return target 21 | -------------------------------------------------------------------------------- /askcos/synthetic/enumeration/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /askcos/synthetic/enumeration/results.py: -------------------------------------------------------------------------------- 1 | import rdkit.Chem as Chem 2 | from rdkit.Chem import AllChem 3 | import numpy as np 4 | 5 | class ForwardResult: 6 | """Results of a one-step forward synthesis. 7 | 8 | Should be used by any type of forward transformer/enumerator to store 9 | results and maintain uniformity! 10 | 11 | Attributes: 12 | smiles (str): SMILES string of ?? 13 | products (list of ForwardProduct): 14 | smiles_to_product (dict {str: int}): 15 | smiles_list_to_product (dict {str: int}): 16 | """ 17 | 18 | def __init__(self, smiles): 19 | """Initalizes ForwardResult. 20 | 21 | Args: 22 | smiles (str): SMILES string of ?? 23 | """ 24 | self.smiles = smiles 25 | self.products = [] 26 | self.smiles_to_product = {} 27 | self.smiles_list_to_product = {} 28 | 29 | def add_product(self, product): 30 | """ 31 | Adds a product to the product set if it is a new product. 32 | Product type is ForwardProduct 33 | """ 34 | # Check if it is new or old 35 | try: 36 | index = self.smiles_to_product[product.smiles] 37 | except KeyError: 38 | try: 39 | index = self.smiles_list_to_product['.'.join(product.smiles_list)] 40 | except KeyError: 41 | #If neither has been encountered: add new product 42 | self.products.append(product) 43 | self.smiles_to_product[product.smiles] = len(self.products) - 1 44 | self.smiles_list_to_product['.'.join(product.smiles_list)] = len(self.products) - 1 45 | return 46 | 47 | self.products[index].template_ids += product.template_ids 48 | self.products[index].num_examples += product.num_examples 49 | 50 | def add_products(self, products): 51 | for product in products: 52 | self.add_product(product) 53 | 54 | def get_products(self): 55 | return self.products 56 | 57 | class ForwardProduct: 58 | """ 59 | A class to store a single forward product for reaction enumeration 60 | """ 61 | def __init__(self, smiles_list = [], smiles = '', template_id = -1, num_examples = 0, 62 | edits = None, template_ids=None): 63 | self.smiles_list = smiles_list 64 | self.smiles = smiles 65 | self.template_ids = [template_id] 66 | if template_ids: 67 | self.template_ids = template_ids 68 | self.num_examples = num_examples 69 | self.edits = edits 70 | 71 | def get_edits(self): 72 | return self.edits 73 | 74 | def get_smiles(self): 75 | return self.smiles 76 | 77 | def as_dict(self): 78 | return { 79 | 'smiles': self.smiles, 80 | 'template_ids': [str(x) for x in self.template_ids], 81 | 'num_examples': self.num_examples, 82 | } 83 | -------------------------------------------------------------------------------- /askcos/synthetic/enumeration/transformer_test.py: -------------------------------------------------------------------------------- 1 | import askcos.synthetic.enumeration.transformer as transformer 2 | from askcos.synthetic.enumeration.results import ForwardResult 3 | import askcos.global_config as gc 4 | import unittest 5 | import os 6 | import sys 7 | is_py2 = sys.version[0] == '2' 8 | if is_py2: 9 | import cPickle as pickle 10 | else: 11 | import pickle as pickle 12 | 13 | 14 | class TestTransformer(unittest.TestCase): 15 | def setUp(self): 16 | self.ft = transformer.ForwardTransformer() 17 | self.ft.load() 18 | 19 | def test_01_get_outcomes(self): 20 | smiles = 'NC(=O)[C@H](CCC=O)N1C(=O)c2ccccc2C1=O' 21 | res = self.ft.get_outcomes(smiles) 22 | self.assertEqual(type(res[0]), str) 23 | self.assertEqual(len(res[1]), 181) 24 | 25 | 26 | if __name__ == '__main__': 27 | res = unittest.main(verbosity=3, exit=False) 28 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/evaluation_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import unittest 4 | 5 | import askcos.global_config as gc 6 | import askcos.synthetic.evaluation.evaluator as ev 7 | import askcos.synthetic.evaluation.fast_filter as ff 8 | import askcos.synthetic.evaluation.template_free as tf 9 | 10 | 11 | @unittest.skip('Non-deterministic') 12 | class TestTemplateFree(unittest.TestCase): 13 | def test_01_evaluate(self): 14 | react = 'CCCCO.CCCCBr' 15 | scorer = tf.TemplateFreeNeuralNetScorer() 16 | result = scorer.evaluate(react) 17 | 18 | with open(os.path.join(os.path.dirname(__file__), 'test_data/template_free.pkl'), 'rb') as t: 19 | expected = pickle.load(t, encoding='iso-8859-1') 20 | 21 | self.assertEqual(len(result), 1) 22 | self.assertEqual(len(result[0]), 73) 23 | 24 | for e, r in zip(expected[0], result[0]): 25 | self.assertEqual(e['outcome'], r['outcome']) 26 | self.assertEqual(e['rank'], r['rank']) 27 | self.assertAlmostEqual(e['score'], r['score'], places=4) 28 | self.assertAlmostEqual(e['prob'], r['prob'], places=4) 29 | self.assertAlmostEqual(e['mol_wt'], r['mol_wt'], places=2) 30 | 31 | 32 | class TestFastFilter(unittest.TestCase): 33 | 34 | @classmethod 35 | def setUpClass(cls): 36 | """This method is run once before all tests in this class.""" 37 | cls.model = ff.FastFilterScorer() 38 | cls.model.load(model_path=gc.FAST_FILTER_MODEL['model_path']) 39 | 40 | def test_01_evaluate(self): 41 | result = self.model.evaluate('CCO.CC(=O)O', 'CCOC(=O)C') 42 | expected = [[{'outcome': {'smiles': 'CCOC(=O)C', 'template_ids': [], 'num_examples': 0}, 43 | 'score': 0.9789425730705261, 'rank': 1.0, 'prob': 0.9789425730705261}]] 44 | self.assertEqual(expected, result) 45 | 46 | def test_02_evaluate(self): 47 | result = self.model.evaluate('[CH3:1][C:2](=[O:3])[O:4][CH:5]1[CH:6]([O:7][C:8]([CH3:9])=[O:10])[CH:11]([CH2:12][O:13][C:14]([CH3:15])=[O:16])[O:17][CH:18]([O:19][CH2:20][CH2:21][CH2:22][CH2:23][CH2:24][CH2:25][CH2:26][CH2:27][CH2:28][CH3:29])[CH:30]1[O:31][C:32]([CH3:33])=[O:34].[CH3:35][O-:36].[CH3:38][OH:39].[Na+:37]', 'CCCCCCCCCCOC1OC(CO)C(O)C(O)C1O') 48 | expected = [[{'outcome': {'smiles': 'CCCCCCCCCCOC1OC(CO)C(O)C(O)C1O', 'template_ids': [], 'num_examples': 0}, 49 | 'score': 0.9983256459236145, 'rank': 1.0, 'prob': 0.9983256459236145}]] 50 | self.assertEqual(expected, result) 51 | 52 | def test_03_evaluate(self): 53 | result = self.model.evaluate('CNC.Cc1ccc(S(=O)(=O)OCCOC(c2ccccc2)c2ccccc2)cc1', 'CN(C)CCOC(c1ccccc1)c2ccccc2') 54 | expected = [[{'outcome': {'smiles': 'CN(C)CCOC(c1ccccc1)c2ccccc2', 'template_ids': [], 'num_examples': 0}, 55 | 'score': 0.9968607425689697, 'rank': 1.0, 'prob': 0.9968607425689697}]] 56 | self.assertEqual(expected, result) 57 | 58 | def test_04_filter_with_threshold(self): 59 | flag_result, score_result = self.model.filter_with_threshold('CCO.CC(=O)O', 'CCOC(=O)C', 0.75) 60 | expected_flag = [[True]] 61 | expected_score = 0.978942573071 62 | self.assertEqual(expected_flag, flag_result) 63 | self.assertAlmostEqual(expected_score, score_result) 64 | 65 | 66 | @unittest.skip('Non-deterministic') 67 | class TestEvaluator(unittest.TestCase): 68 | 69 | def test_01_evaluate(self): 70 | evaluator = ev.Evaluator(celery=False) 71 | result = evaluator.evaluate('CCCCO.CCCCBr', 'O=C1CCCCCCCO1', [(20, '', '', '', '', '')], 72 | forward_scorer=gc.templatefree, return_all_outcomes=True) 73 | 74 | with open(os.path.join(os.path.dirname(__file__), 'test_data/evaluator.pkl'), 'rb') as t: 75 | expected = pickle.load(t, encoding='iso-8859-1') 76 | 77 | self.assertEqual(len(result), 1) 78 | expected, result = expected[0], result[0] 79 | 80 | self.assertEqual(expected['number_of_outcomes'], result['number_of_outcomes']) 81 | self.assertEqual(expected['context'], result['context']) 82 | self.assertEqual(expected['target']['smiles'], result['target']['smiles']) 83 | 84 | self.assertEqual(expected['top_product']['num_examples'], result['top_product']['num_examples']) 85 | self.assertEqual(expected['top_product']['template_ids'], result['top_product']['template_ids']) 86 | self.assertEqual(expected['top_product']['smiles'], result['top_product']['smiles']) 87 | self.assertEqual(expected['top_product']['rank'], result['top_product']['rank']) 88 | self.assertAlmostEqual(expected['top_product']['score'], result['top_product']['score'], places=4) 89 | self.assertAlmostEqual(expected['top_product']['prob'], result['top_product']['prob'], places=4) 90 | 91 | for e, r in zip(expected['outcomes'], result['outcomes']): 92 | self.assertEqual(e['outcome'], r['outcome']) 93 | self.assertEqual(e['rank'], r['rank']) 94 | self.assertAlmostEqual(e['score'], r['score'], places=4) 95 | self.assertAlmostEqual(e['prob'], r['prob'], places=4) 96 | self.assertAlmostEqual(e['mol_wt'], r['mol_wt'], places=2) 97 | 98 | 99 | if __name__ == '__main__': 100 | res = unittest.main(verbosity=3, exit=False) 101 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_direct/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/synthetic/evaluation/rexgen_direct/__init__.py -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_direct/core_wln_global/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/synthetic/evaluation/rexgen_direct/core_wln_global/__init__.py -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_direct/core_wln_global/ioutils_direct.py: -------------------------------------------------------------------------------- 1 | import rdkit.Chem as Chem 2 | from askcos.synthetic.evaluation.rexgen_direct.core_wln_global.mol_graph import bond_fdim, bond_features 3 | import numpy as np 4 | 5 | BOND_TYPE = ["NOBOND", Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC] 6 | N_BOND_CLASS = len(BOND_TYPE) 7 | binary_fdim = 4 + bond_fdim 8 | INVALID_BOND = -1 9 | 10 | def get_bin_feature(r, max_natoms): 11 | comp = {} 12 | for i, s in enumerate(r.split('.')): 13 | mol = Chem.MolFromSmiles(s) 14 | for atom in mol.GetAtoms(): 15 | comp[atom.GetIntProp('molAtomMapNumber') - 1] = i 16 | n_comp = len(r.split('.')) 17 | rmol = Chem.MolFromSmiles(r) 18 | n_atoms = rmol.GetNumAtoms() 19 | bond_map = {} 20 | for bond in rmol.GetBonds(): 21 | a1 = bond.GetBeginAtom().GetIntProp('molAtomMapNumber') - 1 22 | a2 = bond.GetEndAtom().GetIntProp('molAtomMapNumber') - 1 23 | bond_map[(a1,a2)] = bond_map[(a2,a1)] = bond 24 | 25 | features = [] 26 | for i in range(max_natoms): 27 | for j in range(max_natoms): 28 | f = np.zeros((binary_fdim,)) 29 | if i >= n_atoms or j >= n_atoms or i == j: 30 | features.append(f) 31 | continue 32 | if (i,j) in bond_map: 33 | bond = bond_map[(i,j)] 34 | f[1:1+bond_fdim] = bond_features(bond) 35 | else: 36 | f[0] = 1.0 37 | f[-4] = 1.0 if comp[i] != comp[j] else 0.0 38 | f[-3] = 1.0 if comp[i] == comp[j] else 0.0 39 | f[-2] = 1.0 if n_comp == 1 else 0.0 40 | f[-1] = 1.0 if n_comp > 1 else 0.0 41 | features.append(f) 42 | return np.vstack(features).reshape((max_natoms,max_natoms,binary_fdim)) 43 | 44 | bo_to_index = {0.0: 0, 1:1, 2:2, 3:3, 1.5:4} 45 | nbos = len(bo_to_index) 46 | def get_bond_label(r, edits, max_natoms): 47 | rmol = Chem.MolFromSmiles(r) 48 | n_atoms = rmol.GetNumAtoms() 49 | rmap = np.zeros((max_natoms, max_natoms, nbos)) 50 | 51 | for s in edits.split(';'): 52 | a1,a2,bo = s.split('-') 53 | x = min(int(a1)-1,int(a2)-1) 54 | y = max(int(a1)-1, int(a2)-1) 55 | z = bo_to_index[float(bo)] 56 | rmap[x,y,z] = rmap[y,x,z] = 1 57 | 58 | labels = [] 59 | sp_labels = [] 60 | for i in range(max_natoms): 61 | for j in range(max_natoms): 62 | for k in range(len(bo_to_index)): 63 | if i == j or i >= n_atoms or j >= n_atoms: 64 | labels.append(INVALID_BOND) # mask 65 | else: 66 | labels.append(rmap[i,j,k]) 67 | if rmap[i,j,k] == 1: 68 | sp_labels.append(i * max_natoms * nbos + j * nbos + k) 69 | # TODO: check if this is consistent with how TF does flattening 70 | return np.array(labels), sp_labels 71 | 72 | def get_all_batch(re_list): 73 | mol_list = [] 74 | max_natoms = 0 75 | for r,e in re_list: 76 | rmol = Chem.MolFromSmiles(r) 77 | mol_list.append((r,e)) 78 | if rmol.GetNumAtoms() > max_natoms: 79 | max_natoms = rmol.GetNumAtoms() 80 | labels = [] 81 | features = [] 82 | sp_labels = [] 83 | for r,e in mol_list: 84 | l, sl = get_bond_label(r,e,max_natoms) 85 | features.append(get_bin_feature(r,max_natoms)) 86 | labels.append(l) 87 | sp_labels.append(sl) 88 | return np.array(features), np.array(labels), sp_labels 89 | 90 | def get_feature_batch(r_list): 91 | max_natoms = 0 92 | for r in r_list: 93 | rmol = Chem.MolFromSmiles(r) 94 | if rmol.GetNumAtoms() > max_natoms: 95 | max_natoms = rmol.GetNumAtoms() 96 | 97 | features = [] 98 | for r in r_list: 99 | features.append(get_bin_feature(r,max_natoms)) 100 | return np.array(features) 101 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_direct/core_wln_global/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | tf.disable_v2_behavior() 3 | from askcos.synthetic.evaluation.rexgen_direct.core_wln_global.mol_graph import max_nb 4 | from askcos.synthetic.evaluation.rexgen_direct.core_wln_global.nn import * 5 | 6 | def gated_convnet(graph_inputs, batch_size=64, hidden_size=300, depth=3, res_block=2): 7 | input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask = graph_inputs 8 | layers = [input_atom] 9 | atom_features = input_atom 10 | for i in range(depth): 11 | fatom_nei = tf.gather_nd(atom_features, atom_graph) 12 | fbond_nei = tf.gather_nd(input_bond, bond_graph) 13 | f_nei = tf.concat([fatom_nei, fbond_nei], 3) 14 | h_nei = linearND(f_nei, hidden_size, "nei_hidden_%d" % i) 15 | g_nei = tf.nn.sigmoid(linearND(f_nei, hidden_size, "nei_gate_%d" % i)) 16 | f_nei = h_nei * g_nei 17 | mask_nei = tf.reshape(tf.sequence_mask(tf.reshape(num_nbs, [-1]), max_nb, dtype=tf.float32), [batch_size,-1,max_nb,1]) 18 | f_nei = tf.reduce_sum(f_nei * mask_nei, -2) 19 | h_self = linearND(atom_features, hidden_size, "self_hidden_%d" % i) 20 | g_self = tf.nn.sigmoid(linearND(atom_features, hidden_size, "self_gate_%d" % i)) 21 | f_self = h_self * g_self 22 | atom_features = (f_nei + f_self) * node_mask 23 | if res_block is not None and i % res_block == 0 and i > 0: 24 | atom_features = atom_features + layers[-2] 25 | layers.append(atom_features) 26 | output_gate = tf.nn.sigmoid(linearND(atom_features, hidden_size, "out_gate")) 27 | output = node_mask * (output_gate * atom_features) 28 | fp = tf.reduce_sum(output, 1) 29 | return atom_features * node_mask, fp 30 | 31 | def rcnn_wl_last(graph_inputs, batch_size, hidden_size, depth, training=True): 32 | input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask = graph_inputs 33 | atom_features = tf.nn.relu(linearND(input_atom, hidden_size, "atom_embedding", init_bias=None)) 34 | layers = [] 35 | for i in range(depth): 36 | with tf.variable_scope("WL", reuse=(i>0)) as scope: 37 | fatom_nei = tf.gather_nd(atom_features, atom_graph) 38 | fbond_nei = tf.gather_nd(input_bond, bond_graph) 39 | h_nei_atom = linearND(fatom_nei, hidden_size, "nei_atom", init_bias=None) 40 | h_nei_bond = linearND(fbond_nei, hidden_size, "nei_bond", init_bias=None) 41 | h_nei = h_nei_atom * h_nei_bond 42 | mask_nei = tf.reshape(tf.sequence_mask(tf.reshape(num_nbs, [-1]), max_nb, dtype=tf.float32), [batch_size,-1,max_nb,1]) 43 | f_nei = tf.reduce_sum(h_nei * mask_nei, -2) 44 | f_self = linearND(atom_features, hidden_size, "self_atom", init_bias=None) 45 | layers.append(f_nei * f_self * node_mask) 46 | l_nei = tf.concat([fatom_nei, fbond_nei], 3) 47 | nei_label = tf.nn.relu(linearND(l_nei, hidden_size, "label_U2")) 48 | nei_label = tf.reduce_sum(nei_label * mask_nei, -2) 49 | new_label = tf.concat([atom_features, nei_label], 2) 50 | new_label = linearND(new_label, hidden_size, "label_U1") 51 | atom_features = tf.nn.relu(new_label) 52 | #kernels = tf.concat(1, layers) 53 | kernels = layers[-1] 54 | fp = tf.reduce_sum(kernels, 1) 55 | return kernels, fp 56 | 57 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_direct/core_wln_global/mol_graph.py: -------------------------------------------------------------------------------- 1 | import rdkit 2 | import rdkit.Chem as Chem 3 | import numpy as np 4 | 5 | elem_list = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'W', 'Ru', 'Nb', 'Re', 'Te', 'Rh', 'Tc', 'Ba', 'Bi', 'Hf', 'Mo', 'U', 'Sm', 'Os', 'Ir', 'Ce','Gd','Ga','Cs', 'unknown'] 6 | atom_fdim = len(elem_list) + 6 + 6 + 6 + 1 7 | bond_fdim = 6 8 | max_nb = 10 9 | 10 | def onek_encoding_unk(x, allowable_set): 11 | if x not in allowable_set: 12 | x = allowable_set[-1] 13 | return list(map(lambda s: x == s, allowable_set)) 14 | 15 | def atom_features(atom): 16 | return np.array(onek_encoding_unk(atom.GetSymbol(), elem_list) 17 | + onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5]) 18 | + onek_encoding_unk(atom.GetExplicitValence(), [1,2,3,4,5,6]) 19 | + onek_encoding_unk(atom.GetImplicitValence(), [0,1,2,3,4,5]) 20 | + [atom.GetIsAromatic()], dtype=np.float32) 21 | 22 | def bond_features(bond): 23 | bt = bond.GetBondType() 24 | return np.array([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.GetIsConjugated(), bond.IsInRing()], dtype=np.float32) 25 | 26 | def smiles2graph(smiles, idxfunc=lambda x:x.GetIdx()): 27 | mol = Chem.MolFromSmiles(smiles) 28 | if not mol: 29 | raise ValueError("Could not parse smiles string:", smiles) 30 | 31 | n_atoms = mol.GetNumAtoms() 32 | n_bonds = max(mol.GetNumBonds(), 1) 33 | fatoms = np.zeros((n_atoms, atom_fdim)) 34 | fbonds = np.zeros((n_bonds, bond_fdim)) 35 | atom_nb = np.zeros((n_atoms, max_nb), dtype=np.int32) 36 | bond_nb = np.zeros((n_atoms, max_nb), dtype=np.int32) 37 | num_nbs = np.zeros((n_atoms,), dtype=np.int32) 38 | 39 | for atom in mol.GetAtoms(): 40 | idx = idxfunc(atom) 41 | if idx >= n_atoms: 42 | raise Exception(smiles) 43 | fatoms[idx] = atom_features(atom) 44 | 45 | for bond in mol.GetBonds(): 46 | a1 = idxfunc(bond.GetBeginAtom()) 47 | a2 = idxfunc(bond.GetEndAtom()) 48 | idx = bond.GetIdx() 49 | if num_nbs[a1] == max_nb or num_nbs[a2] == max_nb: 50 | raise Exception(smiles) 51 | atom_nb[a1,num_nbs[a1]] = a2 52 | atom_nb[a2,num_nbs[a2]] = a1 53 | bond_nb[a1,num_nbs[a1]] = idx 54 | bond_nb[a2,num_nbs[a2]] = idx 55 | num_nbs[a1] += 1 56 | num_nbs[a2] += 1 57 | fbonds[idx] = bond_features(bond) 58 | return fatoms, fbonds, atom_nb, bond_nb, num_nbs 59 | 60 | def pack2D(arr_list): 61 | N = max([x.shape[0] for x in arr_list]) 62 | M = max([x.shape[1] for x in arr_list]) 63 | a = np.zeros((len(arr_list), N, M)) 64 | for i, arr in enumerate(arr_list): 65 | n = arr.shape[0] 66 | m = arr.shape[1] 67 | a[i,0:n,0:m] = arr 68 | return a 69 | 70 | def pack2D_withidx(arr_list): 71 | N = max([x.shape[0] for x in arr_list]) 72 | M = max([x.shape[1] for x in arr_list]) 73 | a = np.zeros((len(arr_list), N, M, 2)) 74 | for i, arr in enumerate(arr_list): 75 | n = arr.shape[0] 76 | m = arr.shape[1] 77 | a[i,0:n,0:m,0] = i 78 | a[i,0:n,0:m,1] = arr 79 | return a 80 | 81 | def pack1D(arr_list): 82 | N = max([x.shape[0] for x in arr_list]) 83 | a = np.zeros((len(arr_list), N)) 84 | for i, arr in enumerate(arr_list): 85 | n = arr.shape[0] 86 | a[i,0:n] = arr 87 | return a 88 | 89 | def get_mask(arr_list): 90 | N = max([x.shape[0] for x in arr_list]) 91 | a = np.zeros((len(arr_list), N)) 92 | for i, arr in enumerate(arr_list): 93 | for j in range(arr.shape[0]): 94 | a[i][j] = 1 95 | return a 96 | 97 | def smiles2graph_list(smiles_list, idxfunc=lambda x:x.GetIdx()): 98 | res = list(map(lambda x:smiles2graph(x,idxfunc), smiles_list)) 99 | fatom_list, fbond_list, gatom_list, gbond_list, nb_list = zip(*res) 100 | return pack2D(fatom_list), pack2D(fbond_list), pack2D_withidx(gatom_list), pack2D_withidx(gbond_list), pack1D(nb_list), get_mask(fatom_list) 101 | 102 | if __name__ == "__main__": 103 | import sys 104 | np.set_printoptions(threshold=sys.maxsize) 105 | a,b,c,d,e,f = smiles2graph_list(["c1cccnc1",'c1nccc2n1ccc2']) 106 | print(a) 107 | print(b) 108 | print(c) 109 | print(d) 110 | print(e) 111 | print(f) 112 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_direct/core_wln_global/mol_graph_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import sys 4 | import unittest 5 | 6 | import numpy as np 7 | 8 | import askcos.synthetic.evaluation.rexgen_direct.core_wln_global.mol_graph as mg 9 | import askcos.synthetic.evaluation.rexgen_direct.core_wln_global.mol_graph_rich as mgr 10 | 11 | 12 | class TestMolGraph(unittest.TestCase): 13 | def test_01_smiles2graph_list(self): 14 | np.set_printoptions(threshold=sys.maxsize) 15 | result = mg.smiles2graph_list(["c1cccnc1", 'c1nccc2n1ccc2']) 16 | with open(os.path.join(os.path.dirname(__file__), 'test_data/mg_smiles2graph_list.pkl'), 'rb') as t: 17 | expected = pickle.load(t, encoding='iso-8859-1') 18 | for e, r in zip(expected, result): 19 | self.assertTrue((e == r).all()) 20 | 21 | 22 | class TestMolGraphRich(unittest.TestCase): 23 | def test_01_smiles2graph_list(self): 24 | np.set_printoptions(threshold=sys.maxsize) 25 | result = mgr.smiles2graph_list(["c1cccnc1", 'c1nccc2n1ccc2']) 26 | with open(os.path.join(os.path.dirname(__file__), 'test_data/mgr_smiles2graph_list.pkl'), 'rb') as t: 27 | expected = pickle.load(t, encoding='iso-8859-1') 28 | for e, r in zip(expected, result): 29 | self.assertTrue((e == r).all()) 30 | 31 | 32 | if __name__ == '__main__': 33 | res = unittest.main(verbosity=3, exit=False) 34 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_direct/predict.py: -------------------------------------------------------------------------------- 1 | from askcos.synthetic.evaluation.rexgen_direct.core_wln_global.directcorefinder import DirectCoreFinder 2 | from askcos.synthetic.evaluation.rexgen_direct.rank_diff_wln.directcandranker import DirectCandRanker 3 | import rdkit.Chem as Chem 4 | import sys 5 | import os 6 | 7 | class TFFP(): 8 | '''Template-free forward predictor''' 9 | def __init__(self): 10 | self.finder = DirectCoreFinder(batch_size=1) 11 | self.finder.load_model() 12 | self.ranker = DirectCandRanker() 13 | self.ranker.load_model() 14 | 15 | def predict(self, smi, top_n=100, atommap=False): 16 | m = Chem.MolFromSmiles(smi) 17 | if not m: 18 | if smi[-1] == '.': 19 | m = Chem.MolFromSmiles(smi[:-1]) 20 | if not m: 21 | raise ValueError('Could not parse molecule for TFFP! {}'.format(smi)) 22 | atom_mappings = [a.GetAtomMapNum() for a in m.GetAtoms()] 23 | if len(set(atom_mappings)) != len(atom_mappings): 24 | [a.SetIntProp('molAtomMapNumber', i+1) for (i, a) in enumerate(m.GetAtoms())] 25 | rsmi_am = Chem.MolToSmiles(m, isomericSmiles=False) 26 | (react, bond_preds, bond_scores, cur_att_score) = self.finder.predict(rsmi_am) 27 | outcomes = self.ranker.predict(react, bond_preds, bond_scores, 28 | scores=True, top_n=top_n, atommap=atommap) 29 | return rsmi_am, outcomes 30 | 31 | 32 | if __name__ == "__main__": 33 | tffp = TFFP() 34 | if len(sys.argv) < 2: 35 | react = 'CCCO.CCCBr' 36 | else: 37 | react = str(sys.argv[1]) 38 | 39 | print(react) 40 | print(tffp.predict(react)) -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_direct/predict_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import unittest 4 | 5 | import askcos.synthetic.evaluation.rexgen_direct.predict as p 6 | 7 | 8 | class TestPredict(unittest.TestCase): 9 | 10 | def test_01_predict(self): 11 | """Test template free forward predictor.""" 12 | tffp = p.TFFP() 13 | rsmi_am, result = tffp.predict('CCCO.CCCBr') 14 | 15 | self.assertEqual('[CH3:1][CH2:2][CH2:3][OH:4].[CH3:5][CH2:6][CH2:7][Br:8]', rsmi_am) 16 | 17 | with open(os.path.join(os.path.dirname(__file__), 'test_data/predict.pkl'), 'rb') as t: 18 | expected = pickle.load(t, encoding='iso-8859-1') 19 | 20 | self.assertEqual(len(expected), len(result)) 21 | 22 | for e, r in zip(expected, result): 23 | self.assertEqual(e['smiles'], r['smiles']) 24 | self.assertEqual(e['rank'], r['rank']) 25 | self.assertAlmostEqual(e['prob'], r['prob'], places=4) 26 | self.assertAlmostEqual(e['score'], r['score'], places=4) 27 | 28 | 29 | if __name__ == '__main__': 30 | res = unittest.main(verbosity=3, exit=False) 31 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_direct/rank_diff_wln/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/synthetic/evaluation/rexgen_direct/rank_diff_wln/__init__.py -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_direct/rank_diff_wln/edit_mol_direct_useScores.py: -------------------------------------------------------------------------------- 1 | import rdkit 2 | from rdkit import Chem 3 | from optparse import OptionParser 4 | 5 | from rdkit import RDLogger 6 | lg = RDLogger.logger() 7 | lg.setLevel(4) 8 | 9 | BOND_TYPE = [0, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC] 10 | BOND_FLOAT_TO_TYPE = { 11 | 0.0: BOND_TYPE[0], 12 | 1.0: BOND_TYPE[1], 13 | 2.0: BOND_TYPE[2], 14 | 3.0: BOND_TYPE[3], 15 | 1.5: BOND_TYPE[4], 16 | } 17 | def copy_edit_mol(mol): 18 | new_mol = Chem.RWMol(Chem.MolFromSmiles('')) 19 | for atom in mol.GetAtoms(): 20 | new_atom = Chem.Atom(atom.GetSymbol()) 21 | new_atom.SetFormalCharge(atom.GetFormalCharge()) # TODO: How to deal with changing formal charge? 22 | new_atom.SetAtomMapNum(atom.GetAtomMapNum()) 23 | new_mol.AddAtom(new_atom) 24 | for bond in mol.GetBonds(): 25 | a1 = bond.GetBeginAtom().GetIdx() 26 | a2 = bond.GetEndAtom().GetIdx() 27 | bt = bond.GetBondType() 28 | new_mol.AddBond(a1, a2, bt) 29 | return new_mol 30 | 31 | def get_product_smiles(rmol, edits, tatoms): 32 | smiles = edit_mol(rmol, edits, tatoms) 33 | if len(smiles) != 0: return smiles 34 | try: 35 | Chem.Kekulize(rmol) 36 | except Exception as e: 37 | return smiles 38 | return edit_mol(rmol, edits, tatoms) 39 | 40 | def edit_mol(rmol, edits, tatoms): 41 | #new_mol = copy_edit_mol(rmol) 42 | new_mol = Chem.RWMol(rmol) 43 | [a.SetNumExplicitHs(0) for a in new_mol.GetAtoms()] 44 | 45 | amap = {} 46 | for atom in rmol.GetAtoms(): 47 | amap[atom.GetAtomMapNum() - 1] = atom.GetIdx() 48 | 49 | for x,y,t,v in edits: 50 | bond = new_mol.GetBondBetweenAtoms(amap[x],amap[y]) 51 | # a1 = new_mol.GetAtomWithIdx(amap[x]) 52 | # a2 = new_mol.GetAtomWithIdx(amap[y]) 53 | if bond is not None: 54 | new_mol.RemoveBond(amap[x],amap[y]) 55 | if t > 0: 56 | new_mol.AddBond(amap[x],amap[y],BOND_FLOAT_TO_TYPE[t]) 57 | 58 | pred_mol = new_mol.GetMol() 59 | pred_smiles = Chem.MolToSmiles(pred_mol) 60 | pred_list = pred_smiles.split('.') 61 | pred_mols = [] 62 | for pred_smiles in pred_list: 63 | mol = Chem.MolFromSmiles(pred_smiles) 64 | if mol is None: continue 65 | atom_set = set([atom.GetAtomMapNum() - 1 for atom in mol.GetAtoms()]) 66 | if len(atom_set & tatoms) == 0: 67 | continue 68 | for atom in mol.GetAtoms(): 69 | atom.SetAtomMapNum(0) 70 | pred_mols.append(mol) 71 | 72 | return '.'.join( sorted([Chem.MolToSmiles(pred_mol) for pred_mol in pred_mols]) ) 73 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_direct/rank_diff_wln/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | tf.disable_v2_behavior() 3 | from askcos.synthetic.evaluation.rexgen_direct.rank_diff_wln.mol_graph_direct_useScores import max_nb 4 | from askcos.synthetic.evaluation.rexgen_direct.rank_diff_wln.nn import * 5 | 6 | def rcnn_wl_last(graph_inputs, hidden_size, depth, training=True): 7 | input_atom, input_bond, atom_graph, bond_graph, num_nbs = graph_inputs 8 | atom_features = tf.nn.relu(linearND(input_atom, hidden_size, "atom_embedding", init_bias=None)) 9 | layers = [] 10 | for i in range(depth): 11 | with tf.variable_scope("WL", reuse=(i>0)) as scope: 12 | fatom_nei = tf.gather_nd(atom_features, atom_graph) 13 | fbond_nei = tf.gather_nd(input_bond, bond_graph) 14 | h_nei_atom = linearND(fatom_nei, hidden_size, "nei_atom", init_bias=None) 15 | h_nei_bond = linearND(fbond_nei, hidden_size, "nei_bond", init_bias=None) 16 | h_nei = h_nei_atom * h_nei_bond 17 | mask_nei = tf.sequence_mask(tf.reshape(num_nbs, [-1]), max_nb, dtype=tf.float32) 18 | target_shape = tf.concat([tf.shape(num_nbs), [max_nb, 1]], 0) 19 | mask_nei = tf.reshape(mask_nei, target_shape) 20 | mask_nei.set_shape([None, None, max_nb, 1]) 21 | f_nei = tf.reduce_sum(h_nei * mask_nei, -2) 22 | f_self = linearND(atom_features, hidden_size, "self_atom", init_bias=None) 23 | layers.append(f_nei * f_self) 24 | l_nei = tf.concat([fatom_nei, fbond_nei], 3) 25 | nei_label = tf.nn.relu(linearND(l_nei, hidden_size, "label_U2")) 26 | nei_label = tf.reduce_sum(nei_label * mask_nei, -2) 27 | new_label = tf.concat([atom_features, nei_label], 2) 28 | new_label = linearND(new_label, hidden_size, "label_U1") 29 | atom_features = tf.nn.relu(new_label) 30 | #kernels = tf.concat(1, layers) 31 | kernels = layers[-1] 32 | fp = tf.reduce_sum(kernels, 1) 33 | return kernels, fp 34 | 35 | def rcnn_wl_only(graph_inputs, hidden_size, depth, training=True): 36 | input_atom, input_bond, atom_graph, bond_graph, num_nbs = graph_inputs 37 | atom_features = tf.nn.relu(linearND(input_atom, hidden_size, "atom_embedding", init_bias=None)) 38 | layers = [] 39 | for i in range(depth): 40 | with tf.variable_scope("WL", reuse=(i>0)) as scope: 41 | fatom_nei = tf.gather_nd(atom_features, atom_graph) 42 | fbond_nei = tf.gather_nd(input_bond, bond_graph) 43 | 44 | mask_nei = tf.sequence_mask(tf.reshape(num_nbs, [-1]), max_nb, dtype=tf.float32) 45 | target_shape = tf.concat([tf.shape(num_nbs), [max_nb, 1]], 0) 46 | mask_nei = tf.reshape(mask_nei, target_shape) 47 | mask_nei.set_shape([None, None, max_nb, 1]) 48 | 49 | l_nei = tf.concat([fatom_nei, fbond_nei], 3) 50 | nei_label = tf.nn.relu(linearND(l_nei, hidden_size, "label_U2")) 51 | nei_label = tf.reduce_sum(nei_label * mask_nei, -2) 52 | new_label = tf.concat([atom_features, nei_label], 2) 53 | new_label = linearND(new_label, hidden_size, "label_U1") 54 | atom_features = tf.nn.relu(new_label) 55 | 56 | return atom_features 57 | 58 | def wl_diff_net(graph_inputs, atom_features, hidden_size, depth): 59 | input_atom, input_bond, atom_graph, bond_graph, num_nbs = graph_inputs 60 | for i in range(depth): 61 | with tf.variable_scope("WL", reuse=(i>0)) as scope: 62 | fatom_nei = tf.gather_nd(atom_features, atom_graph) 63 | fbond_nei = tf.gather_nd(input_bond, bond_graph) 64 | 65 | mask_nei = tf.sequence_mask(tf.reshape(num_nbs, [-1]), max_nb, dtype=tf.float32) 66 | target_shape = tf.concat([tf.shape(num_nbs), [max_nb, 1]], 0) 67 | mask_nei = tf.reshape(mask_nei, target_shape) 68 | mask_nei.set_shape([None, None, max_nb, 1]) 69 | 70 | l_nei = tf.concat([fatom_nei, fbond_nei], 3) 71 | nei_label = tf.nn.relu(linearND(l_nei, hidden_size, "label_U2")) 72 | nei_label = tf.reduce_sum(nei_label * mask_nei, -2) 73 | new_label = tf.concat([atom_features, nei_label], 2) 74 | new_label = linearND(new_label, hidden_size, "label_U1") 75 | atom_features = tf.nn.relu(new_label) 76 | 77 | return tf.reduce_sum(atom_features, -2) 78 | 79 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_release/CandRanker/__init__.py: -------------------------------------------------------------------------------- 1 | from .cand_ranker import CandRanker 2 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_release/CandRanker/cand_ranker.py: -------------------------------------------------------------------------------- 1 | import rdkit 2 | import rdkit.Chem as Chem 3 | import tensorflow.compat.v1 as tf 4 | from ..utils.nn import linearND, linear 5 | from .mol_graph import atom_fdim as adim, bond_fdim as bdim, max_nb, smiles2graph_test, bond_types 6 | from .models import * 7 | from .edit_mol import edit_mol 8 | 9 | class CandRanker(object): 10 | 11 | def __init__(self, hidden_size, depth, MAX_NCAND=2000, TOPK=5): 12 | self.hidden_size = hidden_size 13 | self.depth = depth 14 | self.MAX_NCAND = MAX_NCAND 15 | self.TOPK = TOPK 16 | 17 | def load_model(self, model_path): 18 | hidden_size = self.hidden_size 19 | depth = self.depth 20 | 21 | self.graph = tf.Graph() 22 | with self.graph.as_default(): 23 | input_atom = tf.placeholder(tf.float32, [None, None, adim]) 24 | input_bond = tf.placeholder(tf.float32, [None, None, bdim]) 25 | atom_graph = tf.placeholder(tf.int32, [None, None, max_nb, 2]) 26 | bond_graph = tf.placeholder(tf.int32, [None, None, max_nb, 2]) 27 | num_nbs = tf.placeholder(tf.int32, [None, None]) 28 | self.leaf_nodes = [input_atom, input_bond, atom_graph, bond_graph, num_nbs] 29 | 30 | graph_inputs = (input_atom, input_bond, atom_graph, bond_graph, num_nbs) 31 | with tf.variable_scope("encoder"): 32 | _, fp = rcnn_wl_last(graph_inputs, hidden_size=hidden_size, depth=depth) 33 | 34 | reactant = fp[0:1,:] 35 | candidates = fp[1:,:] 36 | candidates = candidates - reactant 37 | candidates = linear(candidates, hidden_size, "candidate") 38 | match = tf.nn.relu(candidates) 39 | self.score = tf.squeeze(linear(match, 1, "score"), [1]) 40 | self.scaled_score = tf.nn.softmax(self.score) 41 | cur_k = tf.minimum(self.TOPK, tf.shape(self.score)[0]) 42 | _, self.topk_cands = tf.nn.top_k(self.score, cur_k) 43 | 44 | self.session = tf.Session() 45 | #tf.global_variables_initializer().run(session=self.session) 46 | saver = tf.train.Saver() 47 | saver.restore(self.session, tf.train.latest_checkpoint(model_path)) 48 | 49 | def predict_one(self, r, core, scores=True, top_n=100): 50 | core = [(x-1,y-1) for x,y in core] 51 | ncore = len(core) 52 | while True: 53 | src_tuple,core_conf = smiles2graph_test(r, core[:ncore]) 54 | if len(core_conf) <= self.MAX_NCAND: 55 | break 56 | ncore -= 1 57 | feed_map = {x:y for x,y in zip(self.leaf_nodes, src_tuple)} 58 | if scores: 59 | (cur_scores, cur_probs, candidates) = self.session.run([self.score, self.scaled_score, self.topk_cands], feed_dict=feed_map) 60 | else: 61 | candidates = self.session.run(self.topk_cands, feed_dict=feed_map) 62 | 63 | rmol = Chem.MolFromSmiles(r) 64 | rbonds = {} 65 | for bond in rmol.GetBonds(): 66 | a1 = bond.GetBeginAtom().GetAtomMapNum() 67 | a2 = bond.GetEndAtom().GetAtomMapNum() 68 | t = bond_types.index(bond.GetBondType()) + 1 69 | a1,a2 = min(a1,a2),max(a1,a2) 70 | rbonds[(a1,a2)] = t 71 | 72 | cand_smiles = []; cand_scores = []; cand_probs = []; 73 | for idx in candidates: 74 | edits = [] 75 | for x,y,t in core_conf[idx]: 76 | x,y = x+1,y+1 77 | if ((x,y) not in rbonds and t > 0) or ((x,y) in rbonds and rbonds[(x,y)] != t): 78 | edits.append( (x,y,t) ) 79 | cand = edit_mol(rmol, edits) 80 | cand_smiles.append(cand) 81 | cand_scores.append(cur_scores[idx]) 82 | cand_probs.append(cur_probs[idx]) 83 | 84 | outcomes = [] 85 | if scores: 86 | for i in range(min(len(cand_smiles), top_n)): 87 | outcomes.append({ 88 | 'rank': i + 1, 89 | 'smiles': cand_smiles[i], 90 | 'score': cand_scores[i], 91 | 'prob': cand_probs[i], 92 | }) 93 | else: 94 | for i in range(min(len(cand_smiles), top_n)): 95 | outcomes.append({ 96 | 'rank': i + 1, 97 | 'smiles': cand_smiles[i], 98 | }) 99 | 100 | return outcomes 101 | 102 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_release/CandRanker/edit_mol.py: -------------------------------------------------------------------------------- 1 | import rdkit 2 | from rdkit import Chem 3 | from optparse import OptionParser 4 | 5 | from rdkit import RDLogger 6 | lg = RDLogger.logger() 7 | lg.setLevel(4) 8 | 9 | BOND_TYPE = [0, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC] 10 | 11 | def edit_mol(rmol, edits): 12 | n_atoms = rmol.GetNumAtoms() 13 | 14 | new_mol = Chem.RWMol(rmol) 15 | amap = {} 16 | numH = {} 17 | for atom in rmol.GetAtoms(): 18 | amap[atom.GetIntProp('molAtomMapNumber')] = atom.GetIdx() 19 | numH[atom.GetIntProp('molAtomMapNumber')] = atom.GetNumExplicitHs() 20 | 21 | for x,y,t in edits: 22 | bond = new_mol.GetBondBetweenAtoms(amap[x],amap[y]) 23 | a1 = new_mol.GetAtomWithIdx(amap[x]) 24 | a2 = new_mol.GetAtomWithIdx(amap[y]) 25 | if bond is not None: 26 | val = BOND_TYPE.index(bond.GetBondType()) 27 | new_mol.RemoveBond(amap[x],amap[y]) 28 | numH[x] += val 29 | numH[y] += val 30 | 31 | if t > 0: 32 | new_mol.AddBond(amap[x],amap[y],BOND_TYPE[t]) 33 | numH[x] -= t 34 | numH[y] -= t 35 | 36 | for atom in new_mol.GetAtoms(): 37 | val = numH[atom.GetIntProp('molAtomMapNumber')] 38 | if val >= 0: 39 | atom.SetNumExplicitHs(val) 40 | 41 | pred_mol = new_mol.GetMol() 42 | for atom in pred_mol.GetAtoms(): 43 | atom.ClearProp('molAtomMapNumber') 44 | pred_smiles = Chem.MolToSmiles(pred_mol) 45 | 46 | return pred_smiles 47 | 48 | if __name__ == "__main__": 49 | 50 | parser = OptionParser() 51 | parser.add_option("-t", "--pred", dest="pred_path") 52 | parser.add_option("-g", "--gold", dest="gold_path") 53 | opts,args = parser.parse_args() 54 | 55 | fpred = open(opts.pred_path) 56 | fgold = open(opts.gold_path) 57 | 58 | rank = [] 59 | for line in fpred: 60 | line = line.strip('\r\n |') 61 | gold = fgold.readline() 62 | rex,_ = gold.split() 63 | r,_,p = rex.split('>') 64 | rmol = Chem.MolFromSmiles(r) 65 | pmol = Chem.MolFromSmiles(p) 66 | 67 | patoms = set() 68 | pbonds = {} 69 | for bond in pmol.GetBonds(): 70 | a1 = bond.GetBeginAtom().GetIntProp('molAtomMapNumber') 71 | a2 = bond.GetEndAtom().GetIntProp('molAtomMapNumber') 72 | t = BOND_TYPE.index(bond.GetBondType()) 73 | a1,a2 = min(a1,a2),max(a1,a2) 74 | pbonds[(a1,a2)] = t 75 | patoms.add(a1) 76 | patoms.add(a2) 77 | 78 | rbonds = {} 79 | for bond in rmol.GetBonds(): 80 | a1 = bond.GetBeginAtom().GetIntProp('molAtomMapNumber') 81 | a2 = bond.GetEndAtom().GetIntProp('molAtomMapNumber') 82 | t = BOND_TYPE.index(bond.GetBondType()) 83 | a1,a2 = min(a1,a2),max(a1,a2) 84 | if a1 in patoms or a2 in patoms: 85 | rbonds[(a1,a2)] = t 86 | 87 | rk = 10 88 | for idx,edits in enumerate(line.split('|')): 89 | cbonds = [] 90 | pred = dict(rbonds) 91 | for edit in edits.split(): 92 | x,y,t = edit.split('-') 93 | x,y,t = int(x),int(y),int(t) 94 | cbonds.append((x,y,t)) 95 | if t == 0 and (x,y) in rbonds: 96 | del pred[(x,y)] 97 | if t > 0: 98 | pred[(x,y)] = t 99 | 100 | if pred == pbonds: 101 | rk = idx + 1 102 | break 103 | for atom in pmol.GetAtoms(): 104 | atom.ClearProp('molAtomMapNumber') 105 | psmiles = Chem.MolToSmiles(pmol) 106 | psmiles = set(psmiles.split('.')) 107 | pred_smiles = set(edit_mol(r, cbonds).split('.')) 108 | if psmiles <= pred_smiles: 109 | rk = idx + 1 110 | break 111 | rank.append(rk) 112 | 113 | n = 1.0 * len(rank) 114 | top1,top3,top5 = 0,0,0 115 | for idx in rank: 116 | if idx == 1: top1 += 1 117 | if idx <= 3: top3 += 1 118 | if idx <= 5: top5 += 1 119 | 120 | print('%.4f, %.4f, %.4f' % (top1 / n, top3 / n, top5 / n)) 121 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_release/CandRanker/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from .mol_graph import max_nb 3 | from ..utils.nn import * 4 | 5 | def rcnn_wl_last(graph_inputs, hidden_size, depth, training=True): 6 | input_atom, input_bond, atom_graph, bond_graph, num_nbs = graph_inputs 7 | atom_features = tf.nn.relu(linearND(input_atom, hidden_size, "atom_embedding", init_bias=None)) 8 | layers = [] 9 | for i in range(depth): 10 | with tf.variable_scope("WL", reuse=(i>0)) as scope: 11 | fatom_nei = tf.gather_nd(atom_features, atom_graph) 12 | fbond_nei = tf.gather_nd(input_bond, bond_graph) 13 | h_nei_atom = linearND(fatom_nei, hidden_size, "nei_atom", init_bias=None) 14 | h_nei_bond = linearND(fbond_nei, hidden_size, "nei_bond", init_bias=None) 15 | h_nei = h_nei_atom * h_nei_bond 16 | mask_nei = tf.sequence_mask(tf.reshape(num_nbs, [-1]), max_nb, dtype=tf.float32) 17 | target_shape = tf.concat([tf.shape(num_nbs), [max_nb, 1]], 0) 18 | mask_nei = tf.reshape(mask_nei, target_shape) 19 | mask_nei.set_shape([None, None, max_nb, 1]) 20 | f_nei = tf.reduce_sum(h_nei * mask_nei, -2) 21 | f_self = linearND(atom_features, hidden_size, "self_atom", init_bias=None) 22 | layers.append(f_nei * f_self) 23 | l_nei = tf.concat([fatom_nei, fbond_nei], 3) 24 | nei_label = tf.nn.relu(linearND(l_nei, hidden_size, "label_U2")) 25 | nei_label = tf.reduce_sum(nei_label * mask_nei, -2) 26 | new_label = tf.concat([atom_features, nei_label], 2) 27 | new_label = linearND(new_label, hidden_size, "label_U1") 28 | atom_features = tf.nn.relu(new_label) 29 | #kernels = tf.concat(1, layers) 30 | kernels = layers[-1] 31 | fp = tf.reduce_sum(kernels, 1) 32 | return kernels, fp 33 | """ 34 | 35 | def rcnn_wl_last(graph_inputs, hidden_size, depth): 36 | input_atom, input_bond, atom_graph, bond_graph, num_nbs = graph_inputs 37 | atom_features = tf.nn.relu(linearND(input_atom, hidden_size, "atom_embedding", init_bias=None)) 38 | fbond_nei = tf.gather_nd(input_bond, bond_graph) 39 | 40 | target_shape = tf.concat(0, [tf.shape(num_nbs), [max_nb, 1]]) 41 | mask_nei = tf.sequence_mask(tf.reshape(num_nbs, [-1]), max_nb, dtype=tf.float32) 42 | mask_nei = tf.reshape(mask_nei, target_shape) 43 | mask_nei.set_shape([None, None, max_nb, 1]) 44 | 45 | for i in range(depth): 46 | with tf.variable_scope("WL", reuse=(i>0)) as scope: 47 | fatom_nei = tf.gather_nd(atom_features, atom_graph) 48 | l_nei = tf.concat(3, [fatom_nei, fbond_nei]) 49 | nei_label = tf.nn.relu(linearND(l_nei, hidden_size, "label_U2")) 50 | nei_label = tf.reduce_sum(nei_label * mask_nei, -2) 51 | new_label = tf.concat(2, [atom_features, nei_label]) 52 | new_label = linearND(new_label, hidden_size, "label_U1") 53 | atom_features = tf.nn.relu(new_label) 54 | 55 | fatom_nei = tf.gather_nd(atom_features, atom_graph) 56 | h_nei_atom = linearND(fatom_nei, hidden_size, "nei_atom", init_bias=None) 57 | h_nei_bond = linearND(fbond_nei, hidden_size, "nei_bond", init_bias=None) 58 | h_nei = h_nei_atom * h_nei_bond 59 | f_nei = tf.reduce_sum(h_nei * mask_nei, -2) 60 | f_self = linearND(atom_features, hidden_size, "self_atom", init_bias=None) 61 | kernels = f_nei * f_self 62 | fp = tf.reduce_sum(kernels, 1) 63 | 64 | return kernels, fp 65 | """ 66 | 67 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_release/CandRanker/mol_graph_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import unittest 4 | 5 | import askcos.synthetic.evaluation.rexgen_release.CandRanker.mol_graph as mg 6 | 7 | 8 | class TestCRMolGraph(unittest.TestCase): 9 | 10 | @unittest.skip('Non-deterministic') 11 | def test_01_smiles2graph(self): 12 | """Test askcos.synthetic.evaluation.rexgen_release.CandRanker.mol_graph.smiles2graph""" 13 | result = mg.smiles2graph("[OH:1][CH3:2]", "[O:1]=[CH2:2]", [(0, 1)]) 14 | with open(os.path.join(os.path.dirname(__file__), 'test_data/CR_smiles2graph.pkl'), 'rb') as t: 15 | expected = pickle.load(t, encoding='iso-8859-1') 16 | for e, r in zip(expected[0], result[0]): 17 | self.assertTrue((e == r).all()) 18 | self.assertEqual(sorted(expected[1]), sorted(result[1])) 19 | 20 | 21 | if __name__ == '__main__': 22 | res = unittest.main(verbosity=3, exit=False) 23 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_release/CoreFinder/__init__.py: -------------------------------------------------------------------------------- 1 | from .core_finder import CoreFinder 2 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_release/CoreFinder/core_finder.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from ..utils.nn import linearND, linear 3 | from .mol_graph import atom_fdim as adim, bond_fdim as bdim, max_nb, smiles2graph_batch 4 | from .models import * 5 | from .ioutils import * 6 | 7 | class CoreFinder(object): 8 | 9 | def __init__(self, hidden_size, depth, batch_size=20): 10 | self.hidden_size = hidden_size 11 | self.batch_size = batch_size 12 | self.depth = depth 13 | 14 | def load_model(self, model_path): 15 | hidden_size = self.hidden_size 16 | batch_size = self.batch_size 17 | depth = self.depth 18 | 19 | self.graph = tf.Graph() 20 | with self.graph.as_default(): 21 | input_atom = tf.placeholder(tf.float32, [batch_size, None, adim]) 22 | input_bond = tf.placeholder(tf.float32, [batch_size, None, bdim]) 23 | atom_graph = tf.placeholder(tf.int32, [batch_size, None, max_nb, 2]) 24 | bond_graph = tf.placeholder(tf.int32, [batch_size, None, max_nb, 2]) 25 | num_nbs = tf.placeholder(tf.int32, [batch_size, None]) 26 | node_mask = tf.placeholder(tf.float32, [batch_size, None]) 27 | binary = tf.placeholder(tf.float32, [batch_size, None, None, binary_fdim]) 28 | validity = tf.placeholder(tf.float32, [batch_size, None, None]) 29 | core_size = tf.placeholder(tf.int32) 30 | 31 | self.leaf_nodes = [input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask, binary, validity, core_size] 32 | 33 | node_mask = tf.expand_dims(node_mask, -1) 34 | graph_inputs = (input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask) 35 | with tf.variable_scope("encoder"): 36 | atom_hiddens, _ = rcnn_wl_last(graph_inputs, batch_size=batch_size, hidden_size=hidden_size, depth=depth) 37 | 38 | atom_hiddens1 = tf.reshape(atom_hiddens, [batch_size, 1, -1, hidden_size]) 39 | atom_hiddens2 = tf.reshape(atom_hiddens, [batch_size, -1, 1, hidden_size]) 40 | atom_pair = atom_hiddens1 + atom_hiddens2 41 | 42 | att_hidden = tf.nn.relu(linearND(atom_pair, hidden_size, scope="att_atom_feature", init_bias=None) + linearND(binary, hidden_size, scope="att_bin_feature")) 43 | att_score = linearND(att_hidden, 1, scope="att_scores") 44 | att_score = tf.nn.sigmoid(att_score) 45 | att_context = att_score * atom_hiddens1 46 | att_context = tf.reduce_sum(att_context, 2) 47 | 48 | att_context1 = tf.reshape(att_context, [batch_size, 1, -1, hidden_size]) 49 | att_context2 = tf.reshape(att_context, [batch_size, -1, 1, hidden_size]) 50 | att_pair = att_context1 + att_context2 51 | 52 | pair_hidden = linearND(atom_pair, hidden_size, scope="atom_feature", init_bias=None) + linearND(binary, hidden_size, scope="bin_feature", init_bias=None) + linearND(att_pair, hidden_size, scope="ctx_feature") 53 | pair_hidden = tf.nn.relu(pair_hidden) 54 | score = tf.squeeze(linearND(pair_hidden, 1, scope="scores"), [3]) + validity * 10000 55 | 56 | score = tf.reshape(score, [batch_size, -1]) 57 | _, self.topk = tf.nn.top_k(score, core_size) 58 | 59 | self.session = tf.Session() 60 | #tf.global_variables_initializer().run(session=self.session) 61 | saver = tf.train.Saver() 62 | saver.restore(self.session, tf.train.latest_checkpoint(model_path)) 63 | 64 | def predict(self, reactants, num_core): 65 | reaction_cores = [] 66 | batch_size = self.batch_size 67 | num_core *= 2 68 | 69 | for it in range(0, len(reactants), batch_size): 70 | src_batch = reactants[it:it + batch_size] 71 | src_tuple = smiles2graph_batch(src_batch) 72 | cur_bin, cur_validity = get_all_batch(src_batch) 73 | leaf_values = src_tuple + (cur_bin, cur_validity, num_core) 74 | feed_map = {x:y for x,y in zip(self.leaf_nodes, leaf_values)} 75 | cur_topk = self.session.run(self.topk, feed_dict=feed_map) 76 | cur_dim = cur_validity.shape[1] 77 | 78 | for i in range(batch_size): 79 | res = [] 80 | for j in range(num_core): 81 | k = cur_topk[i,j] 82 | x = k // cur_dim 83 | y = k % cur_dim 84 | if x < y and cur_validity[i,x,y] == 1: 85 | res.append( (x + 1,y + 1) ) 86 | reaction_cores.append(res) 87 | 88 | return reaction_cores 89 | 90 | if __name__ == "__main__": 91 | import sys 92 | cf = CoreFinder(batch_size=10, hidden_size=300, depth=3) 93 | cf.load_model("uspto-300-3") 94 | data = [] 95 | for line in sys.stdin: 96 | data.append(line.split()[0].split('>')[0]) 97 | if len(data) == 40: 98 | break 99 | rcores = cf.predict(data, 10) 100 | for core in rcores: 101 | print(core) 102 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_release/CoreFinder/ioutils.py: -------------------------------------------------------------------------------- 1 | import rdkit.Chem as Chem 2 | from .mol_graph import bond_fdim, bond_features 3 | import numpy as np 4 | 5 | BOND_TYPE = ["NOBOND", Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC] 6 | N_BOND_CLASS = len(BOND_TYPE) 7 | binary_fdim = 4 + bond_fdim 8 | INVALID_BOND = -1 9 | 10 | def get_bin_feature(r, max_natoms): 11 | comp = {} 12 | for i, s in enumerate(r.split('.')): 13 | mol = Chem.MolFromSmiles(s) 14 | for atom in mol.GetAtoms(): 15 | comp[atom.GetIntProp('molAtomMapNumber') - 1] = i 16 | n_comp = len(r.split('.')) 17 | rmol = Chem.MolFromSmiles(r) 18 | n_atoms = rmol.GetNumAtoms() 19 | bond_map = {} 20 | for bond in rmol.GetBonds(): 21 | a1 = bond.GetBeginAtom().GetIntProp('molAtomMapNumber') - 1 22 | a2 = bond.GetEndAtom().GetIntProp('molAtomMapNumber') - 1 23 | bond_map[(a1,a2)] = bond_map[(a2,a1)] = bond 24 | 25 | features = [] 26 | for i in range(max_natoms): 27 | for j in range(max_natoms): 28 | f = np.zeros((binary_fdim,)) 29 | if i >= n_atoms or j >= n_atoms or i == j: 30 | features.append(f) 31 | continue 32 | if (i,j) in bond_map: 33 | bond = bond_map[(i,j)] 34 | f[1:1+bond_fdim] = bond_features(bond) 35 | else: 36 | f[0] = 1.0 37 | f[-4] = 1.0 if comp[i] != comp[j] else 0.0 38 | f[-3] = 1.0 if comp[i] == comp[j] else 0.0 39 | f[-2] = 1.0 if n_comp == 1 else 0.0 40 | f[-1] = 1.0 if n_comp > 1 else 0.0 41 | features.append(f) 42 | return np.vstack(features).reshape((max_natoms,max_natoms,binary_fdim)) 43 | 44 | def get_all_batch(smiles_list): 45 | max_natoms = 0 46 | mol_list = [] 47 | for r in smiles_list: 48 | rmol = Chem.MolFromSmiles(r) 49 | mol_list.append(rmol) 50 | if rmol.GetNumAtoms() > max_natoms: 51 | max_natoms = rmol.GetNumAtoms() 52 | 53 | validity = np.ones( (len(mol_list), max_natoms, max_natoms) ) 54 | eye = np.eye(max_natoms) 55 | for i,mol in enumerate(mol_list): 56 | n = mol.GetNumAtoms() 57 | validity[i, :, n:] = 0 58 | validity[i, n:, :] = 0 59 | validity[i] -= eye 60 | 61 | features = [] 62 | for r in smiles_list: 63 | features.append(get_bin_feature(r, max_natoms)) 64 | 65 | return np.array(features), validity 66 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_release/CoreFinder/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from .mol_graph import max_nb 3 | from ..utils.nn import * 4 | 5 | def gated_convnet(graph_inputs, batch_size=64, hidden_size=300, depth=3, res_block=2): 6 | input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask = graph_inputs 7 | layers = [input_atom] 8 | atom_features = input_atom 9 | for i in range(depth): 10 | fatom_nei = tf.gather_nd(atom_features, atom_graph) 11 | fbond_nei = tf.gather_nd(input_bond, bond_graph) 12 | f_nei = tf.concat([fatom_nei, fbond_nei], 3) 13 | h_nei = linearND(f_nei, hidden_size, "nei_hidden_%d" % i) 14 | g_nei = tf.nn.sigmoid(linearND(f_nei, hidden_size, "nei_gate_%d" % i)) 15 | f_nei = h_nei * g_nei 16 | mask_nei = tf.reshape(tf.sequence_mask(tf.reshape(num_nbs, [-1]), max_nb, dtype=tf.float32), [batch_size,-1,max_nb,1]) 17 | f_nei = tf.reduce_sum(f_nei * mask_nei, -2) 18 | h_self = linearND(atom_features, hidden_size, "self_hidden_%d" % i) 19 | g_self = tf.nn.sigmoid(linearND(atom_features, hidden_size, "self_gate_%d" % i)) 20 | f_self = h_self * g_self 21 | atom_features = (f_nei + f_self) * node_mask 22 | if res_block is not None and i % res_block == 0 and i > 0: 23 | atom_features = atom_features + layers[-2] 24 | layers.append(atom_features) 25 | output_gate = tf.nn.sigmoid(linearND(atom_features, hidden_size, "out_gate")) 26 | output = node_mask * (output_gate * atom_features) 27 | fp = tf.reduce_sum(output, 1) 28 | return atom_features * node_mask, fp 29 | 30 | """ 31 | def rcnn_wl_last(graph_inputs, batch_size, hidden_size, depth, training=True): 32 | input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask = graph_inputs 33 | atom_features = tf.nn.relu(linearND(input_atom, hidden_size, "atom_embedding", init_bias=None)) 34 | mask_nei = tf.reshape(tf.sequence_mask(tf.reshape(num_nbs, [-1]), max_nb, dtype=tf.float32), [batch_size,-1,max_nb,1]) 35 | fbond_nei = tf.gather_nd(input_bond, bond_graph) 36 | 37 | for i in range(depth): 38 | with tf.variable_scope("WL", reuse=(i>0)) as scope: 39 | fatom_nei = tf.gather_nd(atom_features, atom_graph) 40 | l_nei = tf.concat(3, [fatom_nei, fbond_nei]) 41 | nei_label = tf.nn.relu(linearND(l_nei, hidden_size, "label_U2")) 42 | nei_label = tf.reduce_sum(nei_label * mask_nei, -2) 43 | new_label = tf.concat(2, [atom_features, nei_label]) 44 | new_label = linearND(new_label, hidden_size, "label_U1") 45 | atom_features = tf.nn.relu(new_label) 46 | 47 | fatom_nei = tf.gather_nd(atom_features, atom_graph) 48 | h_nei_atom = linearND(fatom_nei, hidden_size, "nei_atom", init_bias=None) 49 | h_nei_bond = linearND(fbond_nei, hidden_size, "nei_bond", init_bias=None) 50 | h_nei = h_nei_atom * h_nei_bond 51 | f_nei = tf.reduce_sum(h_nei * mask_nei, -2) 52 | f_self = linearND(atom_features, hidden_size, "self_atom", init_bias=None) 53 | kernels = f_nei * f_self * node_mask 54 | fp = tf.reduce_sum(kernels, 1) 55 | 56 | return kernels, fp 57 | """ 58 | 59 | def rcnn_wl_last(graph_inputs, batch_size, hidden_size, depth, training=True): 60 | input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask = graph_inputs 61 | atom_features = tf.nn.relu(linearND(input_atom, hidden_size, "atom_embedding", init_bias=None)) 62 | layers = [] 63 | for i in range(depth): 64 | with tf.variable_scope("WL", reuse=(i>0)) as scope: 65 | fatom_nei = tf.gather_nd(atom_features, atom_graph) 66 | fbond_nei = tf.gather_nd(input_bond, bond_graph) 67 | h_nei_atom = linearND(fatom_nei, hidden_size, "nei_atom", init_bias=None) 68 | h_nei_bond = linearND(fbond_nei, hidden_size, "nei_bond", init_bias=None) 69 | h_nei = h_nei_atom * h_nei_bond 70 | mask_nei = tf.reshape(tf.sequence_mask(tf.reshape(num_nbs, [-1]), max_nb, dtype=tf.float32), [batch_size,-1,max_nb,1]) 71 | f_nei = tf.reduce_sum(h_nei * mask_nei, -2) 72 | f_self = linearND(atom_features, hidden_size, "self_atom", init_bias=None) 73 | layers.append(f_nei * f_self * node_mask) 74 | l_nei = tf.concat([fatom_nei, fbond_nei], 3) 75 | nei_label = tf.nn.relu(linearND(l_nei, hidden_size, "label_U2")) 76 | nei_label = tf.reduce_sum(nei_label * mask_nei, -2) 77 | new_label = tf.concat([atom_features, nei_label], 2) 78 | new_label = linearND(new_label, hidden_size, "label_U1") 79 | atom_features = tf.nn.relu(new_label) 80 | #kernels = tf.concat(1, layers) 81 | kernels = layers[-1] 82 | fp = tf.reduce_sum(kernels, 1) 83 | return kernels, fp 84 | 85 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_release/CoreFinder/mol_graph.py: -------------------------------------------------------------------------------- 1 | import rdkit 2 | import rdkit.Chem as Chem 3 | import numpy as np 4 | 5 | elem_list = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'W', 'Ru', 'Nb', 'Re', 'Te', 'Rh', 'Tc', 'Ba', 'Bi', 'Hf', 'Mo', 'U', 'Sm', 'Os', 'Ir', 'Ce','Gd','Ga','Cs', 'unknown'] 6 | atom_fdim = len(elem_list) + 6 + 6 + 6 + 1 7 | bond_fdim = 6 8 | max_nb = 10 9 | 10 | def onek_encoding_unk(x, allowable_set): 11 | if x not in allowable_set: 12 | x = allowable_set[-1] 13 | return list(map(lambda s: x == s, allowable_set)) 14 | 15 | def atom_features(atom): 16 | return np.array(onek_encoding_unk(atom.GetSymbol(), elem_list) 17 | + onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5]) 18 | + onek_encoding_unk(atom.GetExplicitValence(), [1,2,3,4,5,6]) 19 | + onek_encoding_unk(atom.GetImplicitValence(), [0,1,2,3,4,5]) 20 | + [atom.GetIsAromatic()], dtype=np.float32) 21 | 22 | def bond_features(bond): 23 | bt = bond.GetBondType() 24 | return np.array([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.GetIsConjugated(), bond.IsInRing()], dtype=np.float32) 25 | 26 | def smiles2graph(smiles, idxfunc=lambda x:x.GetIdx()): 27 | mol = Chem.MolFromSmiles(smiles) 28 | if not mol: 29 | raise ValueError("Could not parse smiles string:", smiles) 30 | 31 | n_atoms = mol.GetNumAtoms() 32 | n_bonds = max(mol.GetNumBonds(), 1) 33 | fatoms = np.zeros((n_atoms, atom_fdim)) 34 | fbonds = np.zeros((n_bonds, bond_fdim)) 35 | atom_nb = np.zeros((n_atoms, max_nb), dtype=np.int32) 36 | bond_nb = np.zeros((n_atoms, max_nb), dtype=np.int32) 37 | num_nbs = np.zeros((n_atoms,), dtype=np.int32) 38 | 39 | for atom in mol.GetAtoms(): 40 | idx = idxfunc(atom) 41 | if idx >= n_atoms: 42 | raise Exception(smiles) 43 | fatoms[idx] = atom_features(atom) 44 | 45 | for bond in mol.GetBonds(): 46 | a1 = idxfunc(bond.GetBeginAtom()) 47 | a2 = idxfunc(bond.GetEndAtom()) 48 | idx = bond.GetIdx() 49 | if num_nbs[a1] == max_nb or num_nbs[a2] == max_nb: 50 | raise Exception(smiles) 51 | atom_nb[a1,num_nbs[a1]] = a2 52 | atom_nb[a2,num_nbs[a2]] = a1 53 | bond_nb[a1,num_nbs[a1]] = idx 54 | bond_nb[a2,num_nbs[a2]] = idx 55 | num_nbs[a1] += 1 56 | num_nbs[a2] += 1 57 | fbonds[idx] = bond_features(bond) 58 | return fatoms, fbonds, atom_nb, bond_nb, num_nbs 59 | 60 | def pack2D(arr_list): 61 | N = max([x.shape[0] for x in arr_list]) 62 | M = max([x.shape[1] for x in arr_list]) 63 | a = np.zeros((len(arr_list), N, M)) 64 | for i, arr in enumerate(arr_list): 65 | n = arr.shape[0] 66 | m = arr.shape[1] 67 | a[i,0:n,0:m] = arr 68 | return a 69 | 70 | def pack2D_withidx(arr_list): 71 | N = max([x.shape[0] for x in arr_list]) 72 | M = max([x.shape[1] for x in arr_list]) 73 | a = np.zeros((len(arr_list), N, M, 2)) 74 | for i, arr in enumerate(arr_list): 75 | n = arr.shape[0] 76 | m = arr.shape[1] 77 | a[i,0:n,0:m,0] = i 78 | a[i,0:n,0:m,1] = arr 79 | return a 80 | 81 | def pack1D(arr_list): 82 | N = max([x.shape[0] for x in arr_list]) 83 | a = np.zeros((len(arr_list), N)) 84 | for i, arr in enumerate(arr_list): 85 | n = arr.shape[0] 86 | a[i,0:n] = arr 87 | return a 88 | 89 | def get_mask(arr_list): 90 | N = max([x.shape[0] for x in arr_list]) 91 | a = np.zeros((len(arr_list), N)) 92 | for i, arr in enumerate(arr_list): 93 | for j in range(arr.shape[0]): 94 | a[i][j] = 1 95 | return a 96 | 97 | def smiles2graph_batch(smiles_list, idxfunc=lambda x:x.GetIdx()): 98 | res = map(lambda x:smiles2graph(x,idxfunc), smiles_list) 99 | fatom_list, fbond_list, gatom_list, gbond_list, nb_list = zip(*res) 100 | return pack2D(fatom_list), pack2D(fbond_list), pack2D_withidx(gatom_list), pack2D_withidx(gbond_list), pack1D(nb_list), get_mask(fatom_list) 101 | 102 | if __name__ == "__main__": 103 | import sys 104 | np.set_printoptions(threshold=sys.maxsize) 105 | a,b,c,d,e,f = smiles2graph_batch(["c1cccnc1",'c1nccc2n1ccc2']) 106 | print(a) 107 | print(b) 108 | print(c) 109 | print(d) 110 | print(e) 111 | print(f) 112 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_release/CoreFinder/mol_graph_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import sys 4 | import unittest 5 | 6 | import numpy as np 7 | 8 | import askcos.synthetic.evaluation.rexgen_release.CoreFinder.mol_graph as mg 9 | 10 | 11 | class TestCFMolGraph(unittest.TestCase): 12 | 13 | def test_01_smiles2graph_batch(self): 14 | """Test askcos.synthetic.evaluation.rexgen_release.CoreFinder.mol_graph.smiles2graph_batch""" 15 | np.set_printoptions(threshold=sys.maxsize) 16 | result = mg.smiles2graph_batch(["c1cccnc1", 'c1nccc2n1ccc2']) 17 | with open(os.path.join(os.path.dirname(__file__), 'test_data/CF_smiles2graph_batch.pkl'), 'rb') as t: 18 | expected = pickle.load(t, encoding='iso-8859-1') 19 | for e, r in zip(expected, result): 20 | self.assertTrue((e == r).all()) 21 | 22 | 23 | if __name__ == '__main__': 24 | res = unittest.main(verbosity=3, exit=False) 25 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_release/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/synthetic/evaluation/rexgen_release/__init__.py -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_release/predict.py: -------------------------------------------------------------------------------- 1 | from askcos.synthetic.evaluation.rexgen_release.CandRanker import CandRanker 2 | from askcos.synthetic.evaluation.rexgen_release.CoreFinder import CoreFinder 3 | import rdkit.Chem as Chem 4 | import sys 5 | import os 6 | 7 | class TFFP(): 8 | '''Template-free forward predictor''' 9 | def __init__(self): 10 | froot = os.path.dirname(__file__) 11 | self.finder = CoreFinder(hidden_size=300, depth=3, batch_size=1) 12 | print(os.path.join(froot, 'CoreFinder', 'uspto-300-3')) 13 | self.finder.load_model(os.path.join(froot, 'CoreFinder', 'uspto-300-3')) 14 | self.ranker = CandRanker(hidden_size=320, depth=3, TOPK=100) 15 | self.ranker.load_model(os.path.join(froot, 'CandRanker', 'uspto-320-3')) 16 | 17 | def predict(self, smi, top_n=100, num_core=8): 18 | m = Chem.MolFromSmiles(smi) 19 | if not m: 20 | if smi[-1] == '.': 21 | m = Chem.MolFromSmiles(smi[:-1]) 22 | if not m: 23 | raise ValueError('Could not parse molecule for TFFP! {}'.format(smi)) 24 | [a.SetIntProp('molAtomMapNumber', i+1) for (i, a) in enumerate(m.GetAtoms())] 25 | s = Chem.MolToSmiles(m) 26 | rcores = self.finder.predict([s], num_core=num_core)[0] 27 | outcomes = self.ranker.predict_one(s, rcores, scores=True, top_n=top_n) 28 | return(outcomes) 29 | 30 | 31 | if __name__ == "__main__": 32 | tffp = TFFP() 33 | print(tffp.predict('CCCO.CCCBr')) 34 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_release/predict_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import unittest 4 | 5 | import askcos.synthetic.evaluation.rexgen_release.predict as p 6 | 7 | 8 | @unittest.skip('Models have been removed.') 9 | class TestPredict(unittest.TestCase): 10 | 11 | def test_01_predict(self): 12 | """Test template free forward predictor.""" 13 | tffp = p.TFFP() 14 | result = tffp.predict('CCCO.CCCBr') 15 | 16 | with open(os.path.join(os.path.dirname(__file__), 'test_data/predict.pkl'), 'rb') as t: 17 | expected = pickle.load(t, encoding='iso-8859-1') 18 | 19 | self.assertEqual(len(expected), len(result)) 20 | 21 | for e, r in zip(expected, result): 22 | self.assertEqual(e['smiles'], r['smiles']) 23 | self.assertEqual(e['rank'], r['rank']) 24 | self.assertAlmostEqual(e['prob'], r['prob'], places=4) 25 | self.assertAlmostEqual(e['score'], r['score'], places=4) 26 | 27 | 28 | if __name__ == '__main__': 29 | res = unittest.main(verbosity=3, exit=False) 30 | -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/rexgen_release/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/synthetic/evaluation/rexgen_release/utils/__init__.py -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/test_data/evaluator.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/synthetic/evaluation/test_data/evaluator.pkl -------------------------------------------------------------------------------- /askcos/synthetic/evaluation/test_data/template_free.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/synthetic/evaluation/test_data/template_free.pkl -------------------------------------------------------------------------------- /askcos/synthetic/impurity/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/synthetic/impurity/__init__.py -------------------------------------------------------------------------------- /askcos/synthetic/selectivity/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/synthetic/selectivity/__init__.py -------------------------------------------------------------------------------- /askcos/synthetic/selectivity/electronegs.py: -------------------------------------------------------------------------------- 1 | electronegs = { 2 | 1 : 2.2 , 3 | 2 : 0 , 4 | 3 : 0.98 , 5 | 4 : 1.57 , 6 | 5 : 2.04 , 7 | 6 : 2.55 , 8 | 7 : 3.04 , 9 | 8 : 3.44 , 10 | 9 : 3.98 , 11 | 10 : 0 , 12 | 11 : 0.93 , 13 | 12 : 1.31 , 14 | 13 : 1.61 , 15 | 14 : 1.9 , 16 | 15 : 2.19 , 17 | 16 : 2.58 , 18 | 17 : 3.16 , 19 | 18 : 0 , 20 | 19 : 0.82 , 21 | 20 : 1 , 22 | 21 : 1.36 , 23 | 22 : 1.54 , 24 | 23 : 1.63 , 25 | 24 : 1.66 , 26 | 25 : 1.55 , 27 | 26 : 1.83 , 28 | 27 : 1.88 , 29 | 28 : 1.91 , 30 | 29 : 1.9 , 31 | 30 : 1.65 , 32 | 31 : 1.81 , 33 | 32 : 2.01 , 34 | 33 : 2.18 , 35 | 34 : 2.55 , 36 | 35 : 2.96 , 37 | 36 : 3 , 38 | 37 : 0.82 , 39 | 38 : 0.95 , 40 | 39 : 1.22 , 41 | 40 : 1.33 , 42 | 41 : 1.6 , 43 | 42 : 2.16 , 44 | 43 : 1.9 , 45 | 44 : 2.2 , 46 | 45 : 2.28 , 47 | 46 : 2.2 , 48 | 47 : 1.93 , 49 | 48 : 1.69 , 50 | 49 : 1.78 , 51 | 50 : 1.96 , 52 | 51 : 2.05 , 53 | 52 : 2.1 , 54 | 53 : 2.66 , 55 | 54 : 2.6 , 56 | 55 : 0.79 , 57 | 56 : 0.89 , 58 | 57 : 1.1 , 59 | 58 : 1.12 , 60 | 59 : 1.13 , 61 | 60 : 1.14 , 62 | 61 : 0 , 63 | 62 : 1.17 , 64 | 63 : 0 , 65 | 64 : 1.2 , 66 | 65 : 0 , 67 | 66 : 1.22 , 68 | 67 : 1.23 , 69 | 68 : 1.24 , 70 | 69 : 1.25 , 71 | 70 : 0 , 72 | 71 : 1.27 , 73 | 72 : 1.3 , 74 | 73 : 1.5 , 75 | 74 : 2.36 , 76 | 75 : 1.9 , 77 | 76 : 2.2 , 78 | 77 : 2.2 , 79 | 78 : 2.28 , 80 | 79 : 2.54 , 81 | 80 : 2 , 82 | 81 : 1.62 , 83 | 82 : 2.33 , 84 | 83 : 2.02 , 85 | 84 : 2 , 86 | 85 : 2.2 , 87 | 86 : 0 , 88 | 87 : 0 , 89 | 88 : 0.9 , 90 | 89 : 1.1 , 91 | 90 : 1.3 , 92 | 91 : 1.5 , 93 | 92 : 1.38 , 94 | 93 : 1.36 , 95 | 94 : 1.28 , 96 | 95 : 1.3 , 97 | 96 : 1.3 , 98 | 97 : 1.3 , 99 | 98 : 1.3 , 100 | 99 : 1.3 , 101 | 100 : 1.3 , 102 | 101 : 1.3 , 103 | 102 : 1.3 , 104 | } -------------------------------------------------------------------------------- /askcos/synthetic/selectivity/general_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/synthetic/selectivity/general_model/__init__.py -------------------------------------------------------------------------------- /askcos/synthetic/selectivity/general_model/data_loading.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from askcos.synthetic.selectivity.mol_graph import smiles2graph_pr, pack1D, pack2D, pack2D_withidx, \ 4 | get_mask, binary_features_batch, smiles2graph_pr_qm, pack2D_cores 5 | 6 | 7 | def gnn_data_generation(smiles, products): 8 | 9 | size = len(products.split('.')) 10 | prs_extend = [smiles2graph_pr(p, smiles, idxfunc=lambda x: x.GetIdx(), core_buffer=2) 11 | for p in products.split('.')] 12 | smiles_extend = [smiles] * size 13 | 14 | res_extend, prods_extend = zip(*prs_extend) 15 | # graph_inputs for reactants 16 | fatom_list, fbond_list, gatom_list, gbond_list, nb_list, core_mask = zip(*res_extend) 17 | res_graph_inputs = (pack2D(fatom_list), pack2D(fbond_list), pack2D_withidx(gatom_list), 18 | pack2D_withidx(gbond_list), pack1D(nb_list), get_mask(fatom_list), 19 | pack1D(core_mask), binary_features_batch(smiles_extend)) 20 | 21 | # graph_inputs for products 22 | fatom_list, fbond_list, gatom_list, gbond_list, nb_list, core_mask = zip(*prods_extend) 23 | prods_graph_inputs = (pack2D(fatom_list), pack2D(fbond_list), pack2D_withidx(gatom_list), 24 | pack2D_withidx(gbond_list), pack1D(nb_list), get_mask(fatom_list), 25 | pack1D(core_mask)) 26 | 27 | return res_graph_inputs + prods_graph_inputs 28 | 29 | 30 | 31 | def qm_gnn_data_generation(smiles, products, reagents, qm_descriptors): 32 | 33 | prs_extend = [smiles2graph_pr_qm(smiles, p, reagents, qm_descriptors) for p in products.split('.')] 34 | 35 | fatom_list, fatom_qm_list, fbond_list, gatom_list, gbond_list, nb_list, cores, connect, \ 36 | rg_fatom_list, rg_fbond_list, rg_gatom_list, rg_gbond_list, rg_nb_list = zip(*prs_extend) 37 | 38 | res_graph_inputs = (pack2D(fatom_list), pack2D(fbond_list), pack2D_withidx(gatom_list), 39 | pack2D_withidx(gbond_list), pack1D(nb_list), get_mask(fatom_list), 40 | np.stack([pack2D_cores(cores)] * len(prs_extend)), # trick tensorflow 41 | pack2D(fatom_qm_list), 42 | np.stack([np.concatenate(connect, axis=0)] * len(prs_extend)), # trick tensorflow 43 | pack2D(rg_fatom_list), pack2D(rg_fbond_list), pack2D_withidx(rg_gatom_list), 44 | pack2D_withidx(rg_gbond_list), pack1D(rg_nb_list), get_mask(rg_fatom_list), 45 | ) 46 | 47 | return res_graph_inputs 48 | -------------------------------------------------------------------------------- /askcos/synthetic/selectivity/general_model/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import layers 3 | from tensorflow.keras import backend as K 4 | 5 | class WLN_Layer(tf.keras.layers.Layer): 6 | ''' 7 | A Keras class for implementation ICML paper Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network 8 | 9 | Init 10 | hidden_size: The hidden size of the dense layers 11 | depth: How many iterations that a new representation of each atom is computed. Each iteration goes one atom further away from the 12 | initial starting point. The number of distinct labels from the WLN grows ponentially with the number of iterations 13 | max_nb: Max number of bonds. Generally set at 10 and is specified by the graph generation procedure for the inputs 14 | 15 | Inputs 16 | graph_inputs: molecular graph that has atom features, bond features, the atom attachments, bond attachments 17 | number of bonds for each atom, and a node mask since batches have to be padded 18 | 19 | Output 20 | kernels: The WLN graph kernal which is the updated representation of each atom 21 | ''' 22 | def __init__(self, hidden_size, depth, max_nb=10): 23 | super(WLN_Layer, self).__init__() 24 | self.hidden_size = hidden_size 25 | self.depth = depth 26 | self.max_nb = max_nb 27 | 28 | def build(self, input_shape): 29 | self.atom_features = layers.Dense(self.hidden_size, kernel_initializer=tf.random_normal_initializer(stddev=0.1), use_bias=False, input_shape=(50,)) 30 | self.nei_atom = layers.Dense(self.hidden_size, kernel_initializer=tf.random_normal_initializer(stddev=0.1), use_bias=False, input_shape=(input_shape[0], self.max_nb, self.hidden_size,), ) 31 | self.nei_bond = layers.Dense(self.hidden_size, kernel_initializer=tf.random_normal_initializer(stddev=0.1), use_bias=False) 32 | self.self_atom = layers.Dense(self.hidden_size, kernel_initializer=tf.random_normal_initializer(stddev=0.1), use_bias=False) 33 | self.label_U2 = layers.Dense(self.hidden_size, activation=K.relu, kernel_initializer=tf.random_normal_initializer(stddev=0.1)) 34 | self.label_U1 = layers.Dense(self.hidden_size, activation=K.relu, kernel_initializer=tf.random_normal_initializer(stddev=0.1)) 35 | self.node_reshape = layers.Reshape((-1,1)) 36 | super(WLN_Layer, self).build(input_shape) 37 | 38 | def call(self, graph_inputs): 39 | input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask, __ = graph_inputs 40 | #calculate the initial atom features using only its own features (no neighbors) 41 | atom_features = self.atom_features(input_atom) 42 | layers = [] 43 | for i in range(self.depth): 44 | fatom_nei = tf.gather_nd(atom_features, tf.dtypes.cast(atom_graph,tf.int64)) #(batch, #atoms, max_nb, hidden) 45 | fbond_nei = tf.gather_nd(input_bond, tf.dtypes.cast(bond_graph, tf.int64)) #(batch, #atoms, max_nb, #bond features) 46 | h_nei_atom = self.nei_atom(fatom_nei) #(batch, #atoms, max_nb, hidden) 47 | h_nei_bond = self.nei_bond(fbond_nei) #(batch, #atoms, max_nb, hidden) 48 | h_nei = h_nei_atom * h_nei_bond #(batch, #atoms, max_nb, hidden) 49 | mask_nei = K.reshape(tf.sequence_mask(K.reshape(num_nbs, [-1]), self.max_nb, dtype=tf.float32), [K.shape(input_atom)[0],-1, self.max_nb,1]) 50 | f_nei = K.sum(h_nei * mask_nei, axis=-2, keepdims=False) #(batch, #atoms, hidden) sum across atoms 51 | f_self = self.self_atom(atom_features) #(batch, #atoms, hidden) 52 | layers.append(f_nei * f_self * self.node_reshape(node_mask))#, -1)) 53 | l_nei = K.concatenate([fatom_nei, fbond_nei], axis=3) #(batch, #atoms, max_nb, ) 54 | pre_label = self.label_U2(l_nei) 55 | nei_label = K.sum(pre_label * mask_nei, axis=-2, keepdims=False) 56 | new_label = K.concatenate([atom_features, nei_label], axis=2) 57 | atom_features = self.label_U1(new_label) 58 | kernels = layers[-1] 59 | return kernels 60 | 61 | class Global_Attention(tf.keras.layers.Layer): 62 | 63 | 64 | def __init__(self, hidden_size): 65 | super(Global_Attention, self).__init__() 66 | self.hidden_size = hidden_size 67 | 68 | def build(self, input_shape): 69 | self.att_atom_feature = layers.Dense(self.hidden_size, kernel_initializer=tf.random_normal_initializer(stddev=0.1), use_bias=False, input_shape=(self.hidden_size,)) 70 | self.att_bin_feature = layers.Dense(self.hidden_size, kernel_initializer=tf.random_normal_initializer(stddev=0.1)) 71 | self.att_score = layers.Dense(1, activation=K.sigmoid, kernel_initializer=tf.random_normal_initializer(stddev=0.1)) 72 | self.reshape1 = layers.Reshape((1,-1,self.hidden_size)) 73 | self.reshape2 = layers.Reshape((-1,1,self.hidden_size)) 74 | super(Global_Attention, self).build(input_shape) 75 | 76 | def call(self, inputs, bin_features): 77 | atom_hiddens1 = self.reshape1(inputs) 78 | atom_hiddens2 = self.reshape2(inputs) 79 | atom_pair = atom_hiddens1 + atom_hiddens2 80 | att_hidden = K.relu(self.att_atom_feature(atom_pair) + self.att_bin_feature(bin_features)) 81 | att_score = self.att_score(att_hidden) 82 | att_context = att_score * atom_hiddens1 83 | return K.sum(att_context, axis=2, keepdims=False), atom_pair 84 | -------------------------------------------------------------------------------- /askcos/synthetic/selectivity/general_model/loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import backend as K 3 | 4 | 5 | def wln_loss(y_true, y_pred): 6 | 7 | flat_label = K.cast(K.reshape(y_true, [-1]), 'float32') 8 | flat_score = K.reshape(y_pred, [-1]) 9 | 10 | reaction_seg = K.cast(tf.math.cumsum(flat_label), 'int32') - tf.constant([1], dtype='int32') 11 | 12 | max_seg = tf.gather(tf.math.segment_max(flat_score, reaction_seg), reaction_seg) 13 | exp_score = tf.exp(flat_score-max_seg) 14 | 15 | softmax_denominator = tf.gather(tf.math.segment_sum(exp_score, reaction_seg), reaction_seg) 16 | softmax_score = exp_score/softmax_denominator 17 | 18 | softmax_score = tf.clip_by_value(softmax_score, K.epsilon(), 1-K.epsilon()) 19 | 20 | try: 21 | loss = -tf.reduce_sum(flat_label * tf.math.log(softmax_score))/flat_score.shape[0] 22 | return loss 23 | except: 24 | return -tf.reduce_sum(flat_label * tf.math.log(softmax_score)) -------------------------------------------------------------------------------- /askcos/synthetic/selectivity/general_model/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import layers 3 | import tensorflow.keras.backend as K 4 | from askcos.synthetic.selectivity.general_model.layers import WLN_Layer, Global_Attention 5 | 6 | 7 | class WLNReactionClassifier(tf.keras.Model): 8 | ''' 9 | 10 | ''' 11 | def __init__(self, hidden_size=200, depth=4, max_nb=10): 12 | super(WLNReactionClassifier, self).__init__() 13 | self.hidden_size = hidden_size 14 | self.reactants_WLN = WLN_Layer(hidden_size, depth, max_nb) 15 | self.products_WLN = WLN_Layer(hidden_size, depth, max_nb) 16 | self.attention = Global_Attention(hidden_size) 17 | 18 | self.reactant_feature = layers.Dense(hidden_size, activation=K.relu, kernel_initializer=tf.random_normal_initializer(stddev=0.1), use_bias=False) 19 | self.product_feature = layers.Dense(hidden_size, activation=K.relu, kernel_initializer=tf.random_normal_initializer(stddev=0.1), use_bias=False) 20 | 21 | self.reaction_score = layers.Dense(1, kernel_initializer=tf.random_normal_initializer(stddev=0.1)) 22 | 23 | self.node_reshape = layers.Reshape((-1, 1)) 24 | self.core_reshape = layers.Reshape((-1, 1)) 25 | 26 | def call(self, inputs): 27 | res_inputs = inputs[:8] 28 | prods_inputs = inputs[8:] 29 | 30 | res_atom_mask = res_inputs[-3] 31 | prod_atom_mask = prods_inputs[-2] 32 | 33 | res_core_mask = res_inputs[-2] 34 | prod_core_mask = prods_inputs[-1] 35 | 36 | res_bin_features = res_inputs[-1] 37 | 38 | res_atom_hidden = self.reactants_WLN(res_inputs[:-1]) 39 | res_att_context, _ = self.attention(res_atom_hidden, res_bin_features) 40 | res_atom_hidden = res_atom_hidden + res_att_context 41 | res_atom_hidden = K.relu(res_atom_hidden) 42 | res_atom_mask = self.node_reshape(res_atom_mask) 43 | res_core_mask = self.core_reshape(res_core_mask) 44 | res_mol_hidden = K.sum(res_atom_hidden*res_atom_mask*res_core_mask, axis=-2) 45 | res_mol_hidden = self.reactant_feature(res_mol_hidden) 46 | 47 | prod_atom_hidden = self.products_WLN(prods_inputs) 48 | prod_atom_hidden = K.relu(prod_atom_hidden) 49 | prod_atom_mask = self.node_reshape(prod_atom_mask) 50 | prod_core_mask = self.core_reshape(prod_core_mask) 51 | prod_mol_hidden = K.sum(prod_atom_hidden*prod_atom_mask*prod_core_mask, axis=-2) 52 | prod_mol_hidden = self.product_feature(prod_mol_hidden) 53 | 54 | reaction_hidden = tf.concat([prod_mol_hidden, res_mol_hidden], axis=-1) 55 | reaction_score = self.reaction_score(reaction_hidden) 56 | 57 | return reaction_score 58 | -------------------------------------------------------------------------------- /askcos/synthetic/selectivity/general_model/qm_models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import layers 3 | import tensorflow.keras.backend as K 4 | 5 | from .qm_layers import WLN_Layer, Global_Attention 6 | 7 | class QMWLNPairwiseAtomClassifier(tf.keras.Model): 8 | ''' 9 | 10 | ''' 11 | 12 | def __init__(self, hidden_size=200, qm_size=160, depth=4, max_nb=10): 13 | super(QMWLNPairwiseAtomClassifier, self).__init__() 14 | self.hidden_size = hidden_size 15 | self.qm_size = qm_size 16 | self.reactants_WLN = WLN_Layer(hidden_size, depth, max_nb) 17 | self.reagents_WLN = WLN_Layer(hidden_size, depth, max_nb) 18 | self.attention = Global_Attention(hidden_size) 19 | 20 | self.reaction_score0 = layers.Dense(2*hidden_size+qm_size+10, activation=K.relu, kernel_initializer=tf.random_normal_initializer(stddev=0.01), use_bias=False) 21 | 22 | self.reaction_score = layers.Dense(1, kernel_initializer=tf.random_normal_initializer(stddev=0.01), use_bias=False) 23 | 24 | self.reshape1 = layers.Reshape((1, -1, 2*hidden_size+qm_size)) 25 | self.reshape2 = layers.Reshape((-1, 1, 2*hidden_size+qm_size)) 26 | 27 | def call(self, inputs): 28 | res_inputs = inputs[:6] 29 | res_core_mask = inputs[6][0] 30 | fatom_qm = inputs[7] 31 | connect = inputs[8][0] 32 | 33 | rg_inputs = inputs[9:] 34 | 35 | #machine learned representation 36 | res_atom_hidden = self.reactants_WLN(res_inputs) 37 | rg_atom_hidden = self.reagents_WLN(rg_inputs) 38 | 39 | res_att_context = self.attention(res_atom_hidden, rg_atom_hidden) 40 | 41 | res_atom_hidden = K.concatenate([res_atom_hidden, res_att_context, fatom_qm], axis=-1) 42 | 43 | #select out reacting cores 44 | atom_hiddens1 = self.reshape1(res_atom_hidden) 45 | atom_hiddens2 = self.reshape2(res_atom_hidden) 46 | atom_pair = atom_hiddens1 + atom_hiddens2 47 | 48 | atom_pair = tf.gather_nd(atom_pair, res_core_mask) 49 | reaction_hidden = K.concatenate([atom_pair, connect], axis=-1) 50 | reaction_hidden = self.reaction_score0(reaction_hidden) 51 | 52 | reaction_seg = res_core_mask[:, 0] 53 | reaction_hidden = tf.math.segment_mean(reaction_hidden, reaction_seg) 54 | reaction_score = self.reaction_score(reaction_hidden) 55 | return reaction_score 56 | 57 | 58 | class WLNPairwiseAtomClassifierNoReagent(tf.keras.Model): 59 | ''' 60 | 61 | ''' 62 | 63 | def __init__(self, hidden_size=200, qm_size=160, depth=4, max_nb=10): 64 | super(WLNPairwiseAtomClassifierNoReagent, self).__init__() 65 | self.hidden_size = hidden_size 66 | self.qm_size = qm_size 67 | self.reactants_WLN = WLN_Layer(hidden_size, depth, max_nb) 68 | self.reaction_score0 = layers.Dense(hidden_size+qm_size+10, activation=K.relu, kernel_initializer=tf.random_normal_initializer(stddev=0.01), use_bias=False) 69 | 70 | self.reaction_score = layers.Dense(1, kernel_initializer=tf.random_normal_initializer(stddev=0.01), use_bias=False) 71 | 72 | self.reshape1 = layers.Reshape((1, -1, hidden_size+qm_size)) 73 | self.reshape2 = layers.Reshape((-1, 1, hidden_size+qm_size)) 74 | 75 | def call(self, inputs): 76 | res_inputs = inputs[:6] 77 | res_core_mask = inputs[6][0] 78 | fatom_qm = inputs[7] 79 | connect = inputs[8][0] 80 | 81 | rg_inputs = inputs[9:] 82 | 83 | #machine learned representation 84 | res_atom_hidden = self.reactants_WLN(res_inputs) 85 | 86 | res_atom_hidden = K.concatenate([res_atom_hidden, fatom_qm], axis=-1) 87 | #select out reacting cores 88 | atom_hiddens1 = self.reshape1(res_atom_hidden) 89 | atom_hiddens2 = self.reshape2(res_atom_hidden) 90 | atom_pair = atom_hiddens1 + atom_hiddens2 91 | 92 | atom_pair = tf.gather_nd(atom_pair, res_core_mask) 93 | reaction_hidden = K.concatenate([atom_pair, connect], axis=-1) 94 | reaction_hidden = self.reaction_score0(reaction_hidden) 95 | 96 | reaction_seg = res_core_mask[:, 0] 97 | reaction_hidden = tf.math.segment_mean(reaction_hidden, reaction_seg) 98 | reaction_score = self.reaction_score(reaction_hidden) 99 | return reaction_score 100 | 101 | -------------------------------------------------------------------------------- /askcos/synthetic/selectivity/general_selectivity_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from askcos.synthetic.selectivity.general_selectivity import QmGnnGeneralSelectivityPredictor, \ 4 | QmGnnGeneralSelectivityPredictorNoReagent, GnnGeneralSelectivityPredictor 5 | 6 | class GeneralSelectivity(unittest.TestCase): 7 | 8 | def test_qm_gnn_predictor_reagents(self): 9 | """Test qm_gnn predictor""" 10 | rxn = 'CC(COc1n[nH]cc1)C.CC(C)(OC(c1c(Cl)nc(Cl)cc1)=O)C>CN(C=O)C.O>CC(OC(c1ccc(n2ccc(OCC(C)C)n2)nc1Cl)=O)(C)C' 11 | predictor = QmGnnGeneralSelectivityPredictor() 12 | 13 | res = predictor.predict(rxn) 14 | self.assertEqual(len(res), 2) 15 | self.assertEqual(type(res[0]), dict) 16 | self.assertAlmostEqual(res[0]['prob'], 1, 2) 17 | 18 | def test_qm_gnn_predictor_no_reagents(self): 19 | """Test qm_gnn predictor""" 20 | rxn = 'CC(COc1n[nH]cc1)C.CC(C)(OC(c1c(Cl)nc(Cl)cc1)=O)C>>CC(OC(c1ccc(n2ccc(OCC(C)C)n2)nc1Cl)=O)(C)C' 21 | predictor = QmGnnGeneralSelectivityPredictorNoReagent() 22 | 23 | res = predictor.predict(rxn) 24 | self.assertEqual(len(res), 2) 25 | self.assertEqual(type(res[0]), dict) 26 | self.assertAlmostEqual(res[0]['prob'], 1, 2) 27 | 28 | def test_gnn_predictor(self): 29 | """Test qm_gnn predictor""" 30 | rxn = 'CC(COc1n[nH]cc1)C.CC(C)(OC(c1c(Cl)nc(Cl)cc1)=O)C>>CC(OC(c1ccc(n2ccc(OCC(C)C)n2)nc1Cl)=O)(C)C' 31 | predictor = GnnGeneralSelectivityPredictor() 32 | 33 | res = predictor.predict(rxn) 34 | self.assertEqual(len(res), 2) 35 | self.assertEqual(type(res[0]), dict) 36 | self.assertAlmostEqual(res[0]['prob'], 1, 2) 37 | 38 | 39 | if __name__ == '__main__': 40 | res = unittest.main(verbosity=3, exit=False) 41 | -------------------------------------------------------------------------------- /askcos/synthetic/selectivity/ioutils_direct.py: -------------------------------------------------------------------------------- 1 | import rdkit.Chem as Chem 2 | from askcos.synthetic.selectivity.mol_graph import bond_fdim, bond_features 3 | import numpy as np 4 | 5 | binary_fdim = 5 + bond_fdim 6 | 7 | def get_bin_feature(r, max_natoms): 8 | comp = {} 9 | rmol = Chem.MolFromSmiles(r) 10 | for i, frag_ids in enumerate(Chem.GetMolFrags(rmol)): 11 | for idx in frag_ids: 12 | comp[idx] = i 13 | n_comp = len(r.split('.')) 14 | n_atoms = rmol.GetNumAtoms() 15 | bond_map = {} 16 | for bond in rmol.GetBonds(): 17 | a1 = bond.GetBeginAtom().GetIdx() 18 | a2 = bond.GetEndAtom().GetIdx() 19 | bond_map[(a1,a2)] = bond_map[(a2,a1)] = bond 20 | 21 | features = [] 22 | for i in range(max_natoms): 23 | for j in range(max_natoms): 24 | f = np.zeros((binary_fdim,)) 25 | if i >= n_atoms or j >= n_atoms or i == j: 26 | features.append(f) 27 | continue 28 | if (i,j) in bond_map: 29 | bond = bond_map[(i,j)] 30 | f[1:1+bond_fdim] = bond_features(bond) 31 | else: 32 | f[0] = 1.0 33 | f[-4] = 1.0 if comp[i] != comp[j] else 0.0 34 | f[-3] = 1.0 if comp[i] == comp[j] else 0.0 35 | f[-2] = 1.0 if n_comp == 1 else 0.0 36 | f[-1] = 1.0 if n_comp > 1 else 0.0 37 | features.append(f) 38 | return np.vstack(features).reshape((max_natoms,max_natoms,binary_fdim)) 39 | 40 | def binary_features_batch(r_list): 41 | max_natoms = 0 42 | for r in r_list: 43 | rmol = Chem.MolFromSmiles(r) 44 | if rmol.GetNumAtoms() > max_natoms: 45 | max_natoms = rmol.GetNumAtoms() 46 | features = [] 47 | for r in r_list: 48 | features.append(get_bin_feature(r,max_natoms)) 49 | return np.array(features) 50 | -------------------------------------------------------------------------------- /askcos/synthetic/selectivity/multitask_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle as pk 3 | from functools import partial 4 | 5 | import tensorflow.compat.v1 as tf 6 | 7 | from askcos import global_config as gc 8 | from askcos.synthetic.evaluation.rexgen_direct.core_wln_global.models import rcnn_wl_last 9 | from askcos.synthetic.evaluation.rexgen_direct.core_wln_global.nn import linearND 10 | from askcos.synthetic.selectivity.ioutils_direct import binary_features_batch, binary_fdim 11 | from askcos.synthetic.selectivity.mol_graph import get_atom_fdim as adim, bond_fdim as bdim, max_nb, smiles2graph_list as _s2g 12 | 13 | model_path = gc.SELECTIVITY['model_path'] 14 | 15 | 16 | class tf_predictor(): 17 | def __init__(self, depth=5, hidden_size=300, batch_size=1): 18 | 19 | self.depth = depth 20 | self.hidden_size = hidden_size 21 | self.batch_size = batch_size 22 | with open(os.path.join(os.path.dirname(__file__), 'task_dict.pkl'), 'rb') as f: 23 | self.task_dict = pk.load(f) 24 | self.task_dict_rev = {v: k for k, v in self.task_dict.items()} 25 | self.saver = None 26 | self.num_tasks = len(self.task_dict) 27 | self.save_path = model_path 28 | self.smiles2graph_batch = partial(_s2g, idxfunc=lambda x:x.GetIdx(), include_electronegs=True) 29 | self.adim = adim(include_electronegs=True) 30 | self.bdim = bdim 31 | self.max_nb = max_nb 32 | 33 | def build(self): 34 | # Unpack for convenience 35 | batch_size = self.batch_size 36 | hidden_size = self.hidden_size 37 | adim = self.adim 38 | bdim = self.bdim 39 | max_nb = self.max_nb 40 | depth = self.depth 41 | 42 | self.session = tf.Session() 43 | 44 | input_atom = tf.placeholder(tf.float32, [batch_size, None, adim]) 45 | input_bond = tf.placeholder(tf.float32, [batch_size, None, bdim]) 46 | atom_graph = tf.placeholder(tf.int32, [batch_size, None, max_nb, 2]) 47 | bond_graph = tf.placeholder(tf.int32, [batch_size, None, max_nb, 2]) 48 | num_nbs = tf.placeholder(tf.int32, [batch_size, None]) 49 | node_mask = tf.placeholder(tf.float32, [batch_size, None]) 50 | self._src_holder = [input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask] 51 | self._binary = tf.placeholder(tf.float32, [batch_size, None, None, binary_fdim]) 52 | 53 | node_mask_exp = tf.expand_dims(node_mask, -1) 54 | graph_inputs = (input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask_exp) 55 | 56 | #WLN-NN model 57 | with tf.variable_scope("encoder"): 58 | atom_hiddens, _ = rcnn_wl_last(graph_inputs, batch_size=batch_size, 59 | hidden_size=hidden_size, depth=depth) 60 | 61 | # For each pair of atoms, compute an attention score 62 | atom_hiddens1 = tf.reshape(atom_hiddens, [batch_size, 1, -1, hidden_size]) 63 | atom_hiddens2 = tf.reshape(atom_hiddens, [batch_size, -1, 1, hidden_size]) 64 | atom_pair = atom_hiddens1 + atom_hiddens2 65 | att_hidden = tf.nn.relu(linearND(atom_pair, hidden_size, scope="att_atom_feature", init_bias=None) + 66 | linearND(self._binary, hidden_size, scope="att_bin_feature")) 67 | att_score = linearND(att_hidden, 1, scope="att_scores") 68 | att_score = tf.nn.sigmoid(att_score) 69 | 70 | # Use the attention scores to compute the "att_context" global features 71 | att_context = att_score * atom_hiddens1 72 | att_context = tf.reduce_sum(att_context, 2) 73 | 74 | # Compute selectivity toward each atom based on the local and global representations 75 | atom_logits_alltasks = tf.squeeze(linearND(atom_hiddens, self.num_tasks, scope="local_score") + 76 | linearND(att_context, self.num_tasks, scope="global_score")) 77 | 78 | self.atom_likelihoods_smiles = tf.sigmoid(atom_logits_alltasks) 79 | 80 | self.saver = tf.train.Saver() 81 | self.saver.restore(self.session, self.save_path) 82 | 83 | def web_predictor(self, smiles): 84 | ''' 85 | Predictor that will give atom scores for a smiles list 86 | ''' 87 | cur_batch_size = 1 88 | src_tuple = self.smiles2graph_batch(smiles) 89 | cur_bin = binary_features_batch(smiles) 90 | feed_map = {x:y for x,y in zip(self._src_holder, src_tuple)} 91 | feed_map.update({self._binary:cur_bin}) 92 | 93 | alikelihoods = self.session.run(self.atom_likelihoods_smiles, feed_dict=feed_map) 94 | 95 | results = [] 96 | for i,j in enumerate(list(zip(*alikelihoods))): 97 | results.append({ 98 | 'smiles':smiles[0], 99 | 'task': self.task_dict_rev[i], 100 | 'atom_scores': tuple([float(x) for x in j])}) 101 | 102 | return results 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /askcos/synthetic/selectivity/site_selectivity.py: -------------------------------------------------------------------------------- 1 | from askcos.synthetic.selectivity.multitask_model import tf_predictor 2 | import rdkit.Chem as Chem 3 | import sys 4 | import os 5 | 6 | 7 | class Site_Predictor(): 8 | def __init__(self): 9 | self.site_predictor = tf_predictor() 10 | self.site_predictor.build() 11 | print('Loaded recommendation model') 12 | print('### RECOMMENDER STARTED UP ###') 13 | 14 | def predict(self, smi): 15 | res = self.site_predictor.web_predictor([smi]) #has to be a list 16 | return res 17 | 18 | 19 | #for testing purposes 20 | if __name__ == "__main__": 21 | predictor = Site_Predictor() 22 | react = 'Cc1ccccc1' 23 | print(react) 24 | res = predictor.predict(react) 25 | print(res[0]) 26 | 27 | -------------------------------------------------------------------------------- /askcos/synthetic/selectivity/site_selectivity_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import tensorflow.compat.v1 as tf 4 | 5 | from askcos.synthetic.selectivity.site_selectivity import Site_Predictor 6 | 7 | 8 | class SiteSelectivity(unittest.TestCase): 9 | 10 | def setUp(self): 11 | """This method is run once before each test in this class.""" 12 | # Clear tensorflow sessions - leftover sessions from other models cause unexpected errors 13 | tf.keras.backend.clear_session() 14 | 15 | def test_01_predict(self): 16 | """Test that the Site_Predictor works as expected.""" 17 | react = 'Cc1ccccc1' 18 | predictor = Site_Predictor() 19 | res = predictor.predict(react) 20 | self.assertEqual(len(res), 123) 21 | self.assertEqual(type(res[0]), dict) 22 | self.assertEqual(len(res[0].get('atom_scores')), 7) 23 | 24 | 25 | if __name__ == '__main__': 26 | res = unittest.main(verbosity=3, exit=False) 27 | -------------------------------------------------------------------------------- /askcos/synthetic/selectivity/task_dict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/synthetic/selectivity/task_dict.pkl -------------------------------------------------------------------------------- /askcos/utilities/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /askcos/utilities/banned/__init__.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | from rdkit import Chem 4 | from askcos import global_config as gc 5 | 6 | 7 | with open(gc.BAN_LIST_PATH) as f: 8 | ban_list = json.load(f) 9 | 10 | BANNED_SMILES = {Chem.MolToSmiles(Chem.MolFromSmiles(smi), isomericSmiles=True) 11 | for smi in itertools.chain(*ban_list.values()) if smi is not None} 12 | -------------------------------------------------------------------------------- /askcos/utilities/banned/prepare_list.py: -------------------------------------------------------------------------------- 1 | import os, sys, json, urllib 2 | import rdkit.Chem as Chem 3 | from askcos.utilities.io.name_parser import name_to_molecule, urlopen 4 | 5 | ## Syntax of banned list: 6 | # json-dumped dictionary of "name": [isomeric_smiles, flat_smiles] 7 | # where smiles could be None 8 | 9 | # Where is the banned chemical list stored? 10 | banned_names_fpath = os.path.join(os.path.dirname(__file__), 'banned_names.txt') 11 | banned_list_fpath = os.path.join(os.path.dirname(__file__), 'banned_list.json') 12 | 13 | # Default 14 | banned = {} 15 | 16 | # Open list 17 | try: 18 | with open(banned_list_fpath, 'r') as fid: 19 | banned = json.load(fid) 20 | except: 21 | print('Warning: did not find existing banned_list.json file') 22 | pass 23 | 24 | # Check that all names are in the banned_list 25 | try: 26 | with open(banned_names_fpath, 'r') as fid: 27 | names = [line.split('#')[0].strip() for line in fid.readlines()] 28 | for name in names: 29 | if not name: 30 | continue 31 | if name not in banned: 32 | banned[name] = (None, None) 33 | except: 34 | print('Warning: did not find banned_names.txt file') 35 | pass 36 | 37 | # Try to fill in SMILES that are missing automatically 38 | for (name, smis) in banned.items(): 39 | 40 | if smis[0] is None: # use NIH resolver 41 | escaped_name = urllib.parse.quote(name) 42 | try: 43 | mol = name_to_molecule(escaped_name) 44 | if mol is None: 45 | continue 46 | 47 | banned[name] = ( 48 | Chem.MolToSmiles(mol, True), 49 | Chem.MolToSmiles(mol, False), 50 | ) 51 | print('Parsed {} --> {}'.format(name, banned[name])) 52 | except: 53 | pass 54 | 55 | else: # try to canonicalize 56 | mol = Chem.MolFromSmiles(smis[0]) 57 | banned[name] = ( 58 | Chem.MolToSmiles(mol, True), 59 | Chem.MolToSmiles(mol, False), 60 | ) 61 | 62 | 63 | # Resave list 64 | with open(banned_list_fpath, 'w') as fid: 65 | json.dump(banned, fid, indent=4, sort_keys=True) 66 | -------------------------------------------------------------------------------- /askcos/utilities/buyable/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /askcos/utilities/buyable/pricer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | import os 4 | from collections import defaultdict 5 | 6 | import rdkit.Chem as Chem 7 | import pandas as pd 8 | from pymongo import MongoClient, errors 9 | 10 | import askcos.global_config as gc 11 | from askcos.utilities.io.logger import MyLogger 12 | 13 | pricer_loc = 'pricer' 14 | 15 | 16 | class Pricer: 17 | """ 18 | The Pricer class is used to look up the ppg of chemicals if they 19 | are buyable. 20 | """ 21 | 22 | def __init__(self, use_db=True, BUYABLES_DB=None): 23 | 24 | self.BUYABLES_DB = BUYABLES_DB 25 | self.use_db = use_db 26 | self.prices = None 27 | 28 | def load(self, file_name=gc.BUYABLES['file_name']): 29 | """ 30 | Load pricer information. Either create connection to MongoDB or load from local file. 31 | """ 32 | if self.use_db: 33 | self.load_databases(file_name) 34 | else: 35 | self.load_from_file(file_name) 36 | 37 | def load_databases(self, file_name=gc.BUYABLES['file_name']): 38 | """ 39 | Load the pricing data from the online database 40 | 41 | If connection to MongoDB cannot be made, fallback and try to load from local file. 42 | """ 43 | db_client = MongoClient(serverSelectionTimeoutMS=1000, **gc.MONGO) 44 | 45 | try: 46 | db_client.server_info() 47 | except errors.ServerSelectionTimeoutError: 48 | MyLogger.print_and_log('Cannot connect to mongodb to load prices', pricer_loc) 49 | self.use_db = False 50 | self.load(file_name=file_name) 51 | else: 52 | db = db_client[gc.BUYABLES['database']] 53 | self.BUYABLES_DB = db[gc.BUYABLES['collection']] 54 | 55 | def dump_to_file(self, file_path): 56 | """ 57 | Write prices to a local file 58 | """ 59 | self.prices.to_json(file_path, orient='records', compression='gzip') 60 | 61 | def load_from_file(self, file_name): 62 | """ 63 | Load buyables information from local file 64 | """ 65 | if os.path.isfile(file_name): 66 | self.prices = pd.read_json( 67 | file_name, 68 | orient='records', 69 | dtype={'smiles': 'object', 'source': 'object', 'ppg': 'float'}, 70 | compression='gzip', 71 | ) 72 | MyLogger.print_and_log('Loaded prices from flat file', pricer_loc) 73 | else: 74 | MyLogger.print_and_log('Buyables file does not exist: {}'.format(file_name), pricer_loc) 75 | 76 | def lookup_smiles(self, smiles, source=None, alreadyCanonical=False, isomericSmiles=True): 77 | """ 78 | Looks up a price by SMILES. Canonicalize smiles string unless 79 | the user specifies that the smiles string is definitely already 80 | canonical. If the DB connection does not exist, look up from 81 | prices dictionary attribute, otherwise lookup from DB. 82 | If multiple entries exist in the DB, return the lowest price. 83 | 84 | Args: 85 | smiles (str): SMILES string to look up 86 | source (list or str, optional): buyables sources to consider; 87 | if ``None`` (default), include all sources, otherwise 88 | must be single source or list of sources to consider; 89 | alreadyCanonical (bool, optional): whether SMILES string is already 90 | canonical; if ``False`` (default), SMILES will be canonicalized 91 | isomericSmiles (bool, optional): whether to generate isomeric 92 | SMILES string when performing canonicalization 93 | """ 94 | if not alreadyCanonical: 95 | mol = Chem.MolFromSmiles(smiles) 96 | if not mol: 97 | return 0. 98 | smiles = Chem.MolToSmiles(mol, isomericSmiles=isomericSmiles) 99 | 100 | if source == []: 101 | # If no sources are allowed, there is no need to perform lookup 102 | # Empty list is checked explicitly here, since None means source 103 | # will not be included in query, and '' is a valid source value 104 | return 0.0 105 | 106 | if self.use_db: 107 | query = {'smiles': smiles} 108 | 109 | if source is not None: 110 | if isinstance(source, list): 111 | query['source'] = {'$in': source} 112 | else: 113 | query['source'] = source 114 | 115 | cursor = self.BUYABLES_DB.find(query) 116 | return min([doc['ppg'] for doc in cursor], default=0.0) 117 | elif self.prices is not None: 118 | query = self.prices['smiles'] == smiles 119 | 120 | if source is not None: 121 | if isinstance(source, list): 122 | query = query & (self.prices['source'].isin(source)) 123 | else: 124 | query = query & (self.prices['source'] == source) 125 | 126 | results = self.prices.loc[query] 127 | return min(results['ppg'], default=0.0) 128 | else: 129 | return 0.0 130 | 131 | 132 | if __name__ == '__main__': 133 | pricer = Pricer() 134 | pricer.load() 135 | print(pricer.lookup_smiles('CCCCCO')) 136 | print(pricer.lookup_smiles('CCCCXCCO')) 137 | -------------------------------------------------------------------------------- /askcos/utilities/canonicalization.py: -------------------------------------------------------------------------------- 1 | import rdkit.Chem as Chem 2 | import rdkit.Chem.AllChem as AllChem 3 | 4 | class SmilesFixer(): 5 | ''' 6 | This class stores RDKit reactions which help turn molecules with 7 | weird representations into ones with common ones (so they will match 8 | a template) 9 | ''' 10 | def __init__(self): 11 | self.rxns = [ 12 | # Double bonds on aromatic rings (dist 1) 13 | AllChem.ReactionFromSmarts('[NH0:1]=[a:2][nH:3]>>[NH:1][a:2][nH0:3]'), 14 | AllChem.ReactionFromSmarts('[NH1:1]=[a:2][nH:3]>>[NH2:1][a:2][nH0:3]'), 15 | AllChem.ReactionFromSmarts('[OH0:1]=[a:2][nH:3]>>[OH:1][a:2][nH0:3]'), 16 | # Double bonds on aromatic rings (dist 2) 17 | AllChem.ReactionFromSmarts('[NH0:1]=[a:2][a:4][nH:3]>>[NH:1][a:2][a:4][nH0:3]'), 18 | AllChem.ReactionFromSmarts('[NH1:1]=[a:2][a:4][nH:3]>>[NH2:1][a:2][a:4][nH0:3]'), 19 | AllChem.ReactionFromSmarts('[OH0:1]=[a:2][a:4][nH:3]>>[OH:1][a:2][a:4][nH0:3]'), 20 | # Double bonds on aromatic rings (dist 2) 21 | AllChem.ReactionFromSmarts('[NH0:1]=[a:2][a:4][a:5][nH:3]>>[NH:1][a:2][a:4][a:5][nH0:3]'), 22 | AllChem.ReactionFromSmarts('[NH1:1]=[a:2][a:4][a:5][nH:3]>>[NH2:1][a:2][a:4][a:5][nH0:3]'), 23 | AllChem.ReactionFromSmarts('[OH0:1]=[a:2][a:4][a:5][nH:3]>>[OH:1][a:2][a:4][a:5][nH0:3]'), 24 | # Iminol / amide 25 | AllChem.ReactionFromSmarts('[NH0:1]=[C:2]-[OH:3]>>[NH1:1]-[C:2]=[OH0:3]'), 26 | AllChem.ReactionFromSmarts('[NH1:1]=[C:2]-[OH:3]>>[NH2:1]-[C:2]=[OH0:3]'), 27 | # Thiourea 28 | AllChem.ReactionFromSmarts('[NH0:1]=[C:2]-[SH:3]>>[NH1:1]-[C:2]=[SH0:3]'), 29 | AllChem.ReactionFromSmarts('[NH1:1]=[C:2]-[SH:3]>>[NH2:1]-[C:2]=[SH0:3]'), 30 | # Azide 31 | AllChem.ReactionFromSmarts('[NH0:1][NH0:2]=[NH0;-:3]>>[NH0;-:1]=[NH0;+:2]=[NH0;-:3]'), 32 | # Cyanide salts 33 | AllChem.ReactionFromSmarts('([K,Na;H1:1].[C;X1;H0:2]#[N:3])>>[*;H0:1][*:2]#[N:3]'), 34 | AllChem.ReactionFromSmarts('([Cu:1].[C;X1;H0:2]#[N:3])>>[*:1][*:2]#[N:3]'), 35 | # Grinards 36 | AllChem.ReactionFromSmarts('([MgH+:1].[C;v3:2][*:3])>>[Mg+:1][*:2][*:3]'), 37 | # Coordinated tin 38 | AllChem.ReactionFromSmarts('([SnH4:1].[C;v3:2].[C;v3:3].[C;v3:4].[C;v3:5])>>[Sn:1]([*:2])([*:3])([*:4])[*:5]'), 39 | AllChem.ReactionFromSmarts('([SnH3:1].[C;v3:2].[C;v3:3].[C;v3:4])>>[Sn:1]([*:2])([*:3])[*:4]'), 40 | AllChem.ReactionFromSmarts('([SnH2:1].[C;v3:2].[C;v3:3])>>[Sn:1]([*:2])[*:3]'), 41 | AllChem.ReactionFromSmarts('([SnH1:1].[C;v3:2])>>[Sn:1][*:2]'), 42 | ] 43 | 44 | def fix_smiles(self, old_smiles, removeMap = True): 45 | ''' 46 | For a given SMILES string, this function "fixes" common mistakes 47 | found in the Lowe parsed database: 48 | - N=c[nH] structures are turned into the normal [NH]-c[n] forms 49 | - iminols are turned into amides/carbamates 50 | 51 | It applies the reactions in self.rxns until the SMILES string doesn't change 52 | ''' 53 | mol = Chem.MolFromSmiles(old_smiles) 54 | if removeMap: [x.ClearProp('molAtomMapNumber') for x in mol.GetAtoms()] 55 | if not mol: 56 | return old_smiles 57 | 58 | new_smiles = Chem.MolToSmiles(mol, isomericSmiles = True) 59 | old_smiles = '' 60 | while new_smiles != old_smiles: 61 | old_smiles = new_smiles 62 | for rxn in self.rxns: 63 | outcomes = rxn.RunReactants((mol,)) 64 | if not outcomes: 65 | continue 66 | else: 67 | mol = outcomes[0][0] 68 | Chem.SanitizeMol(mol) 69 | new_smiles = Chem.MolToSmiles(mol, isomericSmiles = True) 70 | 71 | return new_smiles -------------------------------------------------------------------------------- /askcos/utilities/cluster.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import rdkit.Chem as Chem 4 | from rdkit.Chem import AllChem 5 | import numpy as np 6 | 7 | import hdbscan 8 | import sklearn.cluster as cluster 9 | 10 | def cluster_precursors(original, precursors, **settings): 11 | precursor_smiles = [prec['smiles'] for prec in precursors] 12 | precursor_scores = [prec['score'] for prec in precursors] 13 | cluster_ids = group_results( 14 | original, precursor_smiles, scores=precursor_scores, 15 | **settings 16 | ) 17 | return cluster_ids 18 | 19 | def group_results(original, outcomes, **kwargs): 20 | '''Cluster the similar transformed outcomes together 21 | 22 | Args: 23 | original (str): SMILES string of original target molecule 24 | outcomes (list of str): List containing SMILES strings of outcomes to be clustered 25 | feature (str, optional): Only use features disappearing from 'original', appearing in 'outcomes' or a combination of 'all'. (default: {'original'}) 26 | cluster_method (str, optional): Method to use for clustering ['kmeans', 'hbdscan'] (default: {'kmeans'}) 27 | fp_type (str, optional): Type of fingerprinting method to use. (default: {'morgan'}) 28 | fp_length (int, optional): Fixed-length folding of fingerprint. (default: {512}) 29 | fp_radius (int, optional): Radius to use for fingerprint. (default: {1}) 30 | scores (list or float, optional): Listof precursor outcome scores to number clusters i.e. - cluster 1 contains precursor outcome with best score. (default: {None}) 31 | 32 | Returns: 33 | list of int: Cluster indices for outcomes, 0-based 34 | ''' 35 | fp_type = kwargs.get('fp_type', 'morgan') 36 | fp_length = kwargs.get('fp_length', 512) 37 | fp_radius = kwargs.get('fp_radius', 1) 38 | if fp_type == 'morgan': 39 | fp_generator = lambda x: AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), fp_radius, nBits=fp_length) 40 | else: 41 | raise Exception('Fatal error: fingerprint type {} is not supported.'.format(fingerprint)) 42 | 43 | cluster_method = kwargs.get('cluster_method', 'kmeans') 44 | feature = kwargs.get('feature', 'original') 45 | scores = kwargs.get('scores') 46 | 47 | if not outcomes: 48 | return [] 49 | 50 | # calc fingerprint 51 | original_fp = np.array(fp_generator(original)) 52 | outcomes_fp = np.array([fp_generator(i) for i in outcomes]) 53 | 54 | diff_fp = original_fp - outcomes_fp 55 | 56 | if 'original' == feature: 57 | diff_fp = diff_fp.clip(min=0) 58 | elif 'outcomes' == feature: 59 | diff_fp = diff_fp.clip(max=0) 60 | elif 'all' == feature: 61 | pass 62 | else: 63 | raise Exception('Fatal error: feature={} is not recognized.'.format(feature)) 64 | 65 | # calc culster indices 66 | res = [] 67 | if 'hdbscan' == cluster_method: 68 | clusterer = hdbscan.HDBSCAN(min_cluster_size=5, gen_min_span_tree=False) 69 | clusterer.fit(diff_fp) 70 | res = clusterer.labels_ 71 | # non-clustered inputs have id -1, make them appear as individual clusters 72 | max_cluster = np.amax(res) 73 | for i in range(len(res)): 74 | if res[i] == -1: 75 | max_cluster += 1 76 | res[i] = max_cluster 77 | elif 'kmeans' == cluster_method: 78 | for cluster_size in range(len(diff_fp)): 79 | kmeans = cluster.KMeans(n_clusters=cluster_size+1).fit(diff_fp) 80 | if kmeans.inertia_ < 1: 81 | break 82 | res = kmeans.labels_ 83 | else: 84 | raise Exception('Fatal error: cluster_method={} is not recognized.'.format(cluster_method)) 85 | 86 | res = [int(i) for i in res] 87 | 88 | if scores is not None: 89 | if len(scores) != len(res): 90 | raise Exception('Fatal error: length of score ({}) and smiles ({}) are different.'.format(len(scores), len(res))) 91 | best_cluster_score = {} 92 | for cluster_id, score in zip(res, scores): 93 | best_cluster_score[cluster_id] = max( 94 | best_cluster_score.get(cluster_id, -float('inf')), 95 | score 96 | ) 97 | print('best_cluster_score: ', best_cluster_score) 98 | new_order = list(sorted(best_cluster_score.items(), key=lambda x: -x[1])) 99 | order_mapping = {new_order[n][0]: n for n in range(len(new_order))} 100 | print('order_mapping: ', order_mapping) 101 | res = [order_mapping[n] for n in res] 102 | 103 | return res 104 | -------------------------------------------------------------------------------- /askcos/utilities/conditions.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import rdkit.Chem as Chem 3 | from rdkit.Chem import AllChem 4 | import numpy as np 5 | from collections import defaultdict 6 | 7 | def average_template_list(INSTANCE_DB, CHEMICAL_DB, id_list): 8 | ''' 9 | Given the INSTANCE_DB with many reaction examples and 10 | a list of IDs to include, this function will look at 11 | the reaction conditions and come up with an estimate 12 | of proposed conditions 13 | ''' 14 | 15 | solvents = defaultdict(int) 16 | reagents = defaultdict(int) 17 | catalysts = defaultdict(int) 18 | pressures = [] 19 | temps = [] 20 | yields = [] 21 | times = [] 22 | 23 | def string_or_range_to_float(text): 24 | try: 25 | return float(text) 26 | except Exception as e: 27 | if '-' in text: 28 | try: 29 | return sum([float(x) for x in text.split('-')]) / len(text.split('-')) 30 | except Exception as e: 31 | print(e) 32 | else: 33 | print(e) 34 | return None 35 | 36 | N_id = float(len(id_list)) 37 | for _id in id_list: 38 | doc = INSTANCE_DB.find_one({'_id': _id}) 39 | if not doc: continue 40 | for xrn in doc['RXD_SOLXRN']: solvents[xrn] += 1 41 | for xrn in doc['RXD_CATXRN']: reagents[xrn] += 1 42 | for xrn in doc['RXD_RGTXRN']: catalysts[xrn] += 1 43 | if doc['RXD_P'] != -1: 44 | P = string_or_range_to_float(doc['RXD_P']) 45 | if P: pressures.append(P) 46 | if doc['RXD_T'] != -1: 47 | T = string_or_range_to_float(doc['RXD_T']) 48 | if T: temps.append(T) 49 | if doc['RXD_TIM'] != -1: 50 | t = string_or_range_to_float(doc['RXD_TIM']) 51 | if t: times.append(t) 52 | if doc['RXD_NYD'] != -1: yields.append(float(doc['RXD_NYD'])) 53 | conditions = [] 54 | 55 | # Solvents 56 | solvent_string = '' 57 | for (solxrn, count) in sorted(solvents.items(), key = lambda x: x[1], reverse = True)[:5]: 58 | doc = CHEMICAL_DB.find_one({'_id': solxrn}, ['SMILES', 'IDE_CN']) 59 | cn = doc['IDE_CN'] if doc else solxrn 60 | solvent_string += '{} ({:.0f}%); '.format(cn, count*100.0/N_id) 61 | conditions.append('SOLVENT: ' + solvent_string) 62 | 63 | # Reagents 64 | reagent_string = '' 65 | for (rgtxrn, count) in sorted(reagents.items(), key = lambda x: x[1], reverse = True)[:5]: 66 | doc = CHEMICAL_DB.find_one({'_id': rgtxrn}, ['SMILES', 'IDE_CN']) 67 | cn = doc['IDE_CN'] if doc else rgtxrn 68 | reagent_string += '{} ({:.0f}%); '.format(cn, count*100.0/N_id) 69 | conditions.append('REAGENT: ' + reagent_string) 70 | 71 | # Catalysts 72 | catalyst_string = '' 73 | for (catxrn, count) in sorted(catalysts.items(), key = lambda x: x[1], reverse = True)[:5]: 74 | doc = CHEMICAL_DB.find_one({'_id': catxrn}, ['SMILES', 'IDE_CN']) 75 | cn = doc['IDE_CN'] if doc else catxrn 76 | catalyst_string += '{} ({:.0f}%); '.format(cn, count*100.0/N_id) 77 | conditions.append('CATALYST: ' + catalyst_string) 78 | 79 | # Time 80 | if times: 81 | conditions.append( 82 | 'TIME: {:.1f} +/- {:.1f} hours (min {:.1f}, max {:.1f}, N={})'.format( 83 | np.mean(times), np.std(times), min(times), max(times), len(times) 84 | ) 85 | ) 86 | else: 87 | conditions.append('TIME UNKNOWN') 88 | 89 | # Temp 90 | if temps: 91 | conditions.append( 92 | 'TEMP: {:.1f} +/- {:.1f} Celsius (min {:.1f}, max {:.1f}, N={})'.format( 93 | np.mean(temps), np.std(temps), min(temps), max(temps), len(temps) 94 | ) 95 | ) 96 | else: 97 | conditions.append('TEMP UNKNOWN') 98 | 99 | # Pressure 100 | if pressures: 101 | conditions.append( 102 | 'PRESSURE: {:.0f} +/- {:.0f} Torr (min {:.0f}, max {:.0f}, N={})'.format( 103 | np.mean(pressures), np.std(pressures), min(pressures), max(pressures), len(pressures) 104 | ) 105 | ) 106 | else: 107 | conditions.append('PRESSURE UNKNOWN') 108 | 109 | # Yields 110 | if yields: 111 | conditions.append( 112 | 'To provide a yield of: {:.1f} +/- {:.1f} percent (min {:.1f}, max {:.1f}, N={})'.format( 113 | np.mean(yields), np.std(yields), min(yields), max(yields), len(yields) 114 | ) 115 | ) 116 | else: 117 | conditions.append('YIELD UNKNOWN') 118 | 119 | return conditions -------------------------------------------------------------------------------- /askcos/utilities/formats.py: -------------------------------------------------------------------------------- 1 | def chem_dict(_id, children=[], **kwargs): 2 | """Returns chemical dictionary in the format required by the website. 3 | 4 | Chemical object as expected by website. Removes ``rct_of``, ``prod_of``, and 5 | ``depth`` information from the ``tree_dict`` entry. 6 | 7 | Args: 8 | _id (int): Chemical ID. 9 | children (list, optional): Children of the node. 10 | **kwargs: The ``tree_dict`` to be modified. 11 | """ 12 | kwargs.pop('rct_of', None) 13 | kwargs.pop('prod_of', None) 14 | kwargs.pop('depth', None) 15 | kwargs['id'] = _id 16 | kwargs['is_chemical'] = True 17 | kwargs['children'] = children 18 | return kwargs 19 | 20 | 21 | def rxn_dict(_id, smiles, children=[], **kwargs): 22 | """Returns reaction dictionary in the format required by the website. 23 | 24 | Reaction object as expected by website. Removes ``rct``, ``prod``, and 25 | ``depth`` information from the ``tree_dict`` entry. 26 | 27 | Args: 28 | _id (int): Chemical ID. 29 | smiles (str): SMILES string of reaction. 30 | children (list, optional): Children of the node. 31 | **kwargs: The ``tree_dict`` to be modified. 32 | """ 33 | kwargs.pop('rcts', None) 34 | kwargs.pop('prod', None) 35 | kwargs.pop('depth', None) 36 | kwargs['id'] = _id 37 | kwargs['is_reaction'] = True 38 | kwargs['children'] = children 39 | kwargs['smiles'] = smiles 40 | return kwargs 41 | -------------------------------------------------------------------------------- /askcos/utilities/historian/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /askcos/utilities/historian/chemicals_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from pymongo import MongoClient, errors 4 | 5 | import askcos.global_config as gc 6 | from askcos.utilities.historian.chemicals import ChemHistorian 7 | 8 | def db_available(): 9 | """Check if a mongo db instance is available.""" 10 | db_client = MongoClient(serverSelectionTimeoutMS=1000, **gc.MONGO) 11 | 12 | try: 13 | db_client.server_info() 14 | except errors.ServerSelectionTimeoutError: 15 | return False 16 | else: 17 | return True 18 | 19 | class TestChemHistorian(unittest.TestCase): 20 | 21 | @classmethod 22 | def setUpClass(cls): 23 | """This method is run once before each test in this class.""" 24 | cls.chemhistorian = ChemHistorian() 25 | cls.chemhistorian.load_from_file(os.path.join(os.path.dirname(__file__), 'test_data', 'chemicals.json.gz')) 26 | 27 | def test_01_lookup_smiles(self): 28 | """Test that we can look up a SMILES string in chemhistorian.""" 29 | result = self.chemhistorian.lookup_smiles('CCCCO') 30 | expected = {'as_product': 2726, 'as_reactant': 17450, 'template_set': 'reaxys'} 31 | self.assertEqual(expected, result) 32 | 33 | 34 | @unittest.skipIf(not db_available(), 'Skipping because mongo db is not available.') 35 | class TestDBChemHistorian(unittest.TestCase): 36 | 37 | @classmethod 38 | def setUpClass(cls): 39 | """This method is run once before each test in this class.""" 40 | cls.chemhistorian = ChemHistorian(use_db=True) 41 | cls.chemhistorian.load() 42 | cls.new_doc = {'_id': 'test_id','smiles': 'CCCCO', 'as_product': 1, 'as_reactant': 1, 'template_set': 'test_template_set'} 43 | 44 | @classmethod 45 | def tearDownClass(cls): 46 | """This is run after each test in this class.""" 47 | cls.chemhistorian.CHEMICALS_DB.delete_one(cls.new_doc) 48 | 49 | def test_01_lookup_smiles(self): 50 | """Test that we can look up a SMILES string in chemhistorian.""" 51 | result = self.chemhistorian.lookup_smiles('CCCCO') 52 | result.pop('_id') 53 | expected = {'as_product': 2726, 'as_reactant': 17450} 54 | self.assertEqual(expected, result) 55 | 56 | def test_02_use_db_template_set(self): 57 | """Test that when using the mongoDB, we can lookup in different template sets""" 58 | self.chemhistorian.CHEMICALS_DB.insert_one(self.new_doc) 59 | reaxys_result = self.chemhistorian.lookup_smiles('CCCCO', template_set='reaxys') 60 | reaxys_result.pop('_id') 61 | expected_reaxys = {'as_product': 2726, 'as_reactant': 17450} 62 | self.assertEqual(expected_reaxys, reaxys_result) 63 | new_result = self.chemhistorian.lookup_smiles('CCCCO', template_set='test_template_set') 64 | expected_new = {'as_product': self.new_doc['as_product'], 'as_reactant': self.new_doc['as_reactant'], '_id': self.new_doc['_id']} 65 | self.assertEqual(expected_new, new_result) 66 | 67 | if __name__ == '__main__': 68 | res = unittest.main(verbosity=3, exit=False) 69 | -------------------------------------------------------------------------------- /askcos/utilities/historian/test_data/chemicals.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/utilities/historian/test_data/chemicals.json.gz -------------------------------------------------------------------------------- /askcos/utilities/io/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/utilities/io/__init__.py -------------------------------------------------------------------------------- /askcos/utilities/io/draw_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | from PIL import Image, ImageChops 5 | 6 | import askcos.utilities.io.draw as draw 7 | 8 | 9 | class TestDraw(unittest.TestCase): 10 | """This class contains unit tests for the :mod:`askcos.utilities.io.draw` module.""" 11 | 12 | def test_01_ReactionStringToImage(self): 13 | """Test drawing of a reaction string.""" 14 | rxn_string = 'Fc1ccc(C2(Cn3cncn3)CO2)c(F)c1.c1nc[nH]n1.Cl.O=C([O-])O.[Na+]>>OC(Cn1cncn1)(Cn1cncn1)c1ccc(F)cc1F' 15 | result = draw.ReactionStringToImage(rxn_string, strip=True) 16 | expected = Image.open(os.path.join(os.path.dirname(__file__), 'test_data/draw_test_rxn_string.png')) 17 | self.assertIsNone(ImageChops.difference(result, expected).getbbox()) 18 | 19 | def test_02_ReactionStringToImageRetro(self): 20 | """Test drawing of a reaction string for retrosynthesis.""" 21 | rxn_string = 'Fc1ccc(C2(Cn3cncn3)CO2)c(F)c1.c1nc[nH]n1.Cl.O=C([O-])O.[Na+]>>OC(Cn1cncn1)(Cn1cncn1)c1ccc(F)cc1F' 22 | result = draw.ReactionStringToImage(rxn_string, strip=True, retro=True) 23 | expected = Image.open(os.path.join(os.path.dirname(__file__), 'test_data/draw_retro_test_rxn_string.png')) 24 | self.assertIsNone(ImageChops.difference(result, expected).getbbox()) 25 | 26 | def test_03_TransformStringToImage(self): 27 | """Test drawing of a reaction transform.""" 28 | tform = '([O;H0:3]=[C;H0:4](-[C:5])-[NH:2]-[C:1])>>([C:1]-[NH2:2]).([OH:3]-[C;H0:4](=O)-[C:5])' 29 | result = draw.TransformStringToImage(tform) 30 | expected = Image.open(os.path.join(os.path.dirname(__file__), 'test_data/draw_transform.png')) 31 | self.assertIsNone(ImageChops.difference(result, expected).getbbox()) 32 | 33 | 34 | if __name__ == '__main__': 35 | res = unittest.main(verbosity=3, exit=False) 36 | -------------------------------------------------------------------------------- /askcos/utilities/io/files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import askcos.global_config as gc 3 | 4 | def make_directory(dir_name): 5 | path = os.path.join(os.getcwd(), dir_name) 6 | if not os.path.isdir(path): 7 | os.mkdir(path) 8 | return path 9 | 10 | ################################################################################ 11 | # Where are local files stored? 12 | ################################################################################ 13 | 14 | def get_retrotransformer_achiral_path(dbname, collname, mincount_retro): 15 | return os.path.join(gc.local_db_dumps, 16 | 'retrotransformer_achiral_using_%s-%s_mincount%i.pkl' % (dbname, collname, mincount_retro)) 17 | 18 | def get_retrotransformer_chiral_path(dbname, collname, mincount_retro, mincount_retro_chiral): 19 | return os.path.join(gc.local_db_dumps, gc.RETRO_TRANSFORMS_CHIRAL['file_name']) 20 | # return os.path.join(gc.local_db_dumps, 21 | # 'retrotransformer_chiral_using_%s-%s_mincount%i_mincountchiral%i.pkl' % (dbname, collname, mincount_retro, mincount_retro_chiral)) 22 | 23 | def get_synthtransformer_path(dbname, collname, mincount): 24 | return os.path.join(gc.local_db_dumps, 25 | 'synthtransformer_using_%s-%s_mincount%i.pkl' % (dbname, collname, mincount)) 26 | 27 | def get_pricer_path(chem_dbname, chem_collname, buyable_dbname, buyable_collname): 28 | return os.path.join(gc.local_db_dumps, gc.BUYABLES['file_name']) 29 | # return os.path.join(gc.local_db_dumps, 30 | # 'pricer_using_%s-%s_and_%s-%s.pkl' % (chem_dbname, chem_collname, buyable_dbname, buyable_collname)) 31 | 32 | def get_abraham_solvents_path(): 33 | return os.path.join(gc.local_db_dumps, 'abraham_solvents.pkl') -------------------------------------------------------------------------------- /askcos/utilities/io/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | 5 | def select_log_path(root='', name=''): 6 | """Select a location for the log file.""" 7 | filename = '{0}.log'.format(name) if name else 'askcos.log' 8 | root = root or os.path.expanduser('~') 9 | log_path = os.path.join(root, filename) 10 | 11 | try: 12 | open(log_path, 'a').close() 13 | except OSError: 14 | log_path = os.path.join(os.getcwd(), filename) 15 | 16 | return log_path 17 | 18 | 19 | class MyLogger: 20 | """ 21 | Create logger. Four different levels of information output. A level 3 ("FATAL") 22 | log will exit the program. 23 | """ 24 | log_file = select_log_path() 25 | levels = { 26 | 0: 'INFO', 27 | 1: 'WARN', 28 | 2: 'ERROR', 29 | 3: 'FATAL' 30 | } 31 | time_zero = 0 32 | 33 | @staticmethod 34 | def initialize_logFile(root='', name=''): 35 | """Clear previous log file and set initialization time.""" 36 | if name: 37 | MyLogger.log_file = select_log_path(root, name) 38 | if os.path.isfile(MyLogger.log_file): 39 | os.remove(MyLogger.log_file) 40 | MyLogger.time_zero = time.time() 41 | 42 | @staticmethod 43 | def print_and_log(text, location, level=0): 44 | """Print message to stdout and write to log file.""" 45 | time_elapsed = time.time() - MyLogger.time_zero 46 | 47 | tag = '{}@{}'.format(MyLogger.levels[level], location)[:25] 48 | outstr = '{:25s}: [{:04.3f}s]\t{}'.format(tag, time_elapsed, text) 49 | 50 | print(outstr) 51 | 52 | with open(MyLogger.log_file, 'a') as f: 53 | f.write(outstr) 54 | f.write('\n') 55 | 56 | if level == 3: 57 | quit() 58 | -------------------------------------------------------------------------------- /askcos/utilities/io/model_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import askcos.global_config as gc 3 | from pymongo import MongoClient 4 | from askcos.utilities.io.logger import MyLogger 5 | from askcos.utilities.buyable.pricer import Pricer 6 | from askcos.synthetic.context.nearestneighbor import NNContextRecommender 7 | from askcos.synthetic.context.neuralnetwork import NeuralNetContextRecommender 8 | from askcos.synthetic.enumeration.transformer import ForwardTransformer 9 | from askcos.retrosynthetic.transformer import RetroTransformer 10 | # from askcos.synthetic.evaluation.template_based import TemplateNeuralNetScorer 11 | from askcos.synthetic.evaluation.template_free import TemplateFreeNeuralNetScorer 12 | from askcos.synthetic.evaluation.fast_filter import FastFilterScorer 13 | import sys 14 | model_loader_loc = 'model_loader' 15 | 16 | 17 | def load_Retro_Transformer(): 18 | ''' 19 | Load the model and databases required for the retro transformer. Returns the retro transformer, ready to run. 20 | ''' 21 | MyLogger.print_and_log( 22 | 'Loading retro synthetic template database...', model_loader_loc) 23 | retroTransformer = RetroTransformer() 24 | retroTransformer.load() 25 | MyLogger.print_and_log( 26 | 'Retro synthetic transformer loaded.', model_loader_loc) 27 | return retroTransformer 28 | 29 | 30 | def load_Pricer(chemical_database, buyable_database): 31 | ''' 32 | Load a pricer using the chemicals database and database of buyable chemicals 33 | ''' 34 | MyLogger.print_and_log('Loading pricing model...', model_loader_loc) 35 | pricerModel = Pricer() 36 | pricerModel.load(chemical_database, buyable_database) 37 | MyLogger.print_and_log('Pricer Loaded.', model_loader_loc) 38 | return pricerModel 39 | 40 | 41 | def load_Forward_Transformer(mincount=100, worker_no = 0): 42 | ''' 43 | Load the forward prediction neural network 44 | ''' 45 | if worker_no==0: 46 | MyLogger.print_and_log('Loading forward prediction model...', model_loader_loc) 47 | transformer = ForwardTransformer() 48 | transformer.load(worker_no = worker_no) 49 | if worker_no==0: 50 | MyLogger.print_and_log('Forward transformer loaded.', model_loader_loc) 51 | return transformer 52 | 53 | 54 | def load_fastfilter(): 55 | ff = FastFilterScorer() 56 | ff.load(model_path =gc.FAST_FILTER_MODEL['model_path']) 57 | return ff 58 | 59 | 60 | def load_templatebased(mincount=25, celery=False, worker_no = 0): 61 | # transformer = None 62 | # if not celery: 63 | # transformer = load_Forward_Transformer(mincount=mincount, worker_no = worker_no) 64 | # scorer = TemplateNeuralNetScorer(forward_transformer=transformer, celery=celery) 65 | # scorer.load(gc.PREDICTOR['trained_model_path'], worker_no = worker_no) 66 | # return scorer 67 | return None 68 | 69 | 70 | def load_templatefree(): 71 | # Still has to be implemented 72 | return TemplateFreeNeuralNetScorer() 73 | 74 | 75 | def load_Context_Recommender(context_recommender, max_contexts=10): 76 | ''' 77 | Load the context recommendation model 78 | ''' 79 | MyLogger.print_and_log('Loading context recommendation model: {}...'.format( 80 | context_recommender), model_loader_loc) 81 | if context_recommender == gc.nearest_neighbor: 82 | recommender = NNContextRecommender(max_contexts=max_contexts) 83 | recommender.load(model_path=gc.CONTEXT_REC[ 84 | 'model_path'], info_path=gc.CONTEXT_REC['info_path']) 85 | elif context_recommender == gc.neural_network: 86 | recommender = NeuralNetContextRecommender(max_contexts=max_contexts) 87 | recommender.load(model_path=gc.NEURALNET_CONTEXT_REC['model_path'], info_path=gc.NEURALNET_CONTEXT_REC[ 88 | 'info_path'], weights_path=gc.NEURALNET_CONTEXT_REC['weights_path']) 89 | else: 90 | raise NotImplementedError 91 | MyLogger.print_and_log('Context recommender loaded.', model_loader_loc) 92 | return recommender 93 | -------------------------------------------------------------------------------- /askcos/utilities/io/name_parser.py: -------------------------------------------------------------------------------- 1 | import sys 2 | if sys.version_info[0] < 3: 3 | from urllib2 import urlopen 4 | else: 5 | from urllib.request import urlopen 6 | import rdkit.Chem as Chem 7 | name_parser_loc = 'name_parser' 8 | 9 | 10 | def name_to_molecule(name): 11 | try: 12 | mol = Chem.MolFromSmiles(name) 13 | if not mol: 14 | raise ValueError 15 | return mol 16 | except: 17 | pass 18 | 19 | smiles = urlopen( 20 | 'https://cactus.nci.nih.gov/chemical/structure/{}/smiles'.format(name)).read() 21 | mol = Chem.MolFromSmiles(smiles) 22 | if not mol: 23 | raise ValueError( 24 | 'Could not resolve SMILES ({}) in a way parseable by RDKit, from identifier: {}'.format(smiles, name)) 25 | 26 | return mol 27 | -------------------------------------------------------------------------------- /askcos/utilities/io/pickle.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import six 3 | from six.moves import cPickle as pickle 4 | 5 | def convert_pickled_bytes_2_to_3(data): 6 | if isinstance(data, bytes): return data.decode() 7 | if isinstance(data, dict): return dict(map(convert_pickled_bytes_2_to_3, data.items())) 8 | if isinstance(data, tuple): return tuple(map(convert_pickled_bytes_2_to_3, data)) 9 | if isinstance(data, list): return list(map(convert_pickled_bytes_2_to_3, data)) 10 | return data 11 | 12 | 13 | def load(file): 14 | if sys.version_info[0] < 3: 15 | return pickle.load(file) 16 | else: 17 | return convert_pickled_bytes_2_to_3(pickle.load(file, encoding='bytes')) 18 | 19 | def dump(data, file, *args, **kwargs): 20 | '''Always use protocol 2 for backwards compatibility!''' 21 | pickle.dump(data, file, 2) # note: always use protocol 2 for backward compatibility -------------------------------------------------------------------------------- /askcos/utilities/io/test_data/draw_retro_test_rxn_string.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/utilities/io/test_data/draw_retro_test_rxn_string.png -------------------------------------------------------------------------------- /askcos/utilities/io/test_data/draw_test_rxn_string.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/utilities/io/test_data/draw_test_rxn_string.png -------------------------------------------------------------------------------- /askcos/utilities/io/test_data/draw_transform.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASKCOS/askcos-core/c1ebf21b7f9c848c6e9488b16aaea504e005a1ca/askcos/utilities/io/test_data/draw_transform.png -------------------------------------------------------------------------------- /askcos/utilities/reactants.py: -------------------------------------------------------------------------------- 1 | import askcos.global_config as gc 2 | import rdkit.Chem as Chem 3 | from rdkit.Chem import AllChem 4 | from askcos.utilities.io.logger import MyLogger 5 | reactants_loc = 'util.reactants' 6 | 7 | 8 | def clean_reactant_mapping(reactants): 9 | """Remaps atoms for reactants. 10 | 11 | Args: 12 | reactants (Chem.Mol): Reactants to remap. 13 | 14 | Returns: 15 | Chem.Mol: Reactants with remapped atoms. 16 | """ 17 | if not reactants: 18 | MyLogger.print_and_log('Could not parse reactants {}'.format(reactants),reactants_loc) 19 | raise ValueError('Could not parse reactants') 20 | if gc.DEBUG: print('Number of reactant atoms: {}'.format(len(reactants.GetAtoms()))) 21 | # Report current reactant SMILES string 22 | [a.ClearProp('molAtomMapNumber') for a in reactants.GetAtoms() if a.HasProp('molAtomMapNumber')] 23 | if gc.DEBUG: print('Reactants w/o map: {}'.format(Chem.MolToSmiles(reactants))) 24 | # Add new atom map numbers 25 | [a.SetProp('molAtomMapNumber', str(i+1)) for (i, a) in enumerate(reactants.GetAtoms())] 26 | # Report new reactant SMILES string 27 | if gc.DEBUG: print('Reactants w/ map: {}'.format(Chem.MolToSmiles(reactants))) 28 | return reactants 29 | -------------------------------------------------------------------------------- /askcos/utilities/strings.py: -------------------------------------------------------------------------------- 1 | 2 | def string_or_range_to_float(text): 3 | """Translate a number or range in string format to a float. 4 | 5 | If the string represents a range, return the average of that range. 6 | 7 | Args: 8 | text (str): Number or range string to be converted to a float. 9 | 10 | Returns: 11 | float or None: Number converted / average of range or None if it could 12 | not be converted. 13 | """ 14 | 15 | try: 16 | return float(text) 17 | except Exception as e: 18 | if text.count('-') == 1: # 20 - 30 19 | try: 20 | x = text.split('-') 21 | return (float(x[0]) + float(x[1])) / 2.0 22 | except Exception as e: 23 | print(e) 24 | elif text.count('-') == 2: # -20 - 0 25 | try: 26 | x = text.split('-') 27 | return (-float(x[0]) + float(x[1])) / 2.0 28 | except Exception as e: 29 | print(e) 30 | elif text.count('-') == 3: # -20 - -10 31 | try: 32 | x = text.split('-') 33 | return (-float(x[0]) - float(x[1])) / 2.0 34 | except Exception as e: 35 | print(e) 36 | else: 37 | print(e) 38 | return None 39 | -------------------------------------------------------------------------------- /askcos/utilities/threadsafe.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | class threadsafe_iter: 4 | """Takes an iterator/generator and makes it thread-safe by 5 | serializing call to the `next` method of given iterator/generator. 6 | """ 7 | def __init__(self, it): 8 | self.it = it 9 | self.lock = threading.Lock() 10 | 11 | def __iter__(self): 12 | return self 13 | 14 | def next(self): 15 | with self.lock: 16 | return self.it.next() 17 | 18 | def threadsafe_generator(f): 19 | """A decorator that takes a generator function and makes it thread-safe. 20 | """ 21 | def g(*a, **kw): 22 | return threadsafe_iter(f(*a, **kw)) 23 | return g -------------------------------------------------------------------------------- /askcos/utilities/with_dummy.py: -------------------------------------------------------------------------------- 1 | class with_dummy(object): 2 | def __init__(self, *args, **kwargs): 3 | pass 4 | def __enter__(self, *args, **kwargs): 5 | pass 6 | def __exit__(self, *args, **kwargs): 7 | pass -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. ASKCOS documentation master file, created by 2 | sphinx-quickstart on Tue Jun 18 16:10:12 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to ASKCOS's documentation! 7 | ================================== 8 | 9 | Contents: 10 | 11 | .. toctree:: 12 | :maxdepth: 4 13 | 14 | modules 15 | 16 | 17 | 18 | Indices and tables 19 | ================== 20 | 21 | * :ref:`genindex` 22 | * :ref:`modindex` 23 | * :ref:`search` 24 | -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | ASKCOS 2 | ====== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | askcos 8 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: askcos 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python=3.7.8 7 | - celery=4.4 8 | - cython=0.29.21 9 | - h5py=2.10.0 10 | - hdbscan=0.8.23 11 | - jupyter 12 | - keras=2.2.4 13 | - markdown 14 | - matplotlib=3.3.0 15 | - networkx=2.4 16 | - numpy=1.19.1 17 | - pandas=1.1.0 18 | - pango=1.42.4 19 | - pillow=7.2.0 20 | - pip 21 | - pycairo=1.19.1 22 | - pymongo=3.10.1 23 | - pytorch>=1.4 24 | - pyyaml=5.3.1 25 | - rdkit=2020.03.6 26 | - requests 27 | - scipy=1.5.2 28 | - scikit-learn=0.23.1 29 | - sphinx 30 | - tensorflow=2.0.0 31 | - tensorflow-estimator=2.0.0 32 | - tqdm 33 | - cairosvg=2.4.2 34 | - pip: 35 | - rdchiral 36 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | celery==4.4 2 | Cython==0.29.21 3 | h5py==2.10.0 4 | hdbscan==0.8.23 5 | keras==2.2.4 6 | Markdown==3.2.2 7 | matplotlib==3.3.0 8 | networkx==2.4 9 | numpy==1.19.0 10 | pandas==1.0.5 11 | pillow==7.2.0 12 | pycairo==1.19.1 13 | pymongo==3.10.1 14 | torch>=1.4.0 15 | PyYAML==5.3.1 16 | rdchiral==1.0.0 17 | scipy==1.5.1 18 | scikit-learn==0.23.1 19 | sphinx==3.1.2 20 | tensorflow==2.0.0 21 | tqdm==4.48.0 22 | CairoSVG==2.4.2 23 | --------------------------------------------------------------------------------